ms-auth/internal/users/repository/valkey_repository.go
2025-02-25 18:33:15 +05:00

237 lines
5.6 KiB
Go

package repository
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"git.sch9.ru/new_gate/ms-auth/internal/models"
"git.sch9.ru/new_gate/ms-auth/pkg"
"github.com/google/uuid"
"github.com/valkey-io/valkey-go"
"strconv"
"time"
)
type ValkeyRepository struct {
db valkey.Client
}
func NewValkeyRepository(db valkey.Client) *ValkeyRepository {
return &ValkeyRepository{
db: db,
}
}
const sessionLifetime = time.Minute * 40
func sha256string(s string) string {
hasher := sha256.New()
hasher.Write([]byte(s))
return hex.EncodeToString(hasher.Sum(nil))
}
func (r *ValkeyRepository) CreateSession(ctx context.Context, userId int32, role models.Role, userAgent, ip string) (*models.Session, error) {
const op = "ValkeyRepository.CreateSession"
session := &models.Session{
Id: uuid.NewString(),
UserId: userId,
Role: role,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(sessionLifetime),
UserAgent: userAgent,
Ip: ip,
}
err := session.Valid()
if err != nil {
return nil, pkg.Wrap(pkg.ErrInternal, err, op, "validating session")
}
userIdHash := sha256string(strconv.FormatInt(int64(userId), 10))
sessionIdHash := sha256string(session.Id)
sessionData, err := json.Marshal(session)
if err != nil {
return nil, pkg.Wrap(pkg.ErrInternal, err, op, "marshaling session")
}
resp := r.db.Do(ctx, r.db.
B().Set().
Key(fmt.Sprintf("userid:%s:sessionid:%s", userIdHash, sessionIdHash)).
Value(string(sessionData)).
Exat(session.ExpiresAt).
Build(),
)
err = resp.Error()
if err != nil {
if valkey.IsValkeyNil(err) {
return nil, pkg.Wrap(pkg.ErrInternal, err, op, "nil response")
}
return nil, pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error")
}
return session, nil
}
const (
readSessionScript = `local result = redis.call('SCAN', 0, 'MATCH', ARGV[1])
if #result[2] == 0 then
return nil
else
return redis.call('GET', result[2][1])
end`
)
func (r *ValkeyRepository) ReadSession(ctx context.Context, sessionId string) (*models.Session, error) {
const op = "ValkeyRepository.ReadSession"
err := uuid.Validate(sessionId)
if err != nil {
return nil, pkg.Wrap(pkg.ErrBadInput, err, op, "uuid validation")
}
sessionIdHash := sha256string(sessionId)
resp := valkey.NewLuaScript(readSessionScript).Exec(
ctx,
r.db,
nil,
[]string{fmt.Sprintf("userid:*:sessionid:%s", sessionIdHash)},
)
if err = resp.Error(); err != nil {
if valkey.IsValkeyNil(err) {
return nil, pkg.Wrap(pkg.ErrNotFound, err, op, "reading session")
}
return nil, pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error")
}
str, err := resp.ToString()
if err != nil {
return nil, pkg.Wrap(pkg.ErrInternal, err, op, "session storage corrupted")
}
var session models.Session
err = json.Unmarshal([]byte(str), &session)
if err != nil {
return nil, pkg.Wrap(pkg.ErrInternal, err, op, "session corrupted")
}
err = session.Valid()
if err != nil {
return nil, pkg.Wrap(pkg.ErrInternal, err, op, "validating session")
}
return &session, nil
}
const (
updateSessionScript = `local result = redis.call('SCAN', 0, 'MATCH', ARGV[1])
return #result[2] > 0 and redis.call('EXPIRE', result[2][1], ARGV[2]) == 1`
)
var (
sessionLifetimeString = strconv.Itoa(int(sessionLifetime.Seconds()))
)
func (r *ValkeyRepository) UpdateSession(ctx context.Context, sessionId string) error {
const op = "ValkeyRepository.UpdateSession"
err := uuid.Validate(sessionId)
if err != nil {
return pkg.Wrap(pkg.ErrBadInput, err, op, "uuid validation")
}
sessionIdHash := sha256string(sessionId)
resp := valkey.NewLuaScript(updateSessionScript).Exec(
ctx,
r.db,
nil,
[]string{fmt.Sprintf("userid:*:sessionid:%s", sessionIdHash), sessionLifetimeString},
)
err = resp.Error()
if err != nil {
if valkey.IsValkeyNil(err) {
return pkg.Wrap(pkg.ErrBadInput, err, op, "nil response")
}
return pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error")
}
return nil
}
const deleteSessionScript = `local result = redis.call('SCAN', 0, 'MATCH', ARGV[1])
return #result[2] > 0 and redis.call('DEL', result[2][1]) == 1`
func (r *ValkeyRepository) DeleteSession(ctx context.Context, sessionId string) error {
const op = "ValkeyRepository.DeleteSession"
err := uuid.Validate(sessionId)
if err != nil {
return pkg.Wrap(pkg.ErrBadInput, err, op, "uuid validation")
}
sessionIdHash := sha256string(sessionId)
resp := valkey.NewLuaScript(deleteSessionScript).Exec(
ctx,
r.db,
nil,
[]string{fmt.Sprintf("userid:*:sessionid:%s", sessionIdHash)},
)
err = resp.Error()
if err != nil {
if valkey.IsValkeyNil(err) {
return pkg.Wrap(pkg.ErrBadInput, err, op, "nil response")
}
return pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error")
}
return nil
}
const (
deleteUserSessionsScript = `local cursor = 0
local dels = 0
repeat
local result = redis.call('SCAN', cursor, 'MATCH', ARGV[1])
for _,key in ipairs(result[2]) do
redis.call('DEL', key)
dels = dels + 1
end
cursor = tonumber(result[1])
until cursor == 0
return dels`
)
func (r *ValkeyRepository) DeleteAllSessions(ctx context.Context, userId int32) error {
const op = "ValkeyRepository.DeleteAllSessions"
userIdHash := sha256string(strconv.FormatInt(int64(userId), 10))
resp := valkey.NewLuaScript(deleteUserSessionsScript).Exec(
ctx,
r.db,
nil,
[]string{fmt.Sprintf("userid:%s:sessionid:*", userIdHash)},
)
err := resp.Error()
if err != nil {
if valkey.IsValkeyNil(err) {
return pkg.Wrap(pkg.ErrBadInput, err, op, "nil response")
}
return pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error")
}
return nil
}