268 lines
6.2 KiB
Go
268 lines
6.2 KiB
Go
|
package storage
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"github.com/jackc/pgerrcode"
|
||
|
"github.com/jackc/pgx/v5/pgconn"
|
||
|
"go.uber.org/zap"
|
||
|
"golang.org/x/crypto/bcrypt"
|
||
|
"ms-auth/internal/lib"
|
||
|
"strings"
|
||
|
"time"
|
||
|
|
||
|
"github.com/jmoiron/sqlx"
|
||
|
)
|
||
|
|
||
|
type PostgresqlStorage struct {
|
||
|
db *sqlx.DB
|
||
|
logger *zap.Logger
|
||
|
}
|
||
|
|
||
|
func NewUserStorage(dsn string, logger *zap.Logger) *PostgresqlStorage {
|
||
|
db, err := sqlx.Connect("pgx", dsn)
|
||
|
if err != nil {
|
||
|
panic(err.Error())
|
||
|
}
|
||
|
|
||
|
return &PostgresqlStorage{db: db, logger: logger}
|
||
|
}
|
||
|
|
||
|
func (storage *PostgresqlStorage) Stop() error {
|
||
|
return storage.db.Close()
|
||
|
}
|
||
|
|
||
|
const (
|
||
|
shortUserLifetime = time.Hour * 24 * 30
|
||
|
defaultUserLifetime = time.Hour * 24 * 365 * 100
|
||
|
)
|
||
|
|
||
|
type User struct {
|
||
|
Id int32 `db:"id"`
|
||
|
|
||
|
Username string `db:"username"`
|
||
|
HashedPassword [60]byte `db:"hashed_pwd"`
|
||
|
|
||
|
Email *string `db:"email"`
|
||
|
|
||
|
ExpiresAt time.Time `db:"expires_at"`
|
||
|
CreatedAt time.Time `db:"created_at"`
|
||
|
|
||
|
Role int32 `db:"role"`
|
||
|
}
|
||
|
|
||
|
func (user *User) IsAdmin() bool {
|
||
|
return lib.IsAdmin(user.Role)
|
||
|
}
|
||
|
|
||
|
func (user *User) IsModerator() bool {
|
||
|
return lib.IsModerator(user.Role)
|
||
|
}
|
||
|
|
||
|
func (user *User) IsParticipant() bool {
|
||
|
return lib.IsParticipant(user.Role)
|
||
|
}
|
||
|
|
||
|
func (user *User) IsSpectator() bool {
|
||
|
return lib.IsSpectator(user.Role)
|
||
|
}
|
||
|
|
||
|
func (user *User) AtLeast(role int32) bool {
|
||
|
return user.Role >= role
|
||
|
}
|
||
|
|
||
|
func (user *User) ComparePassword(password string) error {
|
||
|
if bcrypt.CompareHashAndPassword(user.HashedPassword[:], []byte(password)) != nil {
|
||
|
return lib.ErrBadHandleOrPassword
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (storage *PostgresqlStorage) CreateUser(
|
||
|
ctx context.Context,
|
||
|
username string,
|
||
|
password string,
|
||
|
email *string,
|
||
|
expiresAt *time.Time,
|
||
|
role *int32,
|
||
|
) (*int32, error) {
|
||
|
if err := lib.ValidUsername(username); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if err := lib.ValidPassword(password); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if email != nil {
|
||
|
if err := lib.ValidEmail(*email); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
if role != nil {
|
||
|
if err := lib.ValidRole(*role); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
username = strings.ToLower(username)
|
||
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||
|
if err != nil {
|
||
|
storage.logger.Error(err.Error())
|
||
|
return nil, lib.ErrInternal
|
||
|
}
|
||
|
now := time.Now()
|
||
|
username = strings.ToLower(username)
|
||
|
if email != nil {
|
||
|
*email = strings.ToLower(*email)
|
||
|
}
|
||
|
if role == nil {
|
||
|
role = lib.AsInt32P(lib.RoleSpectator)
|
||
|
}
|
||
|
if expiresAt == nil {
|
||
|
if email == nil {
|
||
|
expiresAt = lib.AsTimeP(now.Add(shortUserLifetime))
|
||
|
} else {
|
||
|
expiresAt = lib.AsTimeP(now.Add(defaultUserLifetime))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
query := storage.db.Rebind(`
|
||
|
INSERT INTO users
|
||
|
(username, hashed_pwd, email, expires_at, role)
|
||
|
VALUES (?, ?, ?, ?, ?)
|
||
|
RETURNING id
|
||
|
`)
|
||
|
|
||
|
rows, err := storage.db.QueryxContext(ctx, query, username, hashedPassword, email, expiresAt, role)
|
||
|
if err != nil {
|
||
|
return nil, storage.handlePgErr(err)
|
||
|
}
|
||
|
defer rows.Close()
|
||
|
var id int32
|
||
|
err = rows.StructScan(&id)
|
||
|
if err != nil {
|
||
|
return nil, storage.handlePgErr(err)
|
||
|
}
|
||
|
return &id, nil
|
||
|
}
|
||
|
func (storage *PostgresqlStorage) ReadUserByEmail(ctx context.Context, email string) (*User, error) {
|
||
|
if err := lib.ValidEmail(email); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
email = strings.ToLower(email)
|
||
|
|
||
|
var user User
|
||
|
query := storage.db.Rebind("SELECT * from users WHERE email=? LIMIT 1")
|
||
|
err := storage.db.GetContext(ctx, &user, query, email)
|
||
|
if err != nil {
|
||
|
return nil, storage.handlePgErr(err)
|
||
|
}
|
||
|
return &user, nil
|
||
|
}
|
||
|
func (storage *PostgresqlStorage) ReadUserByUsername(ctx context.Context, username string) (*User, error) {
|
||
|
if err := lib.ValidUsername(username); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
username = strings.ToLower(username)
|
||
|
|
||
|
var user User
|
||
|
query := storage.db.Rebind("SELECT * from users WHERE username=? LIMIT 1")
|
||
|
err := storage.db.GetContext(ctx, &user, query, username)
|
||
|
if err != nil {
|
||
|
return nil, storage.handlePgErr(err)
|
||
|
}
|
||
|
return &user, nil
|
||
|
}
|
||
|
func (storage *PostgresqlStorage) ReadUserById(ctx context.Context, id int32) (*User, error) {
|
||
|
var user User
|
||
|
query := storage.db.Rebind("SELECT * from users WHERE id=? LIMIT 1")
|
||
|
err := storage.db.GetContext(ctx, &user, query, id)
|
||
|
if err != nil {
|
||
|
return nil, storage.handlePgErr(err)
|
||
|
}
|
||
|
return &user, nil
|
||
|
}
|
||
|
|
||
|
func (storage *PostgresqlStorage) UpdateUser(
|
||
|
ctx context.Context,
|
||
|
id int32,
|
||
|
username *string,
|
||
|
password *string,
|
||
|
email *string,
|
||
|
expiresAt *time.Time,
|
||
|
role *int32,
|
||
|
) error {
|
||
|
var err error
|
||
|
if username != nil {
|
||
|
if err = lib.ValidUsername(*username); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
var hashedPassword []byte
|
||
|
if password != nil {
|
||
|
if err = lib.ValidPassword(*password); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
hashedPassword, err = bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
|
||
|
if err != nil {
|
||
|
storage.logger.Error(err.Error())
|
||
|
return lib.ErrInternal
|
||
|
}
|
||
|
}
|
||
|
if email != nil {
|
||
|
if err = lib.ValidEmail(*email); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
if role != nil {
|
||
|
if err = lib.ValidRole(*role); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if username != nil {
|
||
|
*username = strings.ToLower(*username)
|
||
|
}
|
||
|
if email != nil {
|
||
|
*email = strings.ToLower(*email)
|
||
|
}
|
||
|
|
||
|
query := storage.db.Rebind(`
|
||
|
UPDATE users
|
||
|
SET username = COALESCE(?, username),
|
||
|
hashed_pwd = COALESCE(?, hashed_pwd),
|
||
|
email = COALESCE(?, email),
|
||
|
expires_at = COALESCE(?, expires_at),
|
||
|
role = COALESCE(?, role)
|
||
|
WHERE id = ?`)
|
||
|
|
||
|
_, err = storage.db.ExecContext(ctx, query, username, hashedPassword, email, expiresAt, role, id)
|
||
|
if err != nil {
|
||
|
return storage.handlePgErr(err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
func (storage *PostgresqlStorage) DeleteUser(ctx context.Context, id int32) error {
|
||
|
query := storage.db.Rebind("UPDATE users SET expired_at=NOW() WHERE id = ?")
|
||
|
_, err := storage.db.ExecContext(ctx, query, id)
|
||
|
if err != nil {
|
||
|
return storage.handlePgErr(err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (storage *PostgresqlStorage) handlePgErr(err error) error {
|
||
|
var pgErr *pgconn.PgError
|
||
|
if !errors.As(err, &pgErr) {
|
||
|
storage.logger.DPanic("unexpected error from postgres", zap.String("err", err.Error()))
|
||
|
return lib.ErrUnexpected
|
||
|
}
|
||
|
if pgerrcode.IsIntegrityConstraintViolation(pgErr.Code) {
|
||
|
return errors.New("unique key violation") // FIXME
|
||
|
}
|
||
|
storage.logger.DPanic("unexpected internal error from postgres", zap.String("err", err.Error()))
|
||
|
return lib.ErrInternal
|
||
|
}
|