333 lines
6.5 KiB
Go
333 lines
6.5 KiB
Go
|
package storage
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"go.uber.org/zap"
|
||
|
"time"
|
||
|
|
||
|
"ms-auth/internal/lib"
|
||
|
|
||
|
"github.com/golang-jwt/jwt"
|
||
|
"github.com/google/uuid"
|
||
|
"github.com/valkey-io/valkey-go"
|
||
|
"github.com/valkey-io/valkey-go/valkeylock"
|
||
|
)
|
||
|
|
||
|
type ValkeyStorage struct {
|
||
|
db valkey.Client
|
||
|
locker valkeylock.Locker
|
||
|
cfg *lib.Config
|
||
|
logger *zap.Logger
|
||
|
}
|
||
|
|
||
|
func NewValkeyStorage(dsn string, cfg *lib.Config, logger *zap.Logger) *ValkeyStorage {
|
||
|
opts, err := valkey.ParseURL(dsn)
|
||
|
if err != nil {
|
||
|
panic(err.Error())
|
||
|
}
|
||
|
|
||
|
db, err := valkey.NewClient(opts)
|
||
|
if err != nil {
|
||
|
panic(err.Error())
|
||
|
}
|
||
|
|
||
|
locker, err := valkeylock.NewLocker(valkeylock.LockerOption{
|
||
|
ClientOption: opts,
|
||
|
KeyMajority: 1,
|
||
|
NoLoopTracking: true,
|
||
|
})
|
||
|
if err != nil {
|
||
|
panic(err.Error())
|
||
|
}
|
||
|
|
||
|
return &ValkeyStorage{
|
||
|
db: db,
|
||
|
locker: locker,
|
||
|
cfg: cfg,
|
||
|
logger: logger,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (storage *ValkeyStorage) Stop() error {
|
||
|
storage.db.Close()
|
||
|
storage.locker.Close()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
const (
|
||
|
sessionLifetime = time.Minute * 40
|
||
|
confirmationLifetime = time.Hour * 5
|
||
|
)
|
||
|
|
||
|
func (storage *ValkeyStorage) CreateSession(
|
||
|
ctx context.Context,
|
||
|
user_id int32,
|
||
|
) error {
|
||
|
session := NewSession(user_id)
|
||
|
|
||
|
resp := storage.db.Do(ctx, storage.db.
|
||
|
B().Set().
|
||
|
Key(string(*session.UserId)).
|
||
|
Value(*session.Id).
|
||
|
Nx().
|
||
|
Exat(time.Now().Add(sessionLifetime)).
|
||
|
Build(),
|
||
|
)
|
||
|
|
||
|
if err := resp.Error(); err != nil {
|
||
|
storage.logger.Error(err.Error())
|
||
|
return lib.ErrInternal
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (storage *ValkeyStorage) ReadSessionByToken(ctx context.Context, token string) (*Session, error) {
|
||
|
session, err := Parse(token, storage.cfg.JWTSecret)
|
||
|
if err != nil {
|
||
|
storage.logger.Error(err.Error())
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
real_session, err := storage.ReadSessionByUserId(ctx, *session.UserId)
|
||
|
if err != nil {
|
||
|
storage.logger.Error(err.Error())
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if *session.Id != *real_session.Id {
|
||
|
storage.logger.Error(err.Error())
|
||
|
return nil, lib.ErrInternal
|
||
|
}
|
||
|
|
||
|
return session, err
|
||
|
}
|
||
|
|
||
|
func (storage *ValkeyStorage) ReadSessionByUserId(ctx context.Context, user_id int32) (*Session, error) {
|
||
|
resp := storage.db.Do(ctx, storage.db.B().Get().Key(string(user_id)).Build())
|
||
|
if err := resp.Error(); err != nil {
|
||
|
storage.logger.Error(err.Error())
|
||
|
return nil, lib.ErrInternal
|
||
|
}
|
||
|
|
||
|
id, err := resp.ToString()
|
||
|
if err != nil {
|
||
|
storage.logger.Error(err.Error())
|
||
|
return nil, lib.ErrInternal
|
||
|
}
|
||
|
|
||
|
return &Session{
|
||
|
Id: &id,
|
||
|
UserId: &user_id,
|
||
|
}, err
|
||
|
}
|
||
|
|
||
|
func (storage *ValkeyStorage) UpdateSession(ctx context.Context, session *Session) error {
|
||
|
resp := storage.db.Do(ctx, storage.db.
|
||
|
B().Set().
|
||
|
Key(string(*session.UserId)).
|
||
|
Value(*session.Id).
|
||
|
Xx().
|
||
|
Exat(time.Now().Add(sessionLifetime)).
|
||
|
Build(),
|
||
|
)
|
||
|
|
||
|
if err := resp.Error(); err != nil {
|
||
|
storage.logger.Error(err.Error())
|
||
|
return lib.ErrInternal
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (storage *ValkeyStorage) DeleteSessionByToken(ctx context.Context, token string) error {
|
||
|
session, err := Parse(token, storage.cfg.JWTSecret)
|
||
|
if err != nil {
|
||
|
storage.logger.Error(err.Error())
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
err = storage.DeleteSessionByUserId(ctx, *session.UserId)
|
||
|
if err != nil {
|
||
|
storage.logger.Error(err.Error())
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (storage *ValkeyStorage) DeleteSessionByUserId(ctx context.Context, user_id int32) error {
|
||
|
resp := storage.db.Do(ctx, storage.db.
|
||
|
B().Del().
|
||
|
Key(string(user_id)).
|
||
|
Build(),
|
||
|
)
|
||
|
|
||
|
if err := resp.Error(); err != nil {
|
||
|
storage.logger.Error(err.Error())
|
||
|
return lib.ErrInternal
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (storage *ValkeyStorage) CreateConfirmation(ctx context.Context, conf *Confirmation) error {
|
||
|
resp := storage.db.Do(ctx, storage.db.
|
||
|
B().Set().
|
||
|
Key(*conf.Id).
|
||
|
Value(string(conf.JSON())).
|
||
|
Exat(time.Now().Add(confirmationLifetime)).
|
||
|
Build(),
|
||
|
)
|
||
|
|
||
|
if err := resp.Error(); err != nil {
|
||
|
storage.logger.Error(err.Error())
|
||
|
return lib.ErrInternal
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (storage *ValkeyStorage) ReadConfirmation(ctx context.Context, conf_id string) (*Confirmation, error) {
|
||
|
resp := storage.db.Do(ctx, storage.db.
|
||
|
B().Get().
|
||
|
Key(conf_id).
|
||
|
Build(),
|
||
|
)
|
||
|
|
||
|
if err := resp.Error(); err != nil {
|
||
|
storage.logger.Error(err.Error())
|
||
|
return nil, lib.ErrInternal
|
||
|
}
|
||
|
|
||
|
b, err := resp.AsBytes()
|
||
|
if err != nil {
|
||
|
storage.logger.Error(err.Error())
|
||
|
return nil, lib.ErrInternal
|
||
|
}
|
||
|
|
||
|
var conf Confirmation
|
||
|
err = json.Unmarshal(b, &conf)
|
||
|
if err != nil {
|
||
|
storage.logger.Error(err.Error())
|
||
|
return nil, lib.ErrInternal
|
||
|
}
|
||
|
|
||
|
return &conf, nil
|
||
|
}
|
||
|
|
||
|
func (storage *ValkeyStorage) DeleteConfirmation(ctx context.Context, conf_id string) error {
|
||
|
resp := storage.db.Do(ctx, storage.db.
|
||
|
B().Del().
|
||
|
Key(conf_id).
|
||
|
Build(),
|
||
|
)
|
||
|
|
||
|
if err := resp.Error(); err != nil {
|
||
|
storage.logger.Error(err.Error())
|
||
|
return lib.ErrInternal
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
ErrBadSession = errors.New("bad session")
|
||
|
ErrBadConfirmation = errors.New("bad confirmation")
|
||
|
)
|
||
|
|
||
|
type Confirmation struct {
|
||
|
Id *string `json:"id"`
|
||
|
UserId *int32 `json:"user_id,omitempty"`
|
||
|
Email *string `json:"email"`
|
||
|
}
|
||
|
|
||
|
func NewConfirmation(userId *int32, email string) (*Confirmation, error) {
|
||
|
c := &Confirmation{
|
||
|
Id: lib.AsStringP(uuid.NewString()),
|
||
|
UserId: userId,
|
||
|
Email: &email,
|
||
|
}
|
||
|
|
||
|
if err := c.Valid(); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return c, nil
|
||
|
}
|
||
|
|
||
|
func (c *Confirmation) Valid() error {
|
||
|
if c.Id == nil {
|
||
|
return ErrBadConfirmation
|
||
|
}
|
||
|
// FIXME
|
||
|
// if c.userId == nil {
|
||
|
// return ErrBadConfirmation
|
||
|
// }
|
||
|
if c.Email == nil {
|
||
|
return ErrBadConfirmation
|
||
|
}
|
||
|
if err := lib.ValidEmail(*c.Email); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *Confirmation) JSON() []byte {
|
||
|
b, err := json.Marshal(c)
|
||
|
if err != nil {
|
||
|
panic(err.Error())
|
||
|
}
|
||
|
return b
|
||
|
}
|
||
|
|
||
|
type Session struct {
|
||
|
Id *string
|
||
|
UserId *int32
|
||
|
}
|
||
|
|
||
|
func NewSession(userId int32) *Session {
|
||
|
return &Session{
|
||
|
Id: lib.AsStringP(uuid.NewString()),
|
||
|
UserId: &userId,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s Session) Valid() error {
|
||
|
if s.Id == nil {
|
||
|
return ErrBadSession
|
||
|
}
|
||
|
if s.UserId == nil {
|
||
|
return ErrBadSession
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s Session) Token(secret string) (string, error) {
|
||
|
if err := s.Valid(); err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, s)
|
||
|
str, err := refreshToken.SignedString([]byte(secret))
|
||
|
if err != nil {
|
||
|
return "", ErrBadSession
|
||
|
}
|
||
|
return str, nil
|
||
|
}
|
||
|
|
||
|
func Parse(tkn string, secret string) (*Session, error) {
|
||
|
parsedToken, err := jwt.ParseWithClaims(tkn, &Session{}, func(token *jwt.Token) (interface{}, error) {
|
||
|
return []byte(secret), nil
|
||
|
})
|
||
|
if err != nil {
|
||
|
return nil, ErrBadSession
|
||
|
}
|
||
|
session := parsedToken.Claims.(*Session)
|
||
|
if err := session.Valid(); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return session, nil
|
||
|
}
|