package storage import ( "context" "errors" "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v5/pgconn" "go.uber.org/zap" "golang.org/x/crypto/bcrypt" "ms-auth/internal/lib" "strings" "time" "github.com/jmoiron/sqlx" ) type PostgresqlStorage struct { db *sqlx.DB logger *zap.Logger } func NewUserStorage(dsn string, logger *zap.Logger) *PostgresqlStorage { db, err := sqlx.Connect("pgx", dsn) if err != nil { panic(err.Error()) } return &PostgresqlStorage{db: db, logger: logger} } func (storage *PostgresqlStorage) Stop() error { return storage.db.Close() } const ( shortUserLifetime = time.Hour * 24 * 30 defaultUserLifetime = time.Hour * 24 * 365 * 100 ) type User struct { Id int32 `db:"id"` Username string `db:"username"` HashedPassword [60]byte `db:"hashed_pwd"` Email *string `db:"email"` ExpiresAt time.Time `db:"expires_at"` CreatedAt time.Time `db:"created_at"` Role int32 `db:"role"` } func (user *User) IsAdmin() bool { return lib.IsAdmin(user.Role) } func (user *User) IsModerator() bool { return lib.IsModerator(user.Role) } func (user *User) IsParticipant() bool { return lib.IsParticipant(user.Role) } func (user *User) IsSpectator() bool { return lib.IsSpectator(user.Role) } func (user *User) AtLeast(role int32) bool { return user.Role >= role } func (user *User) ComparePassword(password string) error { if bcrypt.CompareHashAndPassword(user.HashedPassword[:], []byte(password)) != nil { return lib.ErrBadHandleOrPassword } return nil } func (storage *PostgresqlStorage) CreateUser( ctx context.Context, username string, password string, email *string, expiresAt *time.Time, role *int32, ) (*int32, error) { if err := lib.ValidUsername(username); err != nil { return nil, err } if err := lib.ValidPassword(password); err != nil { return nil, err } if email != nil { if err := lib.ValidEmail(*email); err != nil { return nil, err } } if role != nil { if err := lib.ValidRole(*role); err != nil { return nil, err } } username = strings.ToLower(username) hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { storage.logger.Error(err.Error()) return nil, lib.ErrInternal } now := time.Now() username = strings.ToLower(username) if email != nil { *email = strings.ToLower(*email) } if role == nil { role = lib.AsInt32P(lib.RoleSpectator) } if expiresAt == nil { if email == nil { expiresAt = lib.AsTimeP(now.Add(shortUserLifetime)) } else { expiresAt = lib.AsTimeP(now.Add(defaultUserLifetime)) } } query := storage.db.Rebind(` INSERT INTO users (username, hashed_pwd, email, expires_at, role) VALUES (?, ?, ?, ?, ?) RETURNING id `) rows, err := storage.db.QueryxContext(ctx, query, username, hashedPassword, email, expiresAt, role) if err != nil { return nil, storage.handlePgErr(err) } defer rows.Close() var id int32 err = rows.StructScan(&id) if err != nil { return nil, storage.handlePgErr(err) } return &id, nil } func (storage *PostgresqlStorage) ReadUserByEmail(ctx context.Context, email string) (*User, error) { if err := lib.ValidEmail(email); err != nil { return nil, err } email = strings.ToLower(email) var user User query := storage.db.Rebind("SELECT * from users WHERE email=? LIMIT 1") err := storage.db.GetContext(ctx, &user, query, email) if err != nil { return nil, storage.handlePgErr(err) } return &user, nil } func (storage *PostgresqlStorage) ReadUserByUsername(ctx context.Context, username string) (*User, error) { if err := lib.ValidUsername(username); err != nil { return nil, err } username = strings.ToLower(username) var user User query := storage.db.Rebind("SELECT * from users WHERE username=? LIMIT 1") err := storage.db.GetContext(ctx, &user, query, username) if err != nil { return nil, storage.handlePgErr(err) } return &user, nil } func (storage *PostgresqlStorage) ReadUserById(ctx context.Context, id int32) (*User, error) { var user User query := storage.db.Rebind("SELECT * from users WHERE id=? LIMIT 1") err := storage.db.GetContext(ctx, &user, query, id) if err != nil { return nil, storage.handlePgErr(err) } return &user, nil } func (storage *PostgresqlStorage) UpdateUser( ctx context.Context, id int32, username *string, password *string, email *string, expiresAt *time.Time, role *int32, ) error { var err error if username != nil { if err = lib.ValidUsername(*username); err != nil { return err } } var hashedPassword []byte if password != nil { if err = lib.ValidPassword(*password); err != nil { return err } hashedPassword, err = bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) if err != nil { storage.logger.Error(err.Error()) return lib.ErrInternal } } if email != nil { if err = lib.ValidEmail(*email); err != nil { return err } } if role != nil { if err = lib.ValidRole(*role); err != nil { return err } } if username != nil { *username = strings.ToLower(*username) } if email != nil { *email = strings.ToLower(*email) } query := storage.db.Rebind(` UPDATE users SET username = COALESCE(?, username), hashed_pwd = COALESCE(?, hashed_pwd), email = COALESCE(?, email), expires_at = COALESCE(?, expires_at), role = COALESCE(?, role) WHERE id = ?`) _, err = storage.db.ExecContext(ctx, query, username, hashedPassword, email, expiresAt, role, id) if err != nil { return storage.handlePgErr(err) } return nil } func (storage *PostgresqlStorage) DeleteUser(ctx context.Context, id int32) error { query := storage.db.Rebind("UPDATE users SET expired_at=NOW() WHERE id = ?") _, err := storage.db.ExecContext(ctx, query, id) if err != nil { return storage.handlePgErr(err) } return nil } func (storage *PostgresqlStorage) handlePgErr(err error) error { var pgErr *pgconn.PgError if !errors.As(err, &pgErr) { storage.logger.DPanic("unexpected error from postgres", zap.String("err", err.Error())) return lib.ErrUnexpected } if pgerrcode.IsIntegrityConstraintViolation(pgErr.Code) { return errors.New("unique key violation") // FIXME } storage.logger.DPanic("unexpected internal error from postgres", zap.String("err", err.Error())) return lib.ErrInternal }