From 6237ea940fd12048c2e7646499452acca7fa7ee1 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 18 Nov 2022 15:27:52 +0100 Subject: [PATCH] feat: add invites to backend --- backend/db/invites.go | 111 +++++++++++++++++++++++++++++++++ backend/db/user.go | 2 + backend/routes/auth/discord.go | 17 ++--- backend/routes/auth/invite.go | 68 ++++++++++++++++++++ backend/routes/auth/routes.go | 4 ++ backend/server/errors.go | 51 ++++++++------- scripts/migrate/001_init.sql | 11 +++- 7 files changed, 234 insertions(+), 30 deletions(-) create mode 100644 backend/db/invites.go create mode 100644 backend/routes/auth/invite.go diff --git a/backend/db/invites.go b/backend/db/invites.go new file mode 100644 index 0000000..12e2506 --- /dev/null +++ b/backend/db/invites.go @@ -0,0 +1,111 @@ +package db + +import ( + "context" + "crypto/rand" + "encoding/base64" + "time" + + "emperror.dev/errors" + "github.com/georgysavva/scany/pgxscan" + "github.com/jackc/pgx/v4" + "github.com/rs/xid" +) + +type Invite struct { + UserID xid.ID + Code string + Created time.Time + Used bool +} + +func (db *DB) UserInvites(ctx context.Context, userID xid.ID) (is []Invite, err error) { + sql, args, err := sq.Select("*").From("invites").Where("user_id = ?", userID).OrderBy("created").ToSql() + if err != nil { + return nil, errors.Wrap(err, "building sql") + } + + err = pgxscan.Select(ctx, db, &is, sql, args...) + if err != nil { + return nil, errors.Wrap(err, "querying database") + } + if len(is) == 0 { + is = []Invite{} + } + + return is, nil +} + +const ErrTooManyInvites = errors.Sentinel("user invite limit reached") + +func (db *DB) CreateInvite(ctx context.Context, userID xid.ID) (i Invite, err error) { + tx, err := db.Begin(ctx) + if err != nil { + return i, errors.Wrap(err, "beginning transaction") + } + defer tx.Rollback(ctx) + + var maxInvites, inviteCount int + err = tx.QueryRow(ctx, "SELECT max_invites FROM users WHERE id = $1", userID).Scan(&maxInvites) + if err != nil { + return i, errors.Wrap(err, "querying invite limit") + } + err = tx.QueryRow(ctx, "SELECT count(*) FROM invites WHERE user_id = $1", userID).Scan(&inviteCount) + if err != nil { + return i, errors.Wrap(err, "querying current invite count") + } + + if inviteCount >= maxInvites { + return i, ErrTooManyInvites + } + + b := make([]byte, 32) + + _, err = rand.Read(b) + if err != nil { + panic(err) + } + + code := base64.RawURLEncoding.EncodeToString(b) + + sql, args, err := sq.Insert("invites").Columns("user_id", "code").Values(userID, code).Suffix("RETURNING *").ToSql() + if err != nil { + return i, errors.Wrap(err, "building insert invite sql") + } + + err = pgxscan.Get(ctx, db, &i, sql, args...) + if err != nil { + return i, errors.Wrap(err, "inserting invite") + } + + err = tx.Commit(ctx) + if err != nil { + return i, errors.Wrap(err, "committing transaction") + } + return i, nil +} + +func (db *DB) InvalidateInvite(ctx context.Context, tx pgx.Tx, code string) (valid, alreadyUsed bool, err error) { + err = tx.QueryRow(ctx, "SELECT used FROM invites WHERE code = $1", code).Scan(&alreadyUsed) + if err != nil { + if errors.Cause(err) == pgx.ErrNoRows { + return false, false, nil + } + + return false, false, errors.Wrap(err, "checking if invite exists and is used") + } + + // valid: true, already used: true + if alreadyUsed { + return true, true, nil + } + + // invite is valid, not already used + _, err = tx.Exec(ctx, "UPDATE invites SET used = true WHERE code = $1", code) + if err != nil { + return false, false, errors.Wrap(err, "updating invite usage") + } + + // valid: true, already used: false + return true, false, nil +} diff --git a/backend/db/user.go b/backend/db/user.go index db726c7..67bb0de 100644 --- a/backend/db/user.go +++ b/backend/db/user.go @@ -24,6 +24,8 @@ type User struct { Discord *string DiscordUsername *string + + MaxInvites int } // usernames must match this regex diff --git a/backend/routes/auth/discord.go b/backend/routes/auth/discord.go index 582cb43..7e89f2e 100644 --- a/backend/routes/auth/discord.go +++ b/backend/routes/auth/discord.go @@ -182,17 +182,18 @@ func (s *Server) discordSignup(w http.ResponseWriter, r *http.Request) error { } if s.RequireInvite { - // TODO: check invites, invalidate invite when done - inviteValid := true - - if !inviteValid { - err = tx.Rollback(ctx) - if err != nil { - return errors.Wrap(err, "rolling back transaction") - } + valid, used, err := s.DB.InvalidateInvite(ctx, tx, req.InviteCode) + if err != nil { + return errors.Wrap(err, "checking and invalidating invite") + } + if !valid { return server.APIError{Code: server.ErrInviteRequired} } + + if used { + return server.APIError{Code: server.ErrInviteAlreadyUsed} + } } // delete sign up ticket diff --git a/backend/routes/auth/invite.go b/backend/routes/auth/invite.go new file mode 100644 index 0000000..00279c3 --- /dev/null +++ b/backend/routes/auth/invite.go @@ -0,0 +1,68 @@ +package auth + +import ( + "net/http" + "time" + + "codeberg.org/u1f320/pronouns.cc/backend/db" + "codeberg.org/u1f320/pronouns.cc/backend/server" + "emperror.dev/errors" + "github.com/go-chi/render" +) + +type inviteResponse struct { + Code string `json:"string"` + Created time.Time `json:"created"` + Used bool `json:"used"` +} + +func dbInviteToResponse(i db.Invite) inviteResponse { + return inviteResponse{ + Code: i.Code, + Created: i.Created, + Used: i.Used, + } +} + +func (s *Server) getInvites(w http.ResponseWriter, r *http.Request) error { + if !s.RequireInvite { + return server.APIError{Code: server.ErrInvitesDisabled} + } + + ctx := r.Context() + claims, _ := server.ClaimsFromContext(ctx) + + is, err := s.DB.UserInvites(ctx, claims.UserID) + if err != nil { + return errors.Wrap(err, "getting user invites") + } + + resps := make([]inviteResponse, len(is)) + for i := range is { + resps[i] = dbInviteToResponse(is[i]) + } + + render.JSON(w, r, resps) + return nil +} + +func (s *Server) createInvite(w http.ResponseWriter, r *http.Request) error { + if !s.RequireInvite { + return server.APIError{Code: server.ErrInvitesDisabled} + } + + ctx := r.Context() + claims, _ := server.ClaimsFromContext(ctx) + + inv, err := s.DB.CreateInvite(ctx, claims.UserID) + if err != nil { + if err == db.ErrTooManyInvites { + return server.APIError{Code: server.ErrInviteLimitReached} + } + + return errors.Wrap(err, "creating invite") + } + + render.JSON(w, r, dbInviteToResponse(inv)) + return nil +} diff --git a/backend/routes/auth/routes.go b/backend/routes/auth/routes.go index e29d28f..fe9a284 100644 --- a/backend/routes/auth/routes.go +++ b/backend/routes/auth/routes.go @@ -63,6 +63,10 @@ func Mount(srv *server.Server, r chi.Router) { // takes discord signup ticket to register account r.Post("/signup", server.WrapHandler(s.discordSignup)) }) + + // invite routes + r.With(server.MustAuth).Get("/invites", server.WrapHandler(s.getInvites)) + r.With(server.MustAuth).Post("/invites", server.WrapHandler(s.createInvite)) }) } diff --git a/backend/server/errors.go b/backend/server/errors.go index f998387..0f6c4c9 100644 --- a/backend/server/errors.go +++ b/backend/server/errors.go @@ -73,13 +73,16 @@ const ( ErrInternalServerError = 500 // catch-all code for unknown errors // Login/authorize error codes - ErrInvalidState = 1001 - ErrInvalidOAuthCode = 1002 - ErrInvalidToken = 1003 // a token was supplied, but it is invalid - ErrInviteRequired = 1004 - ErrInvalidTicket = 1005 // invalid signup ticket - ErrInvalidUsername = 1006 // invalid username (when signing up) - ErrUsernameTaken = 1007 // username taken (when signing up) + ErrInvalidState = 1001 + ErrInvalidOAuthCode = 1002 + ErrInvalidToken = 1003 // a token was supplied, but it is invalid + ErrInviteRequired = 1004 + ErrInvalidTicket = 1005 // invalid signup ticket + ErrInvalidUsername = 1006 // invalid username (when signing up) + ErrUsernameTaken = 1007 // username taken (when signing up) + ErrInvitesDisabled = 1008 // invites are disabled (unneeded) + ErrInviteLimitReached = 1009 // invite limit reached (when creating invites) + ErrInviteAlreadyUsed = 1010 // invite already used (when signing up) // User-related error codes ErrUserNotFound = 2001 @@ -100,13 +103,16 @@ var errCodeMessages = map[int]string{ ErrTooManyRequests: "Rate limit reached", ErrMethodNotAllowed: "Method not allowed", - ErrInvalidState: "Invalid OAuth state", - ErrInvalidOAuthCode: "Invalid OAuth code", - ErrInvalidToken: "Supplied token was invalid", - ErrInviteRequired: "A valid invite code is required", - ErrInvalidTicket: "Invalid signup ticket", - ErrInvalidUsername: "Invalid username", - ErrUsernameTaken: "Username is already taken", + ErrInvalidState: "Invalid OAuth state", + ErrInvalidOAuthCode: "Invalid OAuth code", + ErrInvalidToken: "Supplied token was invalid", + ErrInviteRequired: "A valid invite code is required", + ErrInvalidTicket: "Invalid signup ticket", + ErrInvalidUsername: "Invalid username", + ErrUsernameTaken: "Username is already taken", + ErrInvitesDisabled: "Invites are disabled", + ErrInviteLimitReached: "Your account has reached the invite limit", + ErrInviteAlreadyUsed: "That invite code has already been used", ErrUserNotFound: "User not found", @@ -124,13 +130,16 @@ var errCodeStatuses = map[int]int{ ErrTooManyRequests: http.StatusTooManyRequests, ErrMethodNotAllowed: http.StatusMethodNotAllowed, - ErrInvalidState: http.StatusBadRequest, - ErrInvalidOAuthCode: http.StatusForbidden, - ErrInvalidToken: http.StatusUnauthorized, - ErrInviteRequired: http.StatusBadRequest, - ErrInvalidTicket: http.StatusBadRequest, - ErrInvalidUsername: http.StatusBadRequest, - ErrUsernameTaken: http.StatusBadRequest, + ErrInvalidState: http.StatusBadRequest, + ErrInvalidOAuthCode: http.StatusForbidden, + ErrInvalidToken: http.StatusUnauthorized, + ErrInviteRequired: http.StatusBadRequest, + ErrInvalidTicket: http.StatusBadRequest, + ErrInvalidUsername: http.StatusBadRequest, + ErrUsernameTaken: http.StatusBadRequest, + ErrInvitesDisabled: http.StatusForbidden, + ErrInviteLimitReached: http.StatusForbidden, + ErrInviteAlreadyUsed: http.StatusBadRequest, ErrUserNotFound: http.StatusNotFound, diff --git a/scripts/migrate/001_init.sql b/scripts/migrate/001_init.sql index ae20ee5..a60e2aa 100644 --- a/scripts/migrate/001_init.sql +++ b/scripts/migrate/001_init.sql @@ -12,7 +12,9 @@ create table users ( links text[], discord text unique, -- for Discord oauth - discord_username text + discord_username text, + + max_invites int default 10 ); create table user_names ( @@ -80,3 +82,10 @@ create table member_fields ( friends_only text[], avoid text[] ); + +create table invites ( + user_id text not null references users (id) on delete cascade, + code text primary key, + created timestamp not null default (current_timestamp at time zone 'utc'), + used boolean not null default false +);