package repository import ( "context" "crypto/sha256" "encoding/hex" "encoding/json" "fmt" "git.sch9.ru/new_gate/ms-auth/internal/models" "git.sch9.ru/new_gate/ms-auth/pkg" "github.com/google/uuid" "github.com/valkey-io/valkey-go" "strconv" "time" ) type ValkeyRepository struct { db valkey.Client } func NewValkeyRepository(db valkey.Client) *ValkeyRepository { return &ValkeyRepository{ db: db, } } const sessionLifetime = time.Minute * 40 func sha256string(s string) string { hasher := sha256.New() hasher.Write([]byte(s)) return hex.EncodeToString(hasher.Sum(nil)) } func (r *ValkeyRepository) CreateSession(ctx context.Context, userId int32, role models.Role, userAgent, ip string) (*models.Session, error) { const op = "ValkeyRepository.CreateSession" session := &models.Session{ Id: uuid.NewString(), UserId: userId, Role: role, CreatedAt: time.Now(), ExpiresAt: time.Now().Add(sessionLifetime), UserAgent: userAgent, Ip: ip, } err := session.Valid() if err != nil { return nil, pkg.Wrap(pkg.ErrInternal, err, op, "validating session") } userIdHash := sha256string(strconv.FormatInt(int64(userId), 10)) sessionIdHash := sha256string(session.Id) sessionData, err := json.Marshal(session) if err != nil { return nil, pkg.Wrap(pkg.ErrInternal, err, op, "marshaling session") } resp := r.db.Do(ctx, r.db. B().Set(). Key(fmt.Sprintf("userid:%s:sessionid:%s", userIdHash, sessionIdHash)). Value(string(sessionData)). Exat(session.ExpiresAt). Build(), ) err = resp.Error() if err != nil { if valkey.IsValkeyNil(err) { return nil, pkg.Wrap(pkg.ErrInternal, err, op, "nil response") } return nil, pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error") } return session, nil } const ( readSessionScript = `local result = redis.call('SCAN', 0, 'MATCH', ARGV[1]) if #result[2] == 0 then return nil else return redis.call('GET', result[2][1]) end` ) func (r *ValkeyRepository) ReadSession(ctx context.Context, sessionId string) (*models.Session, error) { const op = "ValkeyRepository.ReadSession" err := uuid.Validate(sessionId) if err != nil { return nil, pkg.Wrap(pkg.ErrBadInput, err, op, "uuid validation") } sessionIdHash := sha256string(sessionId) resp := valkey.NewLuaScript(readSessionScript).Exec( ctx, r.db, nil, []string{fmt.Sprintf("userid:*:sessionid:%s", sessionIdHash)}, ) if err = resp.Error(); err != nil { if valkey.IsValkeyNil(err) { return nil, pkg.Wrap(pkg.ErrNotFound, err, op, "reading session") } return nil, pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error") } str, err := resp.ToString() if err != nil { return nil, pkg.Wrap(pkg.ErrInternal, err, op, "session storage corrupted") } var session models.Session err = json.Unmarshal([]byte(str), &session) if err != nil { return nil, pkg.Wrap(pkg.ErrInternal, err, op, "session corrupted") } err = session.Valid() if err != nil { return nil, pkg.Wrap(pkg.ErrInternal, err, op, "validating session") } return &session, nil } const ( updateSessionScript = `local result = redis.call('SCAN', 0, 'MATCH', ARGV[1]) return #result[2] > 0 and redis.call('EXPIRE', result[2][1], ARGV[2]) == 1` ) var ( sessionLifetimeString = strconv.Itoa(int(sessionLifetime.Seconds())) ) func (r *ValkeyRepository) UpdateSession(ctx context.Context, sessionId string) error { const op = "ValkeyRepository.UpdateSession" err := uuid.Validate(sessionId) if err != nil { return pkg.Wrap(pkg.ErrBadInput, err, op, "uuid validation") } sessionIdHash := sha256string(sessionId) resp := valkey.NewLuaScript(updateSessionScript).Exec( ctx, r.db, nil, []string{fmt.Sprintf("userid:*:sessionid:%s", sessionIdHash), sessionLifetimeString}, ) err = resp.Error() if err != nil { if valkey.IsValkeyNil(err) { return pkg.Wrap(pkg.ErrBadInput, err, op, "nil response") } return pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error") } return nil } const deleteSessionScript = `local result = redis.call('SCAN', 0, 'MATCH', ARGV[1]) return #result[2] > 0 and redis.call('DEL', result[2][1]) == 1` func (r *ValkeyRepository) DeleteSession(ctx context.Context, sessionId string) error { const op = "ValkeyRepository.DeleteSession" err := uuid.Validate(sessionId) if err != nil { return pkg.Wrap(pkg.ErrBadInput, err, op, "uuid validation") } sessionIdHash := sha256string(sessionId) resp := valkey.NewLuaScript(deleteSessionScript).Exec( ctx, r.db, nil, []string{fmt.Sprintf("userid:*:sessionid:%s", sessionIdHash)}, ) err = resp.Error() if err != nil { if valkey.IsValkeyNil(err) { return pkg.Wrap(pkg.ErrBadInput, err, op, "nil response") } return pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error") } return nil } const ( deleteUserSessionsScript = `local cursor = 0 local dels = 0 repeat local result = redis.call('SCAN', cursor, 'MATCH', ARGV[1]) for _,key in ipairs(result[2]) do redis.call('DEL', key) dels = dels + 1 end cursor = tonumber(result[1]) until cursor == 0 return dels` ) func (r *ValkeyRepository) DeleteAllSessions(ctx context.Context, userId int32) error { const op = "ValkeyRepository.DeleteAllSessions" userIdHash := sha256string(strconv.FormatInt(int64(userId), 10)) resp := valkey.NewLuaScript(deleteUserSessionsScript).Exec( ctx, r.db, nil, []string{fmt.Sprintf("userid:%s:sessionid:*", userIdHash)}, ) err := resp.Error() if err != nil { if valkey.IsValkeyNil(err) { return pkg.Wrap(pkg.ErrBadInput, err, op, "nil response") } return pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error") } return nil }