237 lines
5.6 KiB
Go
237 lines
5.6 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"git.sch9.ru/new_gate/ms-auth/internal/models"
|
|
"git.sch9.ru/new_gate/ms-auth/pkg"
|
|
"github.com/google/uuid"
|
|
"github.com/valkey-io/valkey-go"
|
|
"strconv"
|
|
"time"
|
|
)
|
|
|
|
type ValkeyRepository struct {
|
|
db valkey.Client
|
|
}
|
|
|
|
func NewValkeyRepository(db valkey.Client) *ValkeyRepository {
|
|
return &ValkeyRepository{
|
|
db: db,
|
|
}
|
|
}
|
|
|
|
const sessionLifetime = time.Minute * 40
|
|
|
|
func sha256string(s string) string {
|
|
hasher := sha256.New()
|
|
hasher.Write([]byte(s))
|
|
return hex.EncodeToString(hasher.Sum(nil))
|
|
}
|
|
|
|
func (r *ValkeyRepository) CreateSession(ctx context.Context, userId int32, role models.Role, userAgent, ip string) (*models.Session, error) {
|
|
const op = "ValkeyRepository.CreateSession"
|
|
|
|
session := &models.Session{
|
|
Id: uuid.NewString(),
|
|
UserId: userId,
|
|
Role: role,
|
|
CreatedAt: time.Now(),
|
|
ExpiresAt: time.Now().Add(sessionLifetime),
|
|
UserAgent: userAgent,
|
|
Ip: ip,
|
|
}
|
|
|
|
err := session.Valid()
|
|
if err != nil {
|
|
return nil, pkg.Wrap(pkg.ErrInternal, err, op, "validating session")
|
|
}
|
|
|
|
userIdHash := sha256string(strconv.FormatInt(int64(userId), 10))
|
|
sessionIdHash := sha256string(session.Id)
|
|
|
|
sessionData, err := json.Marshal(session)
|
|
if err != nil {
|
|
return nil, pkg.Wrap(pkg.ErrInternal, err, op, "marshaling session")
|
|
}
|
|
|
|
resp := r.db.Do(ctx, r.db.
|
|
B().Set().
|
|
Key(fmt.Sprintf("userid:%s:sessionid:%s", userIdHash, sessionIdHash)).
|
|
Value(string(sessionData)).
|
|
Exat(session.ExpiresAt).
|
|
Build(),
|
|
)
|
|
|
|
err = resp.Error()
|
|
if err != nil {
|
|
if valkey.IsValkeyNil(err) {
|
|
return nil, pkg.Wrap(pkg.ErrInternal, err, op, "nil response")
|
|
}
|
|
return nil, pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error")
|
|
}
|
|
|
|
return session, nil
|
|
}
|
|
|
|
const (
|
|
readSessionScript = `local result = redis.call('SCAN', 0, 'MATCH', ARGV[1])
|
|
if #result[2] == 0 then
|
|
return nil
|
|
else
|
|
return redis.call('GET', result[2][1])
|
|
end`
|
|
)
|
|
|
|
func (r *ValkeyRepository) ReadSession(ctx context.Context, sessionId string) (*models.Session, error) {
|
|
const op = "ValkeyRepository.ReadSession"
|
|
|
|
err := uuid.Validate(sessionId)
|
|
if err != nil {
|
|
return nil, pkg.Wrap(pkg.ErrBadInput, err, op, "uuid validation")
|
|
}
|
|
|
|
sessionIdHash := sha256string(sessionId)
|
|
|
|
resp := valkey.NewLuaScript(readSessionScript).Exec(
|
|
ctx,
|
|
r.db,
|
|
nil,
|
|
[]string{fmt.Sprintf("userid:*:sessionid:%s", sessionIdHash)},
|
|
)
|
|
|
|
if err = resp.Error(); err != nil {
|
|
if valkey.IsValkeyNil(err) {
|
|
return nil, pkg.Wrap(pkg.ErrNotFound, err, op, "reading session")
|
|
}
|
|
return nil, pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error")
|
|
}
|
|
|
|
str, err := resp.ToString()
|
|
if err != nil {
|
|
return nil, pkg.Wrap(pkg.ErrInternal, err, op, "session storage corrupted")
|
|
}
|
|
|
|
var session models.Session
|
|
|
|
err = json.Unmarshal([]byte(str), &session)
|
|
|
|
if err != nil {
|
|
return nil, pkg.Wrap(pkg.ErrInternal, err, op, "session corrupted")
|
|
}
|
|
|
|
err = session.Valid()
|
|
if err != nil {
|
|
return nil, pkg.Wrap(pkg.ErrInternal, err, op, "validating session")
|
|
}
|
|
|
|
return &session, nil
|
|
}
|
|
|
|
const (
|
|
updateSessionScript = `local result = redis.call('SCAN', 0, 'MATCH', ARGV[1])
|
|
return #result[2] > 0 and redis.call('EXPIRE', result[2][1], ARGV[2]) == 1`
|
|
)
|
|
|
|
var (
|
|
sessionLifetimeString = strconv.Itoa(int(sessionLifetime.Seconds()))
|
|
)
|
|
|
|
func (r *ValkeyRepository) UpdateSession(ctx context.Context, sessionId string) error {
|
|
const op = "ValkeyRepository.UpdateSession"
|
|
|
|
err := uuid.Validate(sessionId)
|
|
if err != nil {
|
|
return pkg.Wrap(pkg.ErrBadInput, err, op, "uuid validation")
|
|
}
|
|
|
|
sessionIdHash := sha256string(sessionId)
|
|
|
|
resp := valkey.NewLuaScript(updateSessionScript).Exec(
|
|
ctx,
|
|
r.db,
|
|
nil,
|
|
[]string{fmt.Sprintf("userid:*:sessionid:%s", sessionIdHash), sessionLifetimeString},
|
|
)
|
|
|
|
err = resp.Error()
|
|
if err != nil {
|
|
if valkey.IsValkeyNil(err) {
|
|
return pkg.Wrap(pkg.ErrBadInput, err, op, "nil response")
|
|
}
|
|
return pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
const deleteSessionScript = `local result = redis.call('SCAN', 0, 'MATCH', ARGV[1])
|
|
return #result[2] > 0 and redis.call('DEL', result[2][1]) == 1`
|
|
|
|
func (r *ValkeyRepository) DeleteSession(ctx context.Context, sessionId string) error {
|
|
const op = "ValkeyRepository.DeleteSession"
|
|
|
|
err := uuid.Validate(sessionId)
|
|
if err != nil {
|
|
return pkg.Wrap(pkg.ErrBadInput, err, op, "uuid validation")
|
|
}
|
|
|
|
sessionIdHash := sha256string(sessionId)
|
|
|
|
resp := valkey.NewLuaScript(deleteSessionScript).Exec(
|
|
ctx,
|
|
r.db,
|
|
nil,
|
|
[]string{fmt.Sprintf("userid:*:sessionid:%s", sessionIdHash)},
|
|
)
|
|
|
|
err = resp.Error()
|
|
if err != nil {
|
|
if valkey.IsValkeyNil(err) {
|
|
return pkg.Wrap(pkg.ErrBadInput, err, op, "nil response")
|
|
}
|
|
return pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
const (
|
|
deleteUserSessionsScript = `local cursor = 0
|
|
local dels = 0
|
|
repeat
|
|
local result = redis.call('SCAN', cursor, 'MATCH', ARGV[1])
|
|
for _,key in ipairs(result[2]) do
|
|
redis.call('DEL', key)
|
|
dels = dels + 1
|
|
end
|
|
cursor = tonumber(result[1])
|
|
until cursor == 0
|
|
return dels`
|
|
)
|
|
|
|
func (r *ValkeyRepository) DeleteAllSessions(ctx context.Context, userId int32) error {
|
|
const op = "ValkeyRepository.DeleteAllSessions"
|
|
|
|
userIdHash := sha256string(strconv.FormatInt(int64(userId), 10))
|
|
|
|
resp := valkey.NewLuaScript(deleteUserSessionsScript).Exec(
|
|
ctx,
|
|
r.db,
|
|
nil,
|
|
[]string{fmt.Sprintf("userid:%s:sessionid:*", userIdHash)},
|
|
)
|
|
|
|
err := resp.Error()
|
|
if err != nil {
|
|
if valkey.IsValkeyNil(err) {
|
|
return pkg.Wrap(pkg.ErrBadInput, err, op, "nil response")
|
|
}
|
|
return pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error")
|
|
}
|
|
|
|
return nil
|
|
}
|