diff --git a/backend/db/email.go b/backend/db/email.go index d26e01e..27c861b 100644 --- a/backend/db/email.go +++ b/backend/db/email.go @@ -19,12 +19,16 @@ type UserEmail struct { } func (db *DB) UserEmails(ctx context.Context, userID common.UserID) (es []UserEmail, err error) { + return db.UserEmailsTx(ctx, db, userID) +} + +func (db *DB) UserEmailsTx(ctx context.Context, q pgxscan.Querier, userID common.UserID) (es []UserEmail, err error) { sql, args, err := sq.Select("*").From("user_emails").Where("user_id = ?", userID).OrderBy("id").ToSql() if err != nil { return nil, errors.Wrap(err, "building query") } - err = pgxscan.Select(ctx, db, &es, sql, args...) + err = pgxscan.Select(ctx, q, &es, sql, args...) if err != nil { return nil, errors.Wrap(err, "executing query") } @@ -59,10 +63,15 @@ func (db *DB) EmailExists(ctx context.Context, email string) (exists bool, err e return exists, err } +func (db *DB) EmailExistsTx(ctx context.Context, tx pgx.Tx, email string) (exists bool, err error) { + err = tx.QueryRow(ctx, "select exists(SELECT * FROM user_emails WHERE email_address = $1)", email).Scan(&exists) + return exists, err +} + const ErrEmailInUse = errors.Sentinel("email already in use") -// AddEmail adds a new email to the database, and generates a confirmation token for it. -func (db *DB) AddEmail(ctx context.Context, tx pgx.Tx, userID common.UserID, email string) (e UserEmail, err error) { +// AddEmail adds a new email to the database. +func (db *DB) AddEmail(ctx context.Context, q pgxscan.Querier, userID common.UserID, email string) (e UserEmail, err error) { sql, args, err := sq.Insert("user_emails").SetMap(map[string]any{ "id": common.GenerateID(), "user_id": userID, @@ -72,7 +81,7 @@ func (db *DB) AddEmail(ctx context.Context, tx pgx.Tx, userID common.UserID, ema return e, errors.Wrap(err, "building query") } - err = pgxscan.Get(ctx, tx, &e, sql, args...) + err = pgxscan.Get(ctx, q, &e, sql, args...) if err != nil { pge := &pgconn.PgError{} if errors.As(err, &pge) { @@ -104,3 +113,13 @@ func (db *DB) SetPassword(ctx context.Context, tx pgx.Tx, userID common.UserID, _, err = tx.Exec(ctx, sql, args...) return errors.Wrap(err, "executing query") } + +func (db *DB) RemoveEmails(ctx context.Context, tx pgx.Tx, userID common.UserID) (err error) { + sql, args, err := sq.Delete("user_emails").Where("user_id = ?", userID).ToSql() + if err != nil { + return errors.Wrap(err, "building sql") + } + + _, err = tx.Exec(ctx, sql, args...) + return errors.Wrap(err, "executing query") +} diff --git a/backend/db/user.go b/backend/db/user.go index 055c06c..47168f2 100644 --- a/backend/db/user.go +++ b/backend/db/user.go @@ -105,7 +105,10 @@ const ( PreferenceSizeSmall PreferenceSize = "small" ) -func (u User) NumProviders() (numProviders int) { +func (u User) NumProviders(emails []UserEmail) (numProviders int) { + if len(emails) > 0 { + numProviders++ + } if u.Discord != nil { numProviders++ } diff --git a/backend/routes/v1/auth/discord.go b/backend/routes/v1/auth/discord.go index 342bbc0..7576c0e 100644 --- a/backend/routes/v1/auth/discord.go +++ b/backend/routes/v1/auth/discord.go @@ -235,8 +235,13 @@ func (s *Server) discordUnlink(w http.ResponseWriter, r *http.Request) error { return server.APIError{Code: server.ErrNotLinked} } + emails, err := s.DB.UserEmails(ctx, u.SnowflakeID) + if err != nil { + return errors.Wrap(err, "getting user emails") + } + // cannot unlink last auth provider - if u.NumProviders() <= 1 { + if u.NumProviders(emails) <= 1 { return server.APIError{Code: server.ErrLastProvider} } diff --git a/backend/routes/v1/auth/fedi_mastodon.go b/backend/routes/v1/auth/fedi_mastodon.go index 28dee0f..5b9d009 100644 --- a/backend/routes/v1/auth/fedi_mastodon.go +++ b/backend/routes/v1/auth/fedi_mastodon.go @@ -262,8 +262,13 @@ func (s *Server) mastodonUnlink(w http.ResponseWriter, r *http.Request) error { return server.APIError{Code: server.ErrNotLinked} } + emails, err := s.DB.UserEmails(ctx, u.SnowflakeID) + if err != nil { + return errors.Wrap(err, "getting user emails") + } + // cannot unlink last auth provider - if u.NumProviders() <= 1 { + if u.NumProviders(emails) <= 1 { return server.APIError{Code: server.ErrLastProvider} } diff --git a/backend/routes/v1/auth/google.go b/backend/routes/v1/auth/google.go index 166080c..82c2f20 100644 --- a/backend/routes/v1/auth/google.go +++ b/backend/routes/v1/auth/google.go @@ -250,8 +250,13 @@ func (s *Server) googleUnlink(w http.ResponseWriter, r *http.Request) error { return server.APIError{Code: server.ErrNotLinked} } + emails, err := s.DB.UserEmails(ctx, u.SnowflakeID) + if err != nil { + return errors.Wrap(err, "getting user emails") + } + // cannot unlink last auth provider - if u.NumProviders() <= 1 { + if u.NumProviders(emails) <= 1 { return server.APIError{Code: server.ErrLastProvider} } diff --git a/backend/routes/v1/auth/tumblr.go b/backend/routes/v1/auth/tumblr.go index 5c372b6..b274b39 100644 --- a/backend/routes/v1/auth/tumblr.go +++ b/backend/routes/v1/auth/tumblr.go @@ -283,8 +283,13 @@ func (s *Server) tumblrUnlink(w http.ResponseWriter, r *http.Request) error { return server.APIError{Code: server.ErrNotLinked} } + emails, err := s.DB.UserEmails(ctx, u.SnowflakeID) + if err != nil { + return errors.Wrap(err, "getting user emails") + } + // cannot unlink last auth provider - if u.NumProviders() <= 1 { + if u.NumProviders(emails) <= 1 { return server.APIError{Code: server.ErrLastProvider} } diff --git a/backend/routes/v2/auth/email_signup.go b/backend/routes/v2/auth/email_signup.go index 989c139..4f6bd75 100644 --- a/backend/routes/v2/auth/email_signup.go +++ b/backend/routes/v2/auth/email_signup.go @@ -50,6 +50,7 @@ func (s *Server) postEmailSignup(w http.ResponseWriter, r *http.Request) (err er "Ticket": ticket, }) + render.NoContent(w, r) return nil } diff --git a/backend/routes/v2/auth/put_email.go b/backend/routes/v2/auth/put_email.go new file mode 100644 index 0000000..fdddfb5 --- /dev/null +++ b/backend/routes/v2/auth/put_email.go @@ -0,0 +1,225 @@ +package auth + +import ( + "fmt" + "net/http" + "strings" + + "codeberg.org/pronounscc/pronouns.cc/backend/common" + "codeberg.org/pronounscc/pronouns.cc/backend/db" + "codeberg.org/pronounscc/pronouns.cc/backend/log" + "codeberg.org/pronounscc/pronouns.cc/backend/server" + "emperror.dev/errors" + "github.com/go-chi/render" + "github.com/jackc/pgx/v5" + "github.com/mediocregopher/radix/v4" +) + +type putEmailRequest struct { + Email string `json:"email"` + Password string `json:"password"` +} + +func (s *Server) putEmail(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + claims, _ := server.ClaimsFromContext(ctx) + if claims.APIToken { + return server.APIError{Code: server.ErrMissingPermissions, Details: "This endpoint cannot be used by API tokens"} + } + + var req putEmailRequest + err := render.Decode(r, &req) + if err != nil { + return server.APIError{Code: server.ErrBadRequest} + } + + u, err := s.DB.User(ctx, claims.UserID) + if err != nil { + // this should never fail + log.Errorf("getting user: %v", err) + return errors.Wrap(err, "getting user") + } + + tx, err := s.DB.Begin(ctx) + if err != nil { + return errors.Wrap(err, "beginning transaction") + } + defer func() { + _ = tx.Rollback(ctx) + }() + + emails, err := s.DB.UserEmailsTx(ctx, tx, u.SnowflakeID) + if err != nil { + log.Errorf("getting user emails: %v", err) + return errors.Wrap(err, "getting user emails") + } + + if len(emails) > 0 { + if emails[0].EmailAddress == req.Email { + return server.APIError{Code: server.ErrBadRequest, Details: "New email address cannot be the same as the current one"} + } + + return s.putEmailExisting(w, r, tx, u, req) + } + + ticket := common.RandBase64(48) + err = s.DB.Redis.Do(ctx, radix.Cmd(nil, "SET", + emailChangeTicketKey(ticket), emailChangeTicketValue(req.Email, u.SnowflakeID), "EX", "3600")) + if err != nil { + return errors.Wrap(err, "setting email change key") + } + + // if the email address already exists, pretend we sent an email and return + exists, err := s.DB.EmailExistsTx(ctx, tx, req.Email) + if err != nil { + return errors.Wrap(err, "checking if email exists") + } + if exists { + render.NoContent(w, r) + return nil + } + + // set the user's password, this won't do anything unless the email address is actually confirmed + err = s.DB.SetPassword(ctx, tx, u.SnowflakeID, req.Password) + if err != nil { + return errors.Wrap(err, "setting user password") + } + + err = tx.Commit(ctx) + if err != nil { + return errors.Wrap(err, "committing transaction") + } + + // send the email + go s.SendEmail(req.Email, "Confirm your email address", "change", map[string]any{ + "Ticket": ticket, + "Username": u.Username, + }) + + render.NoContent(w, r) + return nil +} + +func (s *Server) putEmailExisting( + w http.ResponseWriter, r *http.Request, tx pgx.Tx, u db.User, req putEmailRequest, +) error { + ctx := r.Context() + + if !u.VerifyPassword(req.Password) { + return server.APIError{Code: server.ErrForbidden, Details: "Invalid password"} + } + + ticket := common.RandBase64(48) + err := s.DB.Redis.Do(ctx, radix.Cmd(nil, "SET", + emailChangeTicketKey(ticket), emailChangeTicketValue(req.Email, u.SnowflakeID), "EX", "3600")) + if err != nil { + return errors.Wrap(err, "setting email change key") + } + + // if the email address already exists, pretend we sent an email and return + exists, err := s.DB.EmailExistsTx(ctx, tx, req.Email) + if err != nil { + return errors.Wrap(err, "checking if email exists") + } + if exists { + render.NoContent(w, r) + return nil + } + + go s.SendEmail(req.Email, "Confirm your email address", "change", map[string]any{ + "Ticket": ticket, + "Username": u.Username, + }) + + render.NoContent(w, r) + return nil +} + +type putEmailConfirmRequest struct { + Ticket string `json:"ticket"` +} + +type putEmailConfirmResponse struct { + Email string `json:"email"` + User userResponse `json:"user"` +} + +func (s *Server) putEmailConfim(w http.ResponseWriter, r *http.Request) (err error) { + ctx := r.Context() + var req putEmailConfirmRequest + err = render.Decode(r, &req) + if err != nil { + return server.APIError{Code: server.ErrBadRequest} + } + + var ticket string + err = s.DB.Redis.Do(ctx, radix.Cmd(&ticket, "GET", emailChangeTicketKey(req.Ticket))) + if err != nil { + return errors.Wrap(err, "getting email change key") + } + if ticket == "" { + return server.APIError{Code: server.ErrBadRequest, Details: "Unknown ticket"} + } + + email, userID, ok := parseEmailChangeTicket(ticket) + if !ok { + return fmt.Errorf("invalid email change ticket %q", ticket) + } + + u, err := s.DB.UserBySnowflake(ctx, userID) + if err != nil { + return errors.Wrap(err, "getting user") + } + + tx, err := s.DB.Begin(ctx) + if err != nil { + return errors.Wrap(err, "beginning transaction") + } + defer func() { + _ = tx.Rollback(ctx) + }() + + err = s.DB.RemoveEmails(ctx, tx, userID) + if err != nil { + return errors.Wrapf(err, "removing existing email addresses for user %v", userID) + } + + dbEmail, err := s.DB.AddEmail(ctx, s.DB, userID, email) + if err != nil { + if err == db.ErrEmailInUse { + // This should only happen if the email was *not* taken when the ticket was sent, but was taken in the meantime. + // i.e. unless another person has access to the mailbox, the user will know what happened + return server.APIError{Code: server.ErrBadRequest, Details: "Email is already in use"} + } + + return errors.Wrap(err, "adding email to user") + } + + render.JSON(w, r, putEmailConfirmResponse{ + Email: dbEmail.EmailAddress, + User: *dbUserToUserResponse(u, nil), + }) + return nil +} + +func emailChangeTicketKey(ticket string) string { + return "email-change:" + ticket +} + +func emailChangeTicketValue(email string, userID common.UserID) string { + return email + ":" + userID.String() +} + +func parseEmailChangeTicket(v string) (email string, userID common.UserID, ok bool) { + before, after, ok := strings.Cut(v, ":") + if !ok { + return "", common.UserID(common.NullSnowflake), false + } + + id, err := common.ParseSnowflake(after) + if err != nil { + return "", common.UserID(common.NullSnowflake), false + } + + return before, common.UserID(id), true +} diff --git a/backend/routes/v2/auth/routes.go b/backend/routes/v2/auth/routes.go index 1ee3826..b93ac35 100644 --- a/backend/routes/v2/auth/routes.go +++ b/backend/routes/v2/auth/routes.go @@ -36,14 +36,14 @@ func Mount(srv *server.Server, r chi.Router) { )) r.Route("/auth/email", func(r chi.Router) { - r.With(server.MustAuth).Get("/", server.WrapHandler(s.getEmails)) // List existing email addresses for account - r.With(server.MustAuth).Post("/", nil) // Add/update email to existing account, { email } - r.With(server.MustAuth).Delete("/{id}", nil) // Remove existing email from account, + r.With(server.MustAuth).Get("/", server.WrapHandler(s.getEmails)) + r.With(server.MustAuth).Put("/", server.WrapHandler(s.putEmail)) + r.With(server.MustAuth).Delete("/{id}", nil) // Remove existing email from account, r.Post("/login", server.WrapHandler(s.postLogin)) // Log in to account, { username, password } r.Post("/signup", server.WrapHandler(s.postEmailSignup)) // Create account, { email } r.Post("/signup/confirm", server.WrapHandler(s.postEmailSignupConfirm)) // Create account, { ticket, username, password } - r.Post("/confirm", nil) // Confirm email address, { ticket } + r.Post("/confirm", server.WrapHandler(s.putEmailConfim)) // Confirm email address, { ticket } r.Patch("/password", nil) // Update password r.Post("/password/forgot", nil) // Forgot/reset password, { email } diff --git a/backend/routes/v2/auth/templates/change.html b/backend/routes/v2/auth/templates/change.html new file mode 100644 index 0000000..5a5a2d2 --- /dev/null +++ b/backend/routes/v2/auth/templates/change.html @@ -0,0 +1,24 @@ + + + + + + + + +

+ To change @{{.Username}}'s email address, press the following link: +
+ Confirm your new email address +
+ Note that this link will expire in one hour. +

+

+ If you didn't mean to change your email address, feel free to ignore this email. +

+ + diff --git a/backend/routes/v2/auth/templates/change.txt b/backend/routes/v2/auth/templates/change.txt new file mode 100644 index 0000000..3320737 --- /dev/null +++ b/backend/routes/v2/auth/templates/change.txt @@ -0,0 +1,5 @@ +To change @{{.Username}}'s email address, press the following link: +{{.BaseURL}}/auth/email/confirm/{{.Ticket}} +This link will expire in one hour. + +If you didn't mean to change your email address, feel free to ignore this email. diff --git a/backend/routes/v2/auth/templates/signup.html b/backend/routes/v2/auth/templates/signup.html index 2058064..00270bc 100644 --- a/backend/routes/v2/auth/templates/signup.html +++ b/backend/routes/v2/auth/templates/signup.html @@ -1,7 +1,7 @@ - +