Initial commit
This commit is contained in:
commit
2fa110e760
28 changed files with 2346 additions and 0 deletions
267
internal/storage/postgresql.go
Normal file
267
internal/storage/postgresql.go
Normal file
|
@ -0,0 +1,267 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/jackc/pgerrcode"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"ms-auth/internal/lib"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
type PostgresqlStorage struct {
|
||||
db *sqlx.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewUserStorage(dsn string, logger *zap.Logger) *PostgresqlStorage {
|
||||
db, err := sqlx.Connect("pgx", dsn)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
return &PostgresqlStorage{db: db, logger: logger}
|
||||
}
|
||||
|
||||
func (storage *PostgresqlStorage) Stop() error {
|
||||
return storage.db.Close()
|
||||
}
|
||||
|
||||
const (
|
||||
shortUserLifetime = time.Hour * 24 * 30
|
||||
defaultUserLifetime = time.Hour * 24 * 365 * 100
|
||||
)
|
||||
|
||||
type User struct {
|
||||
Id int32 `db:"id"`
|
||||
|
||||
Username string `db:"username"`
|
||||
HashedPassword [60]byte `db:"hashed_pwd"`
|
||||
|
||||
Email *string `db:"email"`
|
||||
|
||||
ExpiresAt time.Time `db:"expires_at"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
|
||||
Role int32 `db:"role"`
|
||||
}
|
||||
|
||||
func (user *User) IsAdmin() bool {
|
||||
return lib.IsAdmin(user.Role)
|
||||
}
|
||||
|
||||
func (user *User) IsModerator() bool {
|
||||
return lib.IsModerator(user.Role)
|
||||
}
|
||||
|
||||
func (user *User) IsParticipant() bool {
|
||||
return lib.IsParticipant(user.Role)
|
||||
}
|
||||
|
||||
func (user *User) IsSpectator() bool {
|
||||
return lib.IsSpectator(user.Role)
|
||||
}
|
||||
|
||||
func (user *User) AtLeast(role int32) bool {
|
||||
return user.Role >= role
|
||||
}
|
||||
|
||||
func (user *User) ComparePassword(password string) error {
|
||||
if bcrypt.CompareHashAndPassword(user.HashedPassword[:], []byte(password)) != nil {
|
||||
return lib.ErrBadHandleOrPassword
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (storage *PostgresqlStorage) CreateUser(
|
||||
ctx context.Context,
|
||||
username string,
|
||||
password string,
|
||||
email *string,
|
||||
expiresAt *time.Time,
|
||||
role *int32,
|
||||
) (*int32, error) {
|
||||
if err := lib.ValidUsername(username); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := lib.ValidPassword(password); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if email != nil {
|
||||
if err := lib.ValidEmail(*email); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if role != nil {
|
||||
if err := lib.ValidRole(*role); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
username = strings.ToLower(username)
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
storage.logger.Error(err.Error())
|
||||
return nil, lib.ErrInternal
|
||||
}
|
||||
now := time.Now()
|
||||
username = strings.ToLower(username)
|
||||
if email != nil {
|
||||
*email = strings.ToLower(*email)
|
||||
}
|
||||
if role == nil {
|
||||
role = lib.AsInt32P(lib.RoleSpectator)
|
||||
}
|
||||
if expiresAt == nil {
|
||||
if email == nil {
|
||||
expiresAt = lib.AsTimeP(now.Add(shortUserLifetime))
|
||||
} else {
|
||||
expiresAt = lib.AsTimeP(now.Add(defaultUserLifetime))
|
||||
}
|
||||
}
|
||||
|
||||
query := storage.db.Rebind(`
|
||||
INSERT INTO users
|
||||
(username, hashed_pwd, email, expires_at, role)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
RETURNING id
|
||||
`)
|
||||
|
||||
rows, err := storage.db.QueryxContext(ctx, query, username, hashedPassword, email, expiresAt, role)
|
||||
if err != nil {
|
||||
return nil, storage.handlePgErr(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var id int32
|
||||
err = rows.StructScan(&id)
|
||||
if err != nil {
|
||||
return nil, storage.handlePgErr(err)
|
||||
}
|
||||
return &id, nil
|
||||
}
|
||||
func (storage *PostgresqlStorage) ReadUserByEmail(ctx context.Context, email string) (*User, error) {
|
||||
if err := lib.ValidEmail(email); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
email = strings.ToLower(email)
|
||||
|
||||
var user User
|
||||
query := storage.db.Rebind("SELECT * from users WHERE email=? LIMIT 1")
|
||||
err := storage.db.GetContext(ctx, &user, query, email)
|
||||
if err != nil {
|
||||
return nil, storage.handlePgErr(err)
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
func (storage *PostgresqlStorage) ReadUserByUsername(ctx context.Context, username string) (*User, error) {
|
||||
if err := lib.ValidUsername(username); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
username = strings.ToLower(username)
|
||||
|
||||
var user User
|
||||
query := storage.db.Rebind("SELECT * from users WHERE username=? LIMIT 1")
|
||||
err := storage.db.GetContext(ctx, &user, query, username)
|
||||
if err != nil {
|
||||
return nil, storage.handlePgErr(err)
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
func (storage *PostgresqlStorage) ReadUserById(ctx context.Context, id int32) (*User, error) {
|
||||
var user User
|
||||
query := storage.db.Rebind("SELECT * from users WHERE id=? LIMIT 1")
|
||||
err := storage.db.GetContext(ctx, &user, query, id)
|
||||
if err != nil {
|
||||
return nil, storage.handlePgErr(err)
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (storage *PostgresqlStorage) UpdateUser(
|
||||
ctx context.Context,
|
||||
id int32,
|
||||
username *string,
|
||||
password *string,
|
||||
email *string,
|
||||
expiresAt *time.Time,
|
||||
role *int32,
|
||||
) error {
|
||||
var err error
|
||||
if username != nil {
|
||||
if err = lib.ValidUsername(*username); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
var hashedPassword []byte
|
||||
if password != nil {
|
||||
if err = lib.ValidPassword(*password); err != nil {
|
||||
return err
|
||||
}
|
||||
hashedPassword, err = bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
storage.logger.Error(err.Error())
|
||||
return lib.ErrInternal
|
||||
}
|
||||
}
|
||||
if email != nil {
|
||||
if err = lib.ValidEmail(*email); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if role != nil {
|
||||
if err = lib.ValidRole(*role); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if username != nil {
|
||||
*username = strings.ToLower(*username)
|
||||
}
|
||||
if email != nil {
|
||||
*email = strings.ToLower(*email)
|
||||
}
|
||||
|
||||
query := storage.db.Rebind(`
|
||||
UPDATE users
|
||||
SET username = COALESCE(?, username),
|
||||
hashed_pwd = COALESCE(?, hashed_pwd),
|
||||
email = COALESCE(?, email),
|
||||
expires_at = COALESCE(?, expires_at),
|
||||
role = COALESCE(?, role)
|
||||
WHERE id = ?`)
|
||||
|
||||
_, err = storage.db.ExecContext(ctx, query, username, hashedPassword, email, expiresAt, role, id)
|
||||
if err != nil {
|
||||
return storage.handlePgErr(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (storage *PostgresqlStorage) DeleteUser(ctx context.Context, id int32) error {
|
||||
query := storage.db.Rebind("UPDATE users SET expired_at=NOW() WHERE id = ?")
|
||||
_, err := storage.db.ExecContext(ctx, query, id)
|
||||
if err != nil {
|
||||
return storage.handlePgErr(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (storage *PostgresqlStorage) handlePgErr(err error) error {
|
||||
var pgErr *pgconn.PgError
|
||||
if !errors.As(err, &pgErr) {
|
||||
storage.logger.DPanic("unexpected error from postgres", zap.String("err", err.Error()))
|
||||
return lib.ErrUnexpected
|
||||
}
|
||||
if pgerrcode.IsIntegrityConstraintViolation(pgErr.Code) {
|
||||
return errors.New("unique key violation") // FIXME
|
||||
}
|
||||
storage.logger.DPanic("unexpected internal error from postgres", zap.String("err", err.Error()))
|
||||
return lib.ErrInternal
|
||||
}
|
332
internal/storage/valkey.go
Normal file
332
internal/storage/valkey.go
Normal file
|
@ -0,0 +1,332 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"go.uber.org/zap"
|
||||
"time"
|
||||
|
||||
"ms-auth/internal/lib"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
"github.com/valkey-io/valkey-go/valkeylock"
|
||||
)
|
||||
|
||||
type ValkeyStorage struct {
|
||||
db valkey.Client
|
||||
locker valkeylock.Locker
|
||||
cfg *lib.Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewValkeyStorage(dsn string, cfg *lib.Config, logger *zap.Logger) *ValkeyStorage {
|
||||
opts, err := valkey.ParseURL(dsn)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
db, err := valkey.NewClient(opts)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
locker, err := valkeylock.NewLocker(valkeylock.LockerOption{
|
||||
ClientOption: opts,
|
||||
KeyMajority: 1,
|
||||
NoLoopTracking: true,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
return &ValkeyStorage{
|
||||
db: db,
|
||||
locker: locker,
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (storage *ValkeyStorage) Stop() error {
|
||||
storage.db.Close()
|
||||
storage.locker.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
sessionLifetime = time.Minute * 40
|
||||
confirmationLifetime = time.Hour * 5
|
||||
)
|
||||
|
||||
func (storage *ValkeyStorage) CreateSession(
|
||||
ctx context.Context,
|
||||
user_id int32,
|
||||
) error {
|
||||
session := NewSession(user_id)
|
||||
|
||||
resp := storage.db.Do(ctx, storage.db.
|
||||
B().Set().
|
||||
Key(string(*session.UserId)).
|
||||
Value(*session.Id).
|
||||
Nx().
|
||||
Exat(time.Now().Add(sessionLifetime)).
|
||||
Build(),
|
||||
)
|
||||
|
||||
if err := resp.Error(); err != nil {
|
||||
storage.logger.Error(err.Error())
|
||||
return lib.ErrInternal
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (storage *ValkeyStorage) ReadSessionByToken(ctx context.Context, token string) (*Session, error) {
|
||||
session, err := Parse(token, storage.cfg.JWTSecret)
|
||||
if err != nil {
|
||||
storage.logger.Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
real_session, err := storage.ReadSessionByUserId(ctx, *session.UserId)
|
||||
if err != nil {
|
||||
storage.logger.Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if *session.Id != *real_session.Id {
|
||||
storage.logger.Error(err.Error())
|
||||
return nil, lib.ErrInternal
|
||||
}
|
||||
|
||||
return session, err
|
||||
}
|
||||
|
||||
func (storage *ValkeyStorage) ReadSessionByUserId(ctx context.Context, user_id int32) (*Session, error) {
|
||||
resp := storage.db.Do(ctx, storage.db.B().Get().Key(string(user_id)).Build())
|
||||
if err := resp.Error(); err != nil {
|
||||
storage.logger.Error(err.Error())
|
||||
return nil, lib.ErrInternal
|
||||
}
|
||||
|
||||
id, err := resp.ToString()
|
||||
if err != nil {
|
||||
storage.logger.Error(err.Error())
|
||||
return nil, lib.ErrInternal
|
||||
}
|
||||
|
||||
return &Session{
|
||||
Id: &id,
|
||||
UserId: &user_id,
|
||||
}, err
|
||||
}
|
||||
|
||||
func (storage *ValkeyStorage) UpdateSession(ctx context.Context, session *Session) error {
|
||||
resp := storage.db.Do(ctx, storage.db.
|
||||
B().Set().
|
||||
Key(string(*session.UserId)).
|
||||
Value(*session.Id).
|
||||
Xx().
|
||||
Exat(time.Now().Add(sessionLifetime)).
|
||||
Build(),
|
||||
)
|
||||
|
||||
if err := resp.Error(); err != nil {
|
||||
storage.logger.Error(err.Error())
|
||||
return lib.ErrInternal
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (storage *ValkeyStorage) DeleteSessionByToken(ctx context.Context, token string) error {
|
||||
session, err := Parse(token, storage.cfg.JWTSecret)
|
||||
if err != nil {
|
||||
storage.logger.Error(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
err = storage.DeleteSessionByUserId(ctx, *session.UserId)
|
||||
if err != nil {
|
||||
storage.logger.Error(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (storage *ValkeyStorage) DeleteSessionByUserId(ctx context.Context, user_id int32) error {
|
||||
resp := storage.db.Do(ctx, storage.db.
|
||||
B().Del().
|
||||
Key(string(user_id)).
|
||||
Build(),
|
||||
)
|
||||
|
||||
if err := resp.Error(); err != nil {
|
||||
storage.logger.Error(err.Error())
|
||||
return lib.ErrInternal
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (storage *ValkeyStorage) CreateConfirmation(ctx context.Context, conf *Confirmation) error {
|
||||
resp := storage.db.Do(ctx, storage.db.
|
||||
B().Set().
|
||||
Key(*conf.Id).
|
||||
Value(string(conf.JSON())).
|
||||
Exat(time.Now().Add(confirmationLifetime)).
|
||||
Build(),
|
||||
)
|
||||
|
||||
if err := resp.Error(); err != nil {
|
||||
storage.logger.Error(err.Error())
|
||||
return lib.ErrInternal
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (storage *ValkeyStorage) ReadConfirmation(ctx context.Context, conf_id string) (*Confirmation, error) {
|
||||
resp := storage.db.Do(ctx, storage.db.
|
||||
B().Get().
|
||||
Key(conf_id).
|
||||
Build(),
|
||||
)
|
||||
|
||||
if err := resp.Error(); err != nil {
|
||||
storage.logger.Error(err.Error())
|
||||
return nil, lib.ErrInternal
|
||||
}
|
||||
|
||||
b, err := resp.AsBytes()
|
||||
if err != nil {
|
||||
storage.logger.Error(err.Error())
|
||||
return nil, lib.ErrInternal
|
||||
}
|
||||
|
||||
var conf Confirmation
|
||||
err = json.Unmarshal(b, &conf)
|
||||
if err != nil {
|
||||
storage.logger.Error(err.Error())
|
||||
return nil, lib.ErrInternal
|
||||
}
|
||||
|
||||
return &conf, nil
|
||||
}
|
||||
|
||||
func (storage *ValkeyStorage) DeleteConfirmation(ctx context.Context, conf_id string) error {
|
||||
resp := storage.db.Do(ctx, storage.db.
|
||||
B().Del().
|
||||
Key(conf_id).
|
||||
Build(),
|
||||
)
|
||||
|
||||
if err := resp.Error(); err != nil {
|
||||
storage.logger.Error(err.Error())
|
||||
return lib.ErrInternal
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrBadSession = errors.New("bad session")
|
||||
ErrBadConfirmation = errors.New("bad confirmation")
|
||||
)
|
||||
|
||||
type Confirmation struct {
|
||||
Id *string `json:"id"`
|
||||
UserId *int32 `json:"user_id,omitempty"`
|
||||
Email *string `json:"email"`
|
||||
}
|
||||
|
||||
func NewConfirmation(userId *int32, email string) (*Confirmation, error) {
|
||||
c := &Confirmation{
|
||||
Id: lib.AsStringP(uuid.NewString()),
|
||||
UserId: userId,
|
||||
Email: &email,
|
||||
}
|
||||
|
||||
if err := c.Valid(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Confirmation) Valid() error {
|
||||
if c.Id == nil {
|
||||
return ErrBadConfirmation
|
||||
}
|
||||
// FIXME
|
||||
// if c.userId == nil {
|
||||
// return ErrBadConfirmation
|
||||
// }
|
||||
if c.Email == nil {
|
||||
return ErrBadConfirmation
|
||||
}
|
||||
if err := lib.ValidEmail(*c.Email); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Confirmation) JSON() []byte {
|
||||
b, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
Id *string
|
||||
UserId *int32
|
||||
}
|
||||
|
||||
func NewSession(userId int32) *Session {
|
||||
return &Session{
|
||||
Id: lib.AsStringP(uuid.NewString()),
|
||||
UserId: &userId,
|
||||
}
|
||||
}
|
||||
|
||||
func (s Session) Valid() error {
|
||||
if s.Id == nil {
|
||||
return ErrBadSession
|
||||
}
|
||||
if s.UserId == nil {
|
||||
return ErrBadSession
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s Session) Token(secret string) (string, error) {
|
||||
if err := s.Valid(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, s)
|
||||
str, err := refreshToken.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
return "", ErrBadSession
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
func Parse(tkn string, secret string) (*Session, error) {
|
||||
parsedToken, err := jwt.ParseWithClaims(tkn, &Session{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(secret), nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, ErrBadSession
|
||||
}
|
||||
session := parsedToken.Claims.(*Session)
|
||||
if err := session.Valid(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue