From b1a7ef89ca2f3650e1aa3e28fd253daac462454b Mon Sep 17 00:00:00 2001 From: sam Date: Thu, 17 Aug 2023 18:49:32 +0200 Subject: [PATCH] feat(backend): add snowflake IDs --- backend/common/generator.go | 65 ++++++++++++++ backend/common/snowflake.go | 83 ++++++++++++++++++ backend/common/snowflake_types.go | 39 +++++++++ backend/db/flags.go | 23 ++--- backend/db/member.go | 6 +- backend/db/user.go | 3 +- backend/routes/v1/auth/routes.go | 3 + backend/routes/v1/member/get_member.go | 19 ++-- backend/routes/v1/member/get_members.go | 3 + backend/routes/v1/user/get_user.go | 5 ++ main.go | 2 + scripts/migrate/021_snowflakes.sql | 13 +++ scripts/snowflakes/main.go | 111 ++++++++++++++++++++++++ 13 files changed, 355 insertions(+), 20 deletions(-) create mode 100644 backend/common/generator.go create mode 100644 backend/common/snowflake.go create mode 100644 backend/common/snowflake_types.go create mode 100644 scripts/migrate/021_snowflakes.sql create mode 100644 scripts/snowflakes/main.go diff --git a/backend/common/generator.go b/backend/common/generator.go new file mode 100644 index 0000000..33385e1 --- /dev/null +++ b/backend/common/generator.go @@ -0,0 +1,65 @@ +package common + +import ( + "math/rand" + "sync/atomic" + "time" +) + +// Generator is a snowflake generator. +// For compatibility with other snowflake implementations, both worker and PID are set, +// but they are randomized for every generator. +type IDGenerator struct { + inc *uint64 + worker, pid uint64 +} + +var defaultGenerator = NewIDGenerator(0, 0) + +// NewIDGenerator creates a new ID generator with the given worker and pid. +// If worker or pid is empty, it will be set to a random number. +func NewIDGenerator(worker, pid uint64) *IDGenerator { + if worker == 0 { + worker = rand.Uint64() + } + if pid == 0 { + pid = rand.Uint64() + } + + g := &IDGenerator{ + inc: new(uint64), + worker: worker % 32, + pid: pid % 32, + } + + return g +} + +// GenerateID generates a new snowflake with the default generator. +// If you need to customize the worker and PID, manually call (*Generator).Generate. +func GenerateID() Snowflake { + return defaultGenerator.Generate() +} + +// GenerateID generates a new snowflake with the given time with the default generator. +// If you need to customize the worker and PID, manually call (*Generator).GenerateWithTime. +func GenerateIDWithTime(t time.Time) Snowflake { + return defaultGenerator.GenerateWithTime(t) +} + +// Generate generates a snowflake with the current time. +func (g *IDGenerator) Generate() Snowflake { + return g.GenerateWithTime(time.Now()) +} + +// GenerateWithTime generates a snowflake with the given time. +// To generate a snowflake for comparison, use the top-level New function instead. +func (g *IDGenerator) GenerateWithTime(t time.Time) Snowflake { + increment := atomic.AddUint64(g.inc, 1) + ts := uint64(t.UnixMilli() - Epoch) + + worker := g.worker << 17 + pid := g.pid << 12 + + return Snowflake(ts<<22 | worker | pid | (increment % 4096)) +} diff --git a/backend/common/snowflake.go b/backend/common/snowflake.go new file mode 100644 index 0000000..694ec24 --- /dev/null +++ b/backend/common/snowflake.go @@ -0,0 +1,83 @@ +package common + +import ( + "strconv" + "strings" + "time" +) + +// Epoch is the pronouns.cc epoch (January 1st 2022 at 00:00:00 UTC) in milliseconds. +const Epoch = 1_640_995_200_000 +const epochDuration = Epoch * time.Millisecond + +const NullSnowflake = ^Snowflake(0) + +// Snowflake is a 64-bit integer used as a unique ID, with an embedded timestamp. +type Snowflake uint64 + +// ID is an alias to Snowflake. +type ID = Snowflake + +// ParseSnowflake parses a snowflake from a string. +func ParseSnowflake(sf string) (Snowflake, error) { + if sf == "null" { + return NullSnowflake, nil + } + + i, err := strconv.ParseUint(sf, 10, 64) + if err != nil { + return 0, err + } + + return Snowflake(i), nil +} + +// NewSnowflake creates a new snowflake from the given time. +func NewSnowflake(t time.Time) Snowflake { + ts := time.Duration(t.UnixNano()) - epochDuration + + return Snowflake((ts / time.Millisecond) << 22) +} + +// String returns the snowflake as a string. +func (s Snowflake) String() string { return strconv.FormatUint(uint64(s), 10) } + +// Time returns the creation time of the snowflake. +func (s Snowflake) Time() time.Time { + ts := time.Duration(s>>22)*time.Millisecond + epochDuration + return time.Unix(0, int64(ts)) +} + +func (s Snowflake) IsValid() bool { + return s != 0 && s != NullSnowflake +} + +func (s Snowflake) MarshalJSON() ([]byte, error) { + if !s.IsValid() { + return []byte("null"), nil + } + + return []byte(`"` + strconv.FormatUint(uint64(s), 10) + `"`), nil +} + +func (s *Snowflake) UnmarshalJSON(src []byte) error { + sf, err := ParseSnowflake(strings.Trim(string(src), `"`)) + if err != nil { + return err + } + + *s = sf + return nil +} + +func (s Snowflake) Worker() uint8 { + return uint8(s & 0x3E0000 >> 17) +} + +func (s Snowflake) PID() uint8 { + return uint8(s & 0x1F000 >> 12) +} + +func (s Snowflake) Increment() uint16 { + return uint16(s & 0xFFF) +} diff --git a/backend/common/snowflake_types.go b/backend/common/snowflake_types.go new file mode 100644 index 0000000..3e9848f --- /dev/null +++ b/backend/common/snowflake_types.go @@ -0,0 +1,39 @@ +package common + +import "time" + +type UserID Snowflake + +func (id UserID) String() string { return Snowflake(id).String() } +func (id UserID) Time() time.Time { return Snowflake(id).Time() } +func (id UserID) IsValid() bool { return Snowflake(id).IsValid() } +func (id UserID) Worker() uint8 { return Snowflake(id).Worker() } +func (id UserID) PID() uint8 { return Snowflake(id).PID() } +func (id UserID) Increment() uint16 { return Snowflake(id).Increment() } + +func (id UserID) MarshalJSON() ([]byte, error) { return Snowflake(id).MarshalJSON() } +func (id *UserID) UnmarshalJSON(src []byte) error { return (*Snowflake)(id).UnmarshalJSON(src) } + +type MemberID Snowflake + +func (id MemberID) String() string { return Snowflake(id).String() } +func (id MemberID) Time() time.Time { return Snowflake(id).Time() } +func (id MemberID) IsValid() bool { return Snowflake(id).IsValid() } +func (id MemberID) Worker() uint8 { return Snowflake(id).Worker() } +func (id MemberID) PID() uint8 { return Snowflake(id).PID() } +func (id MemberID) Increment() uint16 { return Snowflake(id).Increment() } + +func (id MemberID) MarshalJSON() ([]byte, error) { return Snowflake(id).MarshalJSON() } +func (id *MemberID) UnmarshalJSON(src []byte) error { return (*Snowflake)(id).UnmarshalJSON(src) } + +type FlagID Snowflake + +func (id FlagID) String() string { return Snowflake(id).String() } +func (id FlagID) Time() time.Time { return Snowflake(id).Time() } +func (id FlagID) IsValid() bool { return Snowflake(id).IsValid() } +func (id FlagID) Worker() uint8 { return Snowflake(id).Worker() } +func (id FlagID) PID() uint8 { return Snowflake(id).PID() } +func (id FlagID) Increment() uint16 { return Snowflake(id).Increment() } + +func (id FlagID) MarshalJSON() ([]byte, error) { return Snowflake(id).MarshalJSON() } +func (id *FlagID) UnmarshalJSON(src []byte) error { return (*Snowflake)(id).UnmarshalJSON(src) } diff --git a/backend/db/flags.go b/backend/db/flags.go index 151cb92..8187c46 100644 --- a/backend/db/flags.go +++ b/backend/db/flags.go @@ -9,6 +9,7 @@ import ( "io" "strings" + "codeberg.org/pronounscc/pronouns.cc/backend/common" "codeberg.org/pronounscc/pronouns.cc/backend/log" "emperror.dev/errors" "github.com/davidbyttow/govips/v2/vips" @@ -20,11 +21,12 @@ import ( ) type PrideFlag struct { - ID xid.ID `json:"id"` - UserID xid.ID `json:"-"` - Hash string `json:"hash"` - Name string `json:"name"` - Description *string `json:"description"` + ID xid.ID `json:"id"` + SnowflakeID common.FlagID `json:"id_new"` + UserID xid.ID `json:"-"` + Hash string `json:"hash"` + Name string `json:"name"` + Description *string `json:"description"` } type UserFlag struct { @@ -194,11 +196,12 @@ func (db *DB) CreateFlag(ctx context.Context, tx pgx.Tx, userID xid.ID, name, de sql, args, err := sq.Insert("pride_flags"). SetMap(map[string]any{ - "id": xid.New(), - "hash": "", - "user_id": userID.String(), - "name": name, - "description": description, + "id": xid.New(), + "snowflake_id": common.GenerateID(), + "hash": "", + "user_id": userID.String(), + "name": name, + "description": description, }).Suffix("RETURNING *").ToSql() if err != nil { return f, errors.Wrap(err, "building query") diff --git a/backend/db/member.go b/backend/db/member.go index 71b3bd0..5de82aa 100644 --- a/backend/db/member.go +++ b/backend/db/member.go @@ -6,6 +6,7 @@ import ( "strings" "time" + "codeberg.org/pronounscc/pronouns.cc/backend/common" "emperror.dev/errors" "github.com/Masterminds/squirrel" "github.com/georgysavva/scany/v2/pgxscan" @@ -22,6 +23,7 @@ const ( type Member struct { ID xid.ID UserID xid.ID + SnowflakeID common.MemberID SID string `db:"sid"` Name string DisplayName *string @@ -135,8 +137,8 @@ func (db *DB) CreateMember( name string, displayName *string, bio string, links []string, ) (m Member, err error) { sql, args, err := sq.Insert("members"). - Columns("user_id", "id", "sid", "name", "display_name", "bio", "links"). - Values(userID, xid.New(), squirrel.Expr("find_free_member_sid()"), name, displayName, bio, links). + Columns("user_id", "snowflake_id", "id", "sid", "name", "display_name", "bio", "links"). + Values(userID, common.GenerateID(), xid.New(), squirrel.Expr("find_free_member_sid()"), name, displayName, bio, links). Suffix("RETURNING *").ToSql() if err != nil { return m, errors.Wrap(err, "building sql") diff --git a/backend/db/user.go b/backend/db/user.go index 95a3355..8b9a0ca 100644 --- a/backend/db/user.go +++ b/backend/db/user.go @@ -22,6 +22,7 @@ import ( type User struct { ID xid.ID + SnowflakeID common.UserID SID string `db:"sid"` Username string DisplayName *string @@ -206,7 +207,7 @@ func (db *DB) CreateUser(ctx context.Context, tx pgx.Tx, username string) (u Use return u, err } - sql, args, err := sq.Insert("users").Columns("id", "username", "sid").Values(xid.New(), username, squirrel.Expr("find_free_user_sid()")).Suffix("RETURNING *").ToSql() + sql, args, err := sq.Insert("users").Columns("id", "snowflake_id", "username", "sid").Values(xid.New(), common.GenerateID(), username, squirrel.Expr("find_free_user_sid()")).Suffix("RETURNING *").ToSql() if err != nil { return u, errors.Wrap(err, "building sql") } diff --git a/backend/routes/v1/auth/routes.go b/backend/routes/v1/auth/routes.go index 7699084..103d43e 100644 --- a/backend/routes/v1/auth/routes.go +++ b/backend/routes/v1/auth/routes.go @@ -4,6 +4,7 @@ import ( "net/http" "os" + "codeberg.org/pronounscc/pronouns.cc/backend/common" "codeberg.org/pronounscc/pronouns.cc/backend/db" "codeberg.org/pronounscc/pronouns.cc/backend/log" "codeberg.org/pronounscc/pronouns.cc/backend/server" @@ -25,6 +26,7 @@ type Server struct { type userResponse struct { ID xid.ID `json:"id"` + SnowflakeID common.UserID `json:"id_new"` Username string `json:"name"` DisplayName *string `json:"display_name"` Bio *string `json:"bio"` @@ -51,6 +53,7 @@ type userResponse struct { func dbUserToUserResponse(u db.User, fields []db.Field) *userResponse { return &userResponse{ ID: u.ID, + SnowflakeID: u.SnowflakeID, Username: u.Username, DisplayName: u.DisplayName, Bio: u.Bio, diff --git a/backend/routes/v1/member/get_member.go b/backend/routes/v1/member/get_member.go index ddc94d8..f87bbe9 100644 --- a/backend/routes/v1/member/get_member.go +++ b/backend/routes/v1/member/get_member.go @@ -4,6 +4,7 @@ import ( "context" "net/http" + "codeberg.org/pronounscc/pronouns.cc/backend/common" "codeberg.org/pronounscc/pronouns.cc/backend/db" "codeberg.org/pronounscc/pronouns.cc/backend/server" "emperror.dev/errors" @@ -13,13 +14,14 @@ import ( ) type GetMemberResponse struct { - ID xid.ID `json:"id"` - SID string `json:"sid"` - Name string `json:"name"` - DisplayName *string `json:"display_name"` - Bio *string `json:"bio"` - Avatar *string `json:"avatar"` - Links []string `json:"links"` + ID xid.ID `json:"id"` + SnowflakeID common.MemberID `json:"id_new"` + SID string `json:"sid"` + Name string `json:"name"` + DisplayName *string `json:"display_name"` + Bio *string `json:"bio"` + Avatar *string `json:"avatar"` + Links []string `json:"links"` Names []db.FieldEntry `json:"names"` Pronouns []db.PronounEntry `json:"pronouns"` @@ -34,6 +36,7 @@ type GetMemberResponse struct { func dbMemberToMember(u db.User, m db.Member, fields []db.Field, flags []db.MemberFlag, isOwnMember bool) GetMemberResponse { r := GetMemberResponse{ ID: m.ID, + SnowflakeID: m.SnowflakeID, SID: m.SID, Name: m.Name, DisplayName: m.DisplayName, @@ -48,6 +51,7 @@ func dbMemberToMember(u db.User, m db.Member, fields []db.Field, flags []db.Memb User: PartialUser{ ID: u.ID, + SnowflakeID: u.SnowflakeID, Username: u.Username, DisplayName: u.DisplayName, Avatar: u.Avatar, @@ -64,6 +68,7 @@ func dbMemberToMember(u db.User, m db.Member, fields []db.Field, flags []db.Memb type PartialUser struct { ID xid.ID `json:"id"` + SnowflakeID common.UserID `json:"id_new"` Username string `json:"name"` DisplayName *string `json:"display_name"` Avatar *string `json:"avatar"` diff --git a/backend/routes/v1/member/get_members.go b/backend/routes/v1/member/get_members.go index 6b08239..6dba566 100644 --- a/backend/routes/v1/member/get_members.go +++ b/backend/routes/v1/member/get_members.go @@ -3,6 +3,7 @@ package member import ( "net/http" + "codeberg.org/pronounscc/pronouns.cc/backend/common" "codeberg.org/pronounscc/pronouns.cc/backend/db" "codeberg.org/pronounscc/pronouns.cc/backend/server" "github.com/go-chi/chi/v5" @@ -12,6 +13,7 @@ import ( type memberListResponse struct { ID xid.ID `json:"id"` + SnowflakeID common.MemberID `json:"id_new"` SID string `json:"sid"` Name string `json:"name"` DisplayName *string `json:"display_name"` @@ -28,6 +30,7 @@ func membersToMemberList(ms []db.Member, isSelf bool) []memberListResponse { for i := range ms { resps[i] = memberListResponse{ ID: ms[i].ID, + SnowflakeID: ms[i].SnowflakeID, SID: ms[i].SID, Name: ms[i].Name, DisplayName: ms[i].DisplayName, diff --git a/backend/routes/v1/user/get_user.go b/backend/routes/v1/user/get_user.go index d1163a3..4826ed0 100644 --- a/backend/routes/v1/user/get_user.go +++ b/backend/routes/v1/user/get_user.go @@ -4,6 +4,7 @@ import ( "net/http" "time" + "codeberg.org/pronounscc/pronouns.cc/backend/common" "codeberg.org/pronounscc/pronouns.cc/backend/db" "codeberg.org/pronounscc/pronouns.cc/backend/log" "codeberg.org/pronounscc/pronouns.cc/backend/server" @@ -14,6 +15,7 @@ import ( type GetUserResponse struct { ID xid.ID `json:"id"` + SnowflakeID common.UserID `json:"id_new"` SID string `json:"sid"` Username string `json:"name"` DisplayName *string `json:"display_name"` @@ -58,6 +60,7 @@ type GetMeResponse struct { type PartialMember struct { ID xid.ID `json:"id"` + SnowflakeID common.MemberID `json:"id_new"` SID string `json:"sid"` Name string `json:"name"` DisplayName *string `json:"display_name"` @@ -71,6 +74,7 @@ type PartialMember struct { func dbUserToResponse(u db.User, fields []db.Field, members []db.Member, flags []db.UserFlag) GetUserResponse { resp := GetUserResponse{ ID: u.ID, + SnowflakeID: u.SnowflakeID, SID: u.SID, Username: u.Username, DisplayName: u.DisplayName, @@ -97,6 +101,7 @@ func dbUserToResponse(u db.User, fields []db.Field, members []db.Member, flags [ for i := range members { resp.Members[i] = PartialMember{ ID: members[i].ID, + SnowflakeID: members[i].SnowflakeID, SID: members[i].SID, Name: members[i].Name, DisplayName: members[i].DisplayName, diff --git a/main.go b/main.go index 61cac6b..067aafd 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,7 @@ import ( "codeberg.org/pronounscc/pronouns.cc/scripts/genkey" "codeberg.org/pronounscc/pronouns.cc/scripts/migrate" "codeberg.org/pronounscc/pronouns.cc/scripts/seeddb" + "codeberg.org/pronounscc/pronouns.cc/scripts/snowflakes" "github.com/urfave/cli/v2" ) @@ -32,6 +33,7 @@ var app = &cli.App{ migrate.Command, seeddb.Command, cleandb.Command, + snowflakes.Command, }, }, { diff --git a/scripts/migrate/021_snowflakes.sql b/scripts/migrate/021_snowflakes.sql new file mode 100644 index 0000000..76777c0 --- /dev/null +++ b/scripts/migrate/021_snowflakes.sql @@ -0,0 +1,13 @@ +-- 2023-08-17: Add snowflake ID columns + +-- +migrate Up + +alter table users add column snowflake_id bigint unique; +alter table members add column snowflake_id bigint unique; +alter table pride_flags add column snowflake_id bigint unique; + +-- +migrate Down + +alter table users drop column snowflake_id; +alter table members drop column snowflake_id; +alter table pride_flags drop column snowflake_id; diff --git a/scripts/snowflakes/main.go b/scripts/snowflakes/main.go new file mode 100644 index 0000000..382035a --- /dev/null +++ b/scripts/snowflakes/main.go @@ -0,0 +1,111 @@ +package snowflakes + +import ( + "os" + "time" + + "codeberg.org/pronounscc/pronouns.cc/backend/common" + "codeberg.org/pronounscc/pronouns.cc/backend/log" + "github.com/georgysavva/scany/v2/pgxscan" + "github.com/jackc/pgx/v5" + "github.com/joho/godotenv" + "github.com/rs/xid" + "github.com/urfave/cli/v2" +) + +var Command = &cli.Command{ + Name: "create-snowflakes", + Usage: "Give all users, members, and flags snowflake IDs.", + Action: run, +} + +func run(c *cli.Context) error { + err := godotenv.Load() + if err != nil { + log.Error("loading .env file:", err) + return err + } + + conn, err := pgx.Connect(c.Context, os.Getenv("DATABASE_URL")) + if err != nil { + log.Error("opening database:", err) + return err + } + defer conn.Close(c.Context) + log.Info("opened database") + + tx, err := conn.Begin(c.Context) + if err != nil { + log.Error("creating transaction:", err) + return err + } + defer tx.Rollback(c.Context) + + var userIDs []xid.ID + err = pgxscan.Select(c.Context, conn, &userIDs, "SELECT id FROM users WHERE snowflake_id IS NULL") + if err != nil { + log.Error("selecting users without snowflake:", err) + return err + } + + t := time.Now() + for _, userID := range userIDs { + t := userID.Time() + snowflake := common.UserID(common.GenerateIDWithTime(t)) + + _, err = tx.Exec(c.Context, "UPDATE users SET snowflake_id = $1 WHERE id = $2", snowflake, userID) + if err != nil { + log.Errorf("updating user with ID %v: %v", userID, err) + return err + } + } + log.Infof("updated %v users in %v", len(userIDs), time.Since(t)) + + var memberIDs []xid.ID + err = pgxscan.Select(c.Context, conn, &memberIDs, "SELECT id FROM members WHERE snowflake_id IS NULL") + if err != nil { + log.Error("selecting users without snowflake:", err) + return err + } + + t = time.Now() + for _, memberID := range memberIDs { + t := memberID.Time() + snowflake := common.MemberID(common.GenerateIDWithTime(t)) + + _, err = tx.Exec(c.Context, "UPDATE members SET snowflake_id = $1 WHERE id = $2", snowflake, memberID) + if err != nil { + log.Errorf("updating user with ID %v: %v", memberID, err) + return err + } + } + log.Infof("updated %v members in %v", len(memberIDs), time.Since(t)) + + var flagIDs []xid.ID + err = pgxscan.Select(c.Context, conn, &flagIDs, "SELECT id FROM pride_flags WHERE snowflake_id IS NULL") + if err != nil { + log.Error("selecting users without snowflake:", err) + return err + } + + t = time.Now() + for _, flagID := range flagIDs { + t := flagID.Time() + snowflake := common.FlagID(common.GenerateIDWithTime(t)) + + _, err = tx.Exec(c.Context, "UPDATE pride_flags SET snowflake_id = $1 WHERE id = $2", snowflake, flagID) + if err != nil { + log.Errorf("updating user with ID %v: %v", flagID, err) + return err + } + } + log.Infof("updated %v flags in %v", len(flagIDs), time.Since(t)) + + err = tx.Commit(c.Context) + if err != nil { + log.Error("committing transaction:", err) + return err + } + + return nil +}