package storage import ( "context" "encoding/json" "errors" "go.uber.org/zap" "time" "ms-auth/internal/lib" "github.com/golang-jwt/jwt" "github.com/google/uuid" "github.com/valkey-io/valkey-go" "github.com/valkey-io/valkey-go/valkeylock" ) type ValkeyStorage struct { db valkey.Client locker valkeylock.Locker cfg *lib.Config logger *zap.Logger } func NewValkeyStorage(dsn string, cfg *lib.Config, logger *zap.Logger) *ValkeyStorage { opts, err := valkey.ParseURL(dsn) if err != nil { panic(err.Error()) } db, err := valkey.NewClient(opts) if err != nil { panic(err.Error()) } locker, err := valkeylock.NewLocker(valkeylock.LockerOption{ ClientOption: opts, KeyMajority: 1, NoLoopTracking: true, }) if err != nil { panic(err.Error()) } return &ValkeyStorage{ db: db, locker: locker, cfg: cfg, logger: logger, } } func (storage *ValkeyStorage) Stop() error { storage.db.Close() storage.locker.Close() return nil } const ( sessionLifetime = time.Minute * 40 confirmationLifetime = time.Hour * 5 ) func (storage *ValkeyStorage) CreateSession( ctx context.Context, user_id int32, ) error { session := NewSession(user_id) resp := storage.db.Do(ctx, storage.db. B().Set(). Key(string(*session.UserId)). Value(*session.Id). Nx(). Exat(time.Now().Add(sessionLifetime)). Build(), ) if err := resp.Error(); err != nil { storage.logger.Error(err.Error()) return lib.ErrInternal } return nil } func (storage *ValkeyStorage) ReadSessionByToken(ctx context.Context, token string) (*Session, error) { session, err := Parse(token, storage.cfg.JWTSecret) if err != nil { storage.logger.Error(err.Error()) return nil, err } real_session, err := storage.ReadSessionByUserId(ctx, *session.UserId) if err != nil { storage.logger.Error(err.Error()) return nil, err } if *session.Id != *real_session.Id { storage.logger.Error(err.Error()) return nil, lib.ErrInternal } return session, err } func (storage *ValkeyStorage) ReadSessionByUserId(ctx context.Context, user_id int32) (*Session, error) { resp := storage.db.Do(ctx, storage.db.B().Get().Key(string(user_id)).Build()) if err := resp.Error(); err != nil { storage.logger.Error(err.Error()) return nil, lib.ErrInternal } id, err := resp.ToString() if err != nil { storage.logger.Error(err.Error()) return nil, lib.ErrInternal } return &Session{ Id: &id, UserId: &user_id, }, err } func (storage *ValkeyStorage) UpdateSession(ctx context.Context, session *Session) error { resp := storage.db.Do(ctx, storage.db. B().Set(). Key(string(*session.UserId)). Value(*session.Id). Xx(). Exat(time.Now().Add(sessionLifetime)). Build(), ) if err := resp.Error(); err != nil { storage.logger.Error(err.Error()) return lib.ErrInternal } return nil } func (storage *ValkeyStorage) DeleteSessionByToken(ctx context.Context, token string) error { session, err := Parse(token, storage.cfg.JWTSecret) if err != nil { storage.logger.Error(err.Error()) return err } err = storage.DeleteSessionByUserId(ctx, *session.UserId) if err != nil { storage.logger.Error(err.Error()) return err } return nil } func (storage *ValkeyStorage) DeleteSessionByUserId(ctx context.Context, user_id int32) error { resp := storage.db.Do(ctx, storage.db. B().Del(). Key(string(user_id)). Build(), ) if err := resp.Error(); err != nil { storage.logger.Error(err.Error()) return lib.ErrInternal } return nil } func (storage *ValkeyStorage) CreateConfirmation(ctx context.Context, conf *Confirmation) error { resp := storage.db.Do(ctx, storage.db. B().Set(). Key(*conf.Id). Value(string(conf.JSON())). Exat(time.Now().Add(confirmationLifetime)). Build(), ) if err := resp.Error(); err != nil { storage.logger.Error(err.Error()) return lib.ErrInternal } return nil } func (storage *ValkeyStorage) ReadConfirmation(ctx context.Context, conf_id string) (*Confirmation, error) { resp := storage.db.Do(ctx, storage.db. B().Get(). Key(conf_id). Build(), ) if err := resp.Error(); err != nil { storage.logger.Error(err.Error()) return nil, lib.ErrInternal } b, err := resp.AsBytes() if err != nil { storage.logger.Error(err.Error()) return nil, lib.ErrInternal } var conf Confirmation err = json.Unmarshal(b, &conf) if err != nil { storage.logger.Error(err.Error()) return nil, lib.ErrInternal } return &conf, nil } func (storage *ValkeyStorage) DeleteConfirmation(ctx context.Context, conf_id string) error { resp := storage.db.Do(ctx, storage.db. B().Del(). Key(conf_id). Build(), ) if err := resp.Error(); err != nil { storage.logger.Error(err.Error()) return lib.ErrInternal } return nil } var ( ErrBadSession = errors.New("bad session") ErrBadConfirmation = errors.New("bad confirmation") ) type Confirmation struct { Id *string `json:"id"` UserId *int32 `json:"user_id,omitempty"` Email *string `json:"email"` } func NewConfirmation(userId *int32, email string) (*Confirmation, error) { c := &Confirmation{ Id: lib.AsStringP(uuid.NewString()), UserId: userId, Email: &email, } if err := c.Valid(); err != nil { return nil, err } return c, nil } func (c *Confirmation) Valid() error { if c.Id == nil { return ErrBadConfirmation } // FIXME // if c.userId == nil { // return ErrBadConfirmation // } if c.Email == nil { return ErrBadConfirmation } if err := lib.ValidEmail(*c.Email); err != nil { return err } return nil } func (c *Confirmation) JSON() []byte { b, err := json.Marshal(c) if err != nil { panic(err.Error()) } return b } type Session struct { Id *string UserId *int32 } func NewSession(userId int32) *Session { return &Session{ Id: lib.AsStringP(uuid.NewString()), UserId: &userId, } } func (s Session) Valid() error { if s.Id == nil { return ErrBadSession } if s.UserId == nil { return ErrBadSession } return nil } func (s Session) Token(secret string) (string, error) { if err := s.Valid(); err != nil { return "", err } refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, s) str, err := refreshToken.SignedString([]byte(secret)) if err != nil { return "", ErrBadSession } return str, nil } func Parse(tkn string, secret string) (*Session, error) { parsedToken, err := jwt.ParseWithClaims(tkn, &Session{}, func(token *jwt.Token) (interface{}, error) { return []byte(secret), nil }) if err != nil { return nil, ErrBadSession } session := parsedToken.Claims.(*Session) if err := session.Valid(); err != nil { return nil, err } return session, nil }