From b41ca0b75312802c7c8eccc2fd790b563196dc6d Mon Sep 17 00:00:00 2001 From: Sam Date: Sat, 25 Feb 2023 22:16:22 +0100 Subject: [PATCH] fix(backend): fix sql errors in CreateUser and User.UpdateFromDiscord --- backend/db/user.go | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/backend/db/user.go b/backend/db/user.go index 589cbc2..902b31e 100644 --- a/backend/db/user.go +++ b/backend/db/user.go @@ -65,15 +65,18 @@ func (db *DB) CreateUser(ctx context.Context, tx pgx.Tx, username string) (u Use return u, ErrInvalidUsername } - sql, args, err := sq.Insert("users").Columns("id", "username").Values(xid.New(), username).Suffix("RETURNING *").ToSql() + sql, args, err := sq.Insert("users").Columns("id", "username").Values(xid.New(), username).Suffix("RETURNING id").ToSql() if err != nil { return u, errors.Wrap(err, "building sql") } - err = pgxscan.Get(ctx, tx, &u, sql, args...) + var id xid.ID + err = tx.QueryRow(ctx, sql, args...).Scan(&id) if err != nil { - if v, ok := errors.Cause(err).(*pgconn.PgError); ok { - if v.Code == "23505" { // unique constraint violation + pge := &pgconn.PgError{} + if errors.As(err, &pge) { + // unique constraint violation + if pge.Code == "23505" { return u, ErrUsernameTaken } } @@ -81,7 +84,7 @@ func (db *DB) CreateUser(ctx context.Context, tx pgx.Tx, username string) (u Use return u, errors.Cause(err) } - return u, nil + return db.getUser(ctx, tx, id) } // DiscordUser fetches a user by Discord user ID. @@ -103,18 +106,25 @@ func (db *DB) DiscordUser(ctx context.Context, discordID string) (u User, err er } func (u *User) UpdateFromDiscord(ctx context.Context, db querier, du *discordgo.User) error { - builder := sq.Update("users"). + sql, args, err := sq.Update("users"). Set("discord", du.ID). Set("discord_username", du.String()). Where("id = ?", u.ID). - Suffix("RETURNING *") - - sql, args, err := builder.ToSql() + ToSql() if err != nil { return errors.Wrap(err, "building sql") } - return pgxscan.Get(ctx, db, u, sql, args...) + _, err = db.Exec(ctx, sql, args...) + if err != nil { + return errors.Wrap(err, "executing query") + } + + u.Discord = &du.ID + username := du.String() + u.DiscordUsername = &username + + return nil } func (db *DB) getUser(ctx context.Context, q querier, id xid.ID) (u User, err error) {