195 lines
4.6 KiB
Go
195 lines
4.6 KiB
Go
|
package repository
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"git.sch9.ru/new_gate/ms-auth/internal/models"
|
||
|
"git.sch9.ru/new_gate/ms-auth/pkg"
|
||
|
"github.com/google/uuid"
|
||
|
"github.com/valkey-io/valkey-go"
|
||
|
"strconv"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
type ValkeyRepository struct {
|
||
|
db valkey.Client
|
||
|
}
|
||
|
|
||
|
func NewValkeyRepository(db valkey.Client) *ValkeyRepository {
|
||
|
return &ValkeyRepository{
|
||
|
db: db,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
const sessionLifetime = time.Minute * 40
|
||
|
|
||
|
func (r *ValkeyRepository) CreateSession(ctx context.Context, userId int32, role models.Role) (string, error) {
|
||
|
const op = "ValkeyRepository.CreateSession"
|
||
|
|
||
|
sessionData, sessionId, err := models.NewSession(userId, role)
|
||
|
if err != nil {
|
||
|
return "", pkg.Wrap(pkg.ErrBadInput, err, op, "building session")
|
||
|
}
|
||
|
|
||
|
resp := r.db.Do(ctx, r.db.
|
||
|
B().Set().
|
||
|
Key(fmt.Sprintf("userid:%d:sessionid:%s", userId, sessionId)).
|
||
|
Value(sessionData).
|
||
|
Ex(sessionLifetime).
|
||
|
Build(),
|
||
|
)
|
||
|
|
||
|
err = resp.Error()
|
||
|
if err != nil {
|
||
|
if valkey.IsValkeyNil(err) {
|
||
|
return "", pkg.Wrap(pkg.ErrBadInput, err, op, "nil response")
|
||
|
}
|
||
|
return "", pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error")
|
||
|
}
|
||
|
|
||
|
return sessionId, nil
|
||
|
}
|
||
|
|
||
|
const (
|
||
|
readSessionScript = `local result = redis.call('SCAN', 0, 'MATCH', ARGV[1])
|
||
|
if #result[2] == 0 then
|
||
|
return nil
|
||
|
else
|
||
|
return redis.call('GET', result[2][1])
|
||
|
end`
|
||
|
)
|
||
|
|
||
|
func (r *ValkeyRepository) ReadSession(ctx context.Context, sessionId string) (*models.Session, error) {
|
||
|
const op = "ValkeyRepository.ReadSession"
|
||
|
|
||
|
err := uuid.Validate(sessionId)
|
||
|
if err != nil {
|
||
|
return nil, pkg.Wrap(pkg.ErrBadInput, err, op, "uuid validation")
|
||
|
}
|
||
|
|
||
|
resp := valkey.NewLuaScript(readSessionScript).Exec(
|
||
|
ctx,
|
||
|
r.db,
|
||
|
nil,
|
||
|
[]string{fmt.Sprintf("userid:*:sessionid:%s", sessionId)},
|
||
|
)
|
||
|
|
||
|
if err = resp.Error(); err != nil {
|
||
|
if valkey.IsValkeyNil(err) {
|
||
|
return nil, pkg.Wrap(pkg.ErrNotFound, err, op, "reading session")
|
||
|
}
|
||
|
return nil, pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error")
|
||
|
}
|
||
|
|
||
|
str, err := resp.ToString()
|
||
|
if err != nil {
|
||
|
return nil, pkg.Wrap(pkg.ErrInternal, err, op, "session storage corrupted")
|
||
|
}
|
||
|
|
||
|
session, err := models.ParseSession(str)
|
||
|
if err != nil {
|
||
|
return nil, pkg.Wrap(pkg.ErrInternal, err, op, "session corrupted")
|
||
|
}
|
||
|
|
||
|
return session, nil
|
||
|
}
|
||
|
|
||
|
const (
|
||
|
updateSessionScript = `local result = redis.call('SCAN', 0, 'MATCH', ARGV[1])
|
||
|
return #result[2] > 0 and redis.call('EXPIRE', result[2][1], ARGV[2]) == 1`
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
sessionLifetimeString = strconv.Itoa(int(sessionLifetime.Seconds()))
|
||
|
)
|
||
|
|
||
|
func (r *ValkeyRepository) UpdateSession(ctx context.Context, sessionId string) error {
|
||
|
const op = "ValkeyRepository.UpdateSession"
|
||
|
|
||
|
err := uuid.Validate(sessionId)
|
||
|
if err != nil {
|
||
|
return pkg.Wrap(pkg.ErrBadInput, err, op, "uuid validation")
|
||
|
}
|
||
|
|
||
|
resp := valkey.NewLuaScript(updateSessionScript).Exec(
|
||
|
ctx,
|
||
|
r.db,
|
||
|
nil,
|
||
|
[]string{fmt.Sprintf("userid:*:sessionid:%s", sessionId), sessionLifetimeString},
|
||
|
)
|
||
|
|
||
|
err = resp.Error()
|
||
|
if err != nil {
|
||
|
if valkey.IsValkeyNil(err) {
|
||
|
return pkg.Wrap(pkg.ErrBadInput, err, op, "nil response")
|
||
|
}
|
||
|
return pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error")
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
const deleteSessionScript = `local result = redis.call('SCAN', 0, 'MATCH', KEYS[1])
|
||
|
return #result[2] > 0 and redis.call('DEL', result[2][1]) == 1`
|
||
|
|
||
|
func (r *ValkeyRepository) DeleteSession(ctx context.Context, sessionId string) error {
|
||
|
const op = "ValkeyRepository.DeleteSession"
|
||
|
|
||
|
err := uuid.Validate(sessionId)
|
||
|
if err != nil {
|
||
|
return pkg.Wrap(pkg.ErrBadInput, err, op, "uuid validation")
|
||
|
}
|
||
|
|
||
|
resp := valkey.NewLuaScript(deleteSessionScript).Exec(
|
||
|
ctx,
|
||
|
r.db,
|
||
|
nil,
|
||
|
[]string{fmt.Sprintf("userid:*:sessionid:%s", sessionId)},
|
||
|
)
|
||
|
|
||
|
err = resp.Error()
|
||
|
if err != nil {
|
||
|
if valkey.IsValkeyNil(err) {
|
||
|
return pkg.Wrap(pkg.ErrBadInput, err, op, "nil response")
|
||
|
}
|
||
|
return pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error")
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
const (
|
||
|
deleteUserSessionsScript = `local cursor = 0
|
||
|
local dels = 0
|
||
|
repeat
|
||
|
local result = redis.call('SCAN', cursor, 'MATCH', ARGV[1])
|
||
|
for _,key in ipairs(result[2]) do
|
||
|
redis.call('DEL', key)
|
||
|
dels = dels + 1
|
||
|
end
|
||
|
cursor = tonumber(result[1])
|
||
|
until cursor == 0
|
||
|
return dels`
|
||
|
)
|
||
|
|
||
|
func (r *ValkeyRepository) DeleteAllSessions(ctx context.Context, userId int32) error {
|
||
|
const op = "ValkeyRepository.DeleteAllSessions"
|
||
|
|
||
|
resp := valkey.NewLuaScript(deleteUserSessionsScript).Exec(
|
||
|
ctx,
|
||
|
r.db,
|
||
|
nil,
|
||
|
[]string{fmt.Sprintf("userid:%d:sessionid:*", userId)},
|
||
|
)
|
||
|
|
||
|
err := resp.Error()
|
||
|
if err != nil {
|
||
|
if valkey.IsValkeyNil(err) {
|
||
|
return pkg.Wrap(pkg.ErrBadInput, err, op, "nil response")
|
||
|
}
|
||
|
return pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error")
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|