From c7f486ca21ebf1d11a36b9f5f5e5b1cc148a81be Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 23 Dec 2022 01:31:43 +0100 Subject: [PATCH] feat(backend): allow changing username in PATCH /users/@me --- backend/db/member.go | 11 ++++++--- backend/db/user.go | 38 +++++++++++++++++++++++++++---- backend/routes/user/patch_user.go | 27 ++++++++++++++++++++-- 3 files changed, 67 insertions(+), 9 deletions(-) diff --git a/backend/db/member.go b/backend/db/member.go index 220b36d..b4d2e23 100644 --- a/backend/db/member.go +++ b/backend/db/member.go @@ -30,13 +30,13 @@ const ( ErrMemberNameInUse = errors.Sentinel("member name already in use") ) -func (db *DB) Member(ctx context.Context, id xid.ID) (m Member, err error) { +func (db *DB) getMember(ctx context.Context, q pgxscan.Querier, id xid.ID) (m Member, err error) { sql, args, err := sq.Select("*").From("members").Where("id = ?", id).ToSql() if err != nil { return m, errors.Wrap(err, "building sql") } - err = pgxscan.Get(ctx, db, &m, sql, args...) + err = pgxscan.Get(ctx, q, &m, sql, args...) if err != nil { if errors.Cause(err) == pgx.ErrNoRows { return m, ErrMemberNotFound @@ -47,6 +47,10 @@ func (db *DB) Member(ctx context.Context, id xid.ID) (m Member, err error) { return m, nil } +func (db *DB) Member(ctx context.Context, id xid.ID) (m Member, err error) { + return db.getMember(ctx, db, id) +} + // UserMember returns a member scoped by user. func (db *DB) UserMember(ctx context.Context, userID xid.ID, memberRef string) (m Member, err error) { sql, args, err := sq.Select("*").From("members"). @@ -98,6 +102,7 @@ func (db *DB) CreateMember(ctx context.Context, tx pgx.Tx, userID xid.ID, name s if err != nil { pge := &pgconn.PgError{} if errors.As(err, &pge) { + // unique constraint violation if pge.Code == "23505" { return m, ErrMemberNameInUse } @@ -146,7 +151,7 @@ func (db *DB) UpdateMember( avatarURLs []string, ) (m Member, err error) { if name == nil && displayName == nil && bio == nil && links == nil && avatarURLs == nil { - return m, ErrNothingToUpdate + return db.getMember(ctx, tx, id) } builder := sq.Update("members").Where("id = ?", id) diff --git a/backend/db/user.go b/backend/db/user.go index 41a9ce4..c4ce209 100644 --- a/backend/db/user.go +++ b/backend/db/user.go @@ -113,9 +113,8 @@ func (u *User) UpdateFromDiscord(ctx context.Context, db pgxscan.Querier, du *di return pgxscan.Get(ctx, db, u, sql, args...) } -// User gets a user by ID. -func (db *DB) User(ctx context.Context, id xid.ID) (u User, err error) { - err = pgxscan.Get(ctx, db, &u, "select * from users where id = $1", id) +func (db *DB) getUser(ctx context.Context, q pgxscan.Querier, id xid.ID) (u User, err error) { + err = pgxscan.Get(ctx, q, &u, "select * from users where id = $1", id) if err != nil { if errors.Cause(err) == pgx.ErrNoRows { return u, ErrUserNotFound @@ -127,6 +126,11 @@ func (db *DB) User(ctx context.Context, id xid.ID) (u User, err error) { return u, nil } +// User gets a user by ID. +func (db *DB) User(ctx context.Context, id xid.ID) (u User, err error) { + return db.getUser(ctx, db, id) +} + // Username gets a user by username. func (db *DB) Username(ctx context.Context, name string) (u User, err error) { err = pgxscan.Get(ctx, db, &u, "select * from users where username = $1", name) @@ -151,6 +155,32 @@ func (db *DB) UsernameTaken(ctx context.Context, username string) (valid, taken return true, taken, err } +// UpdateUsername validates the given username, then updates the given user's name to it if valid. +func (db *DB) UpdateUsername(ctx context.Context, tx pgx.Tx, id xid.ID, newName string) error { + if !usernameRegex.MatchString(newName) { + return ErrInvalidUsername + } + + sql, args, err := sq.Update("users").Set("username", newName).Where("id = ?", id).ToSql() + if err != nil { + return errors.Wrap(err, "building sql") + } + + _, err = db.Exec(ctx, sql, args...) + if err != nil { + pge := &pgconn.PgError{} + if errors.As(err, &pge) { + // unique constraint violation + if pge.Code == "23505" { + return ErrUsernameTaken + } + } + + return errors.Wrap(err, "executing query") + } + return nil +} + func (db *DB) UpdateUser( ctx context.Context, tx pgx.Tx, id xid.ID, @@ -159,7 +189,7 @@ func (db *DB) UpdateUser( avatarURLs []string, ) (u User, err error) { if displayName == nil && bio == nil && links == nil && avatarURLs == nil { - return u, ErrNothingToUpdate + return db.getUser(ctx, tx, id) } builder := sq.Update("users").Where("id = ?", id) diff --git a/backend/routes/user/patch_user.go b/backend/routes/user/patch_user.go index 526f023..3e8561b 100644 --- a/backend/routes/user/patch_user.go +++ b/backend/routes/user/patch_user.go @@ -12,6 +12,7 @@ import ( ) type PatchUserRequest struct { + Username *string `json:"username"` DisplayName *string `json:"display_name"` Bio *string `json:"bio"` Links *[]string `json:"links"` @@ -34,8 +35,15 @@ func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error { return server.APIError{Code: server.ErrBadRequest} } + // get existing user, for comparison later + u, err := s.DB.User(ctx, claims.UserID) + if err != nil { + return errors.Wrap(err, "getting existing user") + } + // validate that *something* is set - if req.DisplayName == nil && + if req.Username == nil && + req.DisplayName == nil && req.Bio == nil && req.Links == nil && req.Fields == nil && @@ -130,7 +138,22 @@ func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error { } defer tx.Rollback(ctx) - u, err := s.DB.UpdateUser(ctx, tx, claims.UserID, req.DisplayName, req.Bio, req.Links, avatarURLs) + // update username + if req.Username != nil && *req.Username != u.Username { + err = s.DB.UpdateUsername(ctx, tx, claims.UserID, *req.Username) + if err != nil { + switch err { + case db.ErrUsernameTaken: + return server.APIError{Code: server.ErrUsernameTaken} + case db.ErrInvalidUsername: + return server.APIError{Code: server.ErrInvalidUsername} + default: + return errors.Wrap(err, "updating username") + } + } + } + + u, err = s.DB.UpdateUser(ctx, tx, claims.UserID, req.DisplayName, req.Bio, req.Links, avatarURLs) if err != nil && errors.Cause(err) != db.ErrNothingToUpdate { log.Errorf("updating user: %v", err) return err