From 9e98b614722cd730abd78fe705e6c2d91af1f1d2 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 8 Sep 2022 14:00:41 +0200 Subject: [PATCH] feat: add user names/pronouns to GET /users/{userRef} and PATCH /users/@me --- backend/db/names_pronouns.go | 172 ++++++++++++++++++++++++++++++ backend/routes/user/get_user.go | 48 ++++++++- backend/routes/user/patch_user.go | 78 ++++++++++---- 3 files changed, 271 insertions(+), 27 deletions(-) create mode 100644 backend/db/names_pronouns.go diff --git a/backend/db/names_pronouns.go b/backend/db/names_pronouns.go new file mode 100644 index 0000000..6b5e8a5 --- /dev/null +++ b/backend/db/names_pronouns.go @@ -0,0 +1,172 @@ +package db + +import ( + "context" + "fmt" + "strings" + + "emperror.dev/errors" + "github.com/georgysavva/scany/pgxscan" + "github.com/jackc/pgx/v4" + "github.com/rs/xid" +) + +type WordStatus int + +const ( + StatusUnknown WordStatus = 0 + StatusFavourite WordStatus = 1 + StatusOkay WordStatus = 2 + StatusJokingly WordStatus = 3 + StatusFriendsOnly WordStatus = 4 + StatusAvoid WordStatus = 5 + wordStatusMax WordStatus = 6 +) + +type Name struct { + ID int64 `json:"-"` + Name string `json:"name"` + Status WordStatus `json:"status"` +} + +func (n Name) Validate() string { + if n.Name == "" { + return "name cannot be empty" + } + + if len([]rune(n.Name)) > FieldEntryMaxLength { + return fmt.Sprintf("name must be %d characters or less, is %d", FieldEntryMaxLength, len([]rune(n.Name))) + } + + if n.Status == StatusUnknown || n.Status >= wordStatusMax { + return fmt.Sprintf("status is invalid, must be between 1 and %d, is %d", wordStatusMax-1, n.Status) + } + + return "" +} + +type Pronoun struct { + ID int64 `json:"-"` + DisplayText *string `json:"display_text"` + Pronouns string `json:"pronouns"` + Status WordStatus `json:"status"` +} + +func (p Pronoun) Validate() string { + if p.Pronouns == "" { + return "pronouns cannot be empty" + } + + if p.DisplayText != nil { + if len([]rune(*p.DisplayText)) > FieldEntryMaxLength { + return fmt.Sprintf("display_text must be %d characters or less, is %d", FieldEntryMaxLength, len([]rune(*p.DisplayText))) + } + } + + if len([]rune(p.Pronouns)) > FieldEntryMaxLength { + return fmt.Sprintf("pronouns must be %d characters or less, is %d", FieldEntryMaxLength, len([]rune(p.Pronouns))) + } + + if p.Status == StatusUnknown || p.Status >= wordStatusMax { + return fmt.Sprintf("status is invalid, must be between 1 and %d, is %d", wordStatusMax-1, p.Status) + } + + return "" +} + +func (p Pronoun) String() string { + if p.DisplayText != nil { + return *p.DisplayText + } + + split := strings.Split(p.Pronouns, "/") + if len(split) <= 2 { + return strings.Join(split, "/") + } + + return strings.Join(split[:1], "/") +} + +func (db *DB) UserNames(ctx context.Context, userID xid.ID) (ns []Name, err error) { + sql, args, err := sq.Select("id", "name", "status").From("user_names").Where("user_id = ?", userID).OrderBy("id").ToSql() + if err != nil { + return nil, errors.Wrap(err, "building sql") + } + + err = pgxscan.Select(ctx, db, &ns, sql, args...) + if err != nil { + return nil, errors.Wrap(err, "executing query") + } + return ns, nil +} + +func (db *DB) UserPronouns(ctx context.Context, userID xid.ID) (ps []Pronoun, err error) { + sql, args, err := sq. + Select("id", "display_text", "pronouns", "status"). + From("user_pronouns").Where("user_id = ?", userID). + OrderBy("id").ToSql() + if err != nil { + return nil, errors.Wrap(err, "building sql") + } + + err = pgxscan.Select(ctx, db, &ps, sql, args...) + if err != nil { + return nil, errors.Wrap(err, "executing query") + } + return ps, nil +} + +func (db *DB) SetUserNames(ctx context.Context, tx pgx.Tx, userID xid.ID, names []Name) (err error) { + sql, args, err := sq.Delete("user_names").Where("user_id = ?", userID).ToSql() + if err != nil { + return errors.Wrap(err, "building sql") + } + + _, err = tx.Exec(ctx, sql, args...) + if err != nil { + return errors.Wrap(err, "deleting existing names") + } + + _, err = tx.CopyFrom(ctx, + pgx.Identifier{"user_names"}, + []string{"user_id", "name", "status"}, + pgx.CopyFromSlice(len(names), func(i int) ([]any, error) { + return []any{ + userID, + names[i].Name, + names[i].Status, + }, nil + })) + if err != nil { + return errors.Wrap(err, "inserting new names") + } + return nil +} + +func (db *DB) SetUserPronouns(ctx context.Context, tx pgx.Tx, userID xid.ID, names []Pronoun) (err error) { + sql, args, err := sq.Delete("user_pronouns").Where("user_id = ?", userID).ToSql() + if err != nil { + return errors.Wrap(err, "building sql") + } + + _, err = tx.Exec(ctx, sql, args...) + if err != nil { + return errors.Wrap(err, "deleting existing pronouns") + } + + _, err = tx.CopyFrom(ctx, + pgx.Identifier{"user_pronouns"}, + []string{"user_id", "pronouns", "display_text", "status"}, + pgx.CopyFromSlice(len(names), func(i int) ([]any, error) { + return []any{ + userID, + names[i].Pronouns, + names[i].DisplayText, + names[i].Status, + }, nil + })) + if err != nil { + return errors.Wrap(err, "inserting new pronouns") + } + return nil +} diff --git a/backend/routes/user/get_user.go b/backend/routes/user/get_user.go index c03c0a7..f74efd4 100644 --- a/backend/routes/user/get_user.go +++ b/backend/routes/user/get_user.go @@ -18,6 +18,8 @@ type GetUserResponse struct { Bio *string `json:"bio"` AvatarURL *string `json:"avatar_url"` Links []string `json:"links"` + Names []db.Name `json:"names"` + Pronouns []db.Pronoun `json:"pronouns"` Members []PartialMember `json:"members"` Fields []db.Field `json:"fields"` } @@ -35,7 +37,7 @@ type PartialMember struct { AvatarURL *string `json:"avatar_url"` } -func dbUserToResponse(u db.User, fields []db.Field) GetUserResponse { +func dbUserToResponse(u db.User, fields []db.Field, names []db.Name, pronouns []db.Pronoun) GetUserResponse { return GetUserResponse{ ID: u.ID, Username: u.Username, @@ -43,6 +45,8 @@ func dbUserToResponse(u db.User, fields []db.Field) GetUserResponse { Bio: u.Bio, AvatarURL: u.AvatarURL, Links: u.Links, + Names: names, + Pronouns: pronouns, Fields: fields, } } @@ -61,7 +65,19 @@ func (s *Server) getUser(w http.ResponseWriter, r *http.Request) error { return err } - render.JSON(w, r, dbUserToResponse(u, fields)) + names, err := s.DB.UserNames(ctx, u.ID) + if err != nil { + log.Errorf("getting user names: %v", err) + return err + } + + pronouns, err := s.DB.UserPronouns(ctx, u.ID) + if err != nil { + log.Errorf("getting user pronouns: %v", err) + return err + } + + render.JSON(w, r, dbUserToResponse(u, fields, names, pronouns)) return nil } else if err != db.ErrUserNotFound { log.Errorf("Error getting user by ID: %v", err) @@ -81,13 +97,25 @@ func (s *Server) getUser(w http.ResponseWriter, r *http.Request) error { return err } + names, err := s.DB.UserNames(ctx, u.ID) + if err != nil { + log.Errorf("getting user names: %v", err) + return err + } + + pronouns, err := s.DB.UserPronouns(ctx, u.ID) + if err != nil { + log.Errorf("getting user pronouns: %v", err) + return err + } + fields, err := s.DB.UserFields(ctx, u.ID) if err != nil { log.Errorf("Error getting user fields: %v", err) return err } - render.JSON(w, r, dbUserToResponse(u, fields)) + render.JSON(w, r, dbUserToResponse(u, fields, names, pronouns)) return nil } @@ -101,6 +129,18 @@ func (s *Server) getMeUser(w http.ResponseWriter, r *http.Request) error { return err } + names, err := s.DB.UserNames(ctx, u.ID) + if err != nil { + log.Errorf("getting user names: %v", err) + return err + } + + pronouns, err := s.DB.UserPronouns(ctx, u.ID) + if err != nil { + log.Errorf("getting user pronouns: %v", err) + return err + } + fields, err := s.DB.UserFields(ctx, u.ID) if err != nil { log.Errorf("Error getting user fields: %v", err) @@ -108,7 +148,7 @@ func (s *Server) getMeUser(w http.ResponseWriter, r *http.Request) error { } render.JSON(w, r, GetMeResponse{ - GetUserResponse: dbUserToResponse(u, fields), + GetUserResponse: dbUserToResponse(u, fields, names, pronouns), Discord: u.Discord, DiscordUsername: u.DiscordUsername, }) diff --git a/backend/routes/user/patch_user.go b/backend/routes/user/patch_user.go index cb00bc7..54d5c2b 100644 --- a/backend/routes/user/patch_user.go +++ b/backend/routes/user/patch_user.go @@ -12,13 +12,16 @@ import ( ) type PatchUserRequest struct { - DisplayName *string `json:"display_name"` - Bio *string `json:"bio"` - Links *[]string `json:"links"` - Fields *[]db.Field `json:"fields"` + DisplayName *string `json:"display_name"` + Bio *string `json:"bio"` + Links *[]string `json:"links"` + Names *[]db.Name `json:"names"` + Pronouns *[]db.Pronoun `json:"pronouns"` + Fields *[]db.Field `json:"fields"` } // patchUser parses a PatchUserRequest and updates the user with the given ID. +// TODO: could this be refactored to be less repetitive? names, pronouns, and fields are all validated in the same way func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() @@ -71,24 +74,16 @@ func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error { } } - if (req.Fields) != nil { - // max 25 fields - if len(*req.Fields) > db.MaxFields { - return server.APIError{ - Code: server.ErrBadRequest, - Details: fmt.Sprintf("Too many fields (max %d, current %d)", db.MaxFields, len(*req.Fields)), - } - } + if err := validateSlicePtr("name", req.Names); err != nil { + return err + } - // validate all fields - for i, field := range *req.Fields { - if s := field.Validate(); s != "" { - return server.APIError{ - Code: server.ErrBadRequest, - Details: fmt.Sprintf("field %d: %s", i, s), - } - } - } + if err := validateSlicePtr("pronoun", req.Pronouns); err != nil { + return err + } + + if err := validateSlicePtr("field", req.Fields); err != nil { + return err } // start transaction @@ -105,7 +100,12 @@ func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error { return err } - var fields []db.Field + var ( + names []db.Name + pronouns []db.Pronoun + fields []db.Field + ) + if req.Fields != nil { err = s.DB.SetUserFields(ctx, tx, claims.UserID, *req.Fields) if err != nil { @@ -127,6 +127,38 @@ func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error { } // echo the updated user back on success - render.JSON(w, r, dbUserToResponse(u, fields)) + render.JSON(w, r, dbUserToResponse(u, fields, names, pronouns)) + return nil +} + +type validator interface { + Validate() string +} + +// validateSlicePtr validates a slice of validators. +// If the slice is nil, a nil error is returned (assuming that the field is not required) +func validateSlicePtr[T validator](typ string, slice *[]T) error { + if slice == nil { + return nil + } + + // max 25 fields + if len(*slice) > db.MaxFields { + return server.APIError{ + Code: server.ErrBadRequest, + Details: fmt.Sprintf("Too many %ss (max %d, current %d)", typ, db.MaxFields, len(*slice)), + } + } + + // validate all fields + for i, pronouns := range *slice { + if s := pronouns.Validate(); s != "" { + return server.APIError{ + Code: server.ErrBadRequest, + Details: fmt.Sprintf("%s %d: %s", typ, i, s), + } + } + } + return nil }