refactor(user): refactor api

This commit is contained in:
Vyacheslav1557 2025-03-28 16:33:13 +05:00
parent 2a18e1d42f
commit 51c5b4943c
10 changed files with 156 additions and 102 deletions

View file

@ -0,0 +1,6 @@
package models
type Pagination struct {
Page int32 `json:"page"`
Total int32 `json:"total"`
}

View file

@ -38,6 +38,13 @@ func (s Session) Valid() error {
return nil return nil
} }
type SessionCreation struct {
UserId int32
Role Role
UserAgent string
Ip string
}
type JWT struct { type JWT struct {
SessionId string `json:"session_id"` SessionId string `json:"session_id"`
UserId int32 `json:"user_id"` UserId int32 `json:"user_id"`

View file

@ -10,6 +10,34 @@ import (
type Role int32 type Role int32
type User struct {
Id int32 `db:"id"`
Username string `db:"username"`
HashedPassword string `db:"hashed_pwd"`
CreatedAt time.Time `db:"created_at"`
ModifiedAt time.Time `db:"modified_at"`
Role Role `db:"role"`
}
type UsersListFilters struct {
PageSize int32
Page int32
}
func (f UsersListFilters) Offset() int32 {
return (f.Page - 1) * f.PageSize
}
type UsersList struct {
Users []*User
Pagination Pagination
}
type UserUpdate struct {
Username *string
Role *Role
}
const ( const (
RoleGuest Role = -1 RoleGuest Role = -1
RoleStudent Role = 0 RoleStudent Role = 0
@ -96,15 +124,6 @@ allow if {
} }
` `
type User struct {
Id int32 `db:"id"`
Username string `db:"username"`
HashedPassword string `db:"hashed_pwd"`
CreatedAt time.Time `db:"created_at"`
ModifiedAt time.Time `db:"modified_at"`
Role Role `db:"role"`
}
func (user *User) MarshalJSON() ([]byte, error) { func (user *User) MarshalJSON() ([]byte, error) {
m := map[string]interface{}{ m := map[string]interface{}{
"id": user.Id, "id": user.Id,

View file

@ -27,31 +27,19 @@ func NewUserHandlers(userUC users.UseCase, jwtSecret string) *UserHandlers {
} }
func (h *UserHandlers) Login(c *fiber.Ctx) error { func (h *UserHandlers) Login(c *fiber.Ctx) error {
const op = "UserHandlers.Login"
authHeader := c.Get("Authorization", "") authHeader := c.Get("Authorization", "")
if authHeader == "" { if authHeader == "" {
return c.SendStatus(fiber.StatusUnauthorized) return c.SendStatus(fiber.StatusUnauthorized)
} }
authParts := strings.Split(authHeader, " ") username, pwd, err := parseBasicAuth(authHeader)
if len(authParts) != 2 || strings.ToLower(authParts[0]) != "basic" {
return c.SendStatus(fiber.StatusUnauthorized)
}
decodedAuth, err := base64.StdEncoding.DecodeString(authParts[1])
if err != nil { if err != nil {
return c.SendStatus(fiber.StatusUnauthorized) return c.SendStatus(fiber.StatusUnauthorized)
} }
authParts = strings.Split(string(decodedAuth), ":")
if len(authParts) != 2 {
return c.SendStatus(fiber.StatusUnauthorized)
}
ctx := c.Context() ctx := c.Context()
user, err := h.userUC.ReadUserByUsername(ctx, authParts[0]) user, err := h.userUC.ReadUserByUsername(ctx, username)
if err != nil { if err != nil {
if errors.Is(err, pkg.ErrNotFound) { if errors.Is(err, pkg.ErrNotFound) {
return c.SendStatus(fiber.StatusUnauthorized) return c.SendStatus(fiber.StatusUnauthorized)
@ -60,14 +48,19 @@ func (h *UserHandlers) Login(c *fiber.Ctx) error {
return c.SendStatus(pkg.ToREST(err)) return c.SendStatus(pkg.ToREST(err))
} }
if !user.IsSamePwd(authParts[1]) { if !user.IsSamePwd(pwd) {
return c.SendStatus(fiber.StatusUnauthorized) return c.SendStatus(fiber.StatusUnauthorized)
} }
userAgent := c.Get("User-Agent", "") userAgent := c.Get("User-Agent", "")
ip := c.IP() ip := c.IP()
session, err := h.userUC.CreateSession(ctx, user.Id, user.Role, userAgent, ip) session, err := h.userUC.CreateSession(ctx, &models.SessionCreation{
UserId: user.Id,
Role: user.Role,
UserAgent: userAgent,
Ip: ip,
})
if err != nil { if err != nil {
return c.SendStatus(pkg.ToREST(err)) return c.SendStatus(pkg.ToREST(err))
} }
@ -93,8 +86,6 @@ func (h *UserHandlers) Login(c *fiber.Ctx) error {
} }
func (h *UserHandlers) Refresh(c *fiber.Ctx) error { func (h *UserHandlers) Refresh(c *fiber.Ctx) error {
const op = "UserHandlers.Refresh"
token, ok := c.Locals(TokenKey).(*models.JWT) token, ok := c.Locals(TokenKey).(*models.JWT)
if !ok { if !ok {
return c.SendStatus(fiber.StatusUnauthorized) return c.SendStatus(fiber.StatusUnauthorized)
@ -109,8 +100,6 @@ func (h *UserHandlers) Refresh(c *fiber.Ctx) error {
} }
func (h *UserHandlers) Logout(c *fiber.Ctx) error { func (h *UserHandlers) Logout(c *fiber.Ctx) error {
const op = "UserHandlers.Logout"
token, ok := c.Locals(TokenKey).(*models.JWT) token, ok := c.Locals(TokenKey).(*models.JWT)
if !ok { if !ok {
return c.SendStatus(fiber.StatusUnauthorized) return c.SendStatus(fiber.StatusUnauthorized)
@ -125,8 +114,6 @@ func (h *UserHandlers) Logout(c *fiber.Ctx) error {
} }
func (h *UserHandlers) CompleteLogout(c *fiber.Ctx) error { func (h *UserHandlers) CompleteLogout(c *fiber.Ctx) error {
const op = "UserHandlers.CompleteLogout"
token, ok := c.Locals(TokenKey).(*models.JWT) token, ok := c.Locals(TokenKey).(*models.JWT)
if !ok { if !ok {
return c.SendStatus(fiber.StatusUnauthorized) return c.SendStatus(fiber.StatusUnauthorized)
@ -143,14 +130,10 @@ func (h *UserHandlers) CompleteLogout(c *fiber.Ctx) error {
} }
func (h *UserHandlers) Verify(c *fiber.Ctx) error { func (h *UserHandlers) Verify(c *fiber.Ctx) error {
const op = "UserHandlers.Verify"
return c.SendStatus(fiber.StatusNotImplemented) return c.SendStatus(fiber.StatusNotImplemented)
} }
func (h *UserHandlers) CreateUser(c *fiber.Ctx) error { func (h *UserHandlers) CreateUser(c *fiber.Ctx) error {
const op = "UserHandlers.CreateUser"
ctx := c.Context() ctx := c.Context()
var req = &userv1.CreateUserRequest{} var req = &userv1.CreateUserRequest{}
@ -170,35 +153,31 @@ func (h *UserHandlers) CreateUser(c *fiber.Ctx) error {
return c.SendStatus(pkg.ToREST(err)) return c.SendStatus(pkg.ToREST(err))
} }
return c.JSON(map[string]interface{}{ return c.JSON(userv1.CreateUserResponse{Id: id})
"id": id,
})
} }
func (h *UserHandlers) GetUser(c *fiber.Ctx, id int32) error { func (h *UserHandlers) GetUser(c *fiber.Ctx, id int32) error {
const op = "UserHandlers.GetUser"
user, err := h.userUC.ReadUserById(c.Context(), id) user, err := h.userUC.ReadUserById(c.Context(), id)
if err != nil { if err != nil {
return c.SendStatus(pkg.ToREST(err)) return c.SendStatus(pkg.ToREST(err))
} }
return c.JSON(map[string]interface{}{ return c.JSON(userv1.GetUserResponse{
"user": user, User: U2U(*user),
}) })
} }
func (h *UserHandlers) UpdateUser(c *fiber.Ctx, id int32) error { func (h *UserHandlers) UpdateUser(c *fiber.Ctx, id int32) error {
const op = "UserHandlers.UpdateUser"
var req = &userv1.UpdateUserRequest{} var req = &userv1.UpdateUserRequest{}
err := c.BodyParser(req) err := c.BodyParser(req)
if err != nil { if err != nil {
return c.SendStatus(fiber.StatusBadRequest) return c.SendStatus(fiber.StatusBadRequest)
} }
err = h.userUC.UpdateUser(c.Context(), id, req.Username, int32PtrToRolePtr(req.Role)) err = h.userUC.UpdateUser(c.Context(), id, &models.UserUpdate{
Username: req.Username,
Role: I32P2RP(req.Role),
})
if err != nil { if err != nil {
return c.SendStatus(pkg.ToREST(err)) return c.SendStatus(pkg.ToREST(err))
} }
@ -207,8 +186,6 @@ func (h *UserHandlers) UpdateUser(c *fiber.Ctx, id int32) error {
} }
func (h *UserHandlers) DeleteUser(c *fiber.Ctx, id int32) error { func (h *UserHandlers) DeleteUser(c *fiber.Ctx, id int32) error {
const op = "UserHandlers.DeleteUser"
ctx := c.Context() ctx := c.Context()
err := h.userUC.DeleteUser(ctx, id) err := h.userUC.DeleteUser(ctx, id)
@ -220,35 +197,75 @@ func (h *UserHandlers) DeleteUser(c *fiber.Ctx, id int32) error {
} }
func (h *UserHandlers) ListUsers(c *fiber.Ctx, params userv1.ListUsersParams) error { func (h *UserHandlers) ListUsers(c *fiber.Ctx, params userv1.ListUsersParams) error {
const op = "UserHandlers.ListUsers" usersList, err := h.userUC.ListUsers(c.Context(), models.UsersListFilters{
PageSize: params.PageSize,
usersList, count, err := h.userUC.ListUsers(c.Context(), params.Page, params.PageSize) Page: params.Page,
})
if err != nil { if err != nil {
return c.SendStatus(pkg.ToREST(err)) return c.SendStatus(pkg.ToREST(err))
} }
return c.JSON(map[string]interface{}{ resp := userv1.ListUsersResponse{
"users": usersList, Users: make([]userv1.User, len(usersList.Users)),
"page": params.Page, Pagination: P2P(usersList.Pagination),
"max_page": func() int32 { }
if count%params.PageSize == 0 {
return count / params.PageSize for i, user := range usersList.Users {
} resp.Users[i] = U2U(*user)
return count/params.PageSize + 1 }
}(),
}) return c.JSON(resp)
} }
func (h *UserHandlers) ListSessions(c *fiber.Ctx) error { func (h *UserHandlers) ListSessions(c *fiber.Ctx) error {
const op = "UserHandlers.ListSessions"
return c.SendStatus(fiber.StatusNotImplemented) return c.SendStatus(fiber.StatusNotImplemented)
} }
func int32PtrToRolePtr(i *int32) *models.Role { func parseBasicAuth(header string) (string, string, error) {
const (
op = "parseBasicAuth"
msg = "invalid auth header"
)
authParts := strings.Split(header, " ")
if len(authParts) != 2 || strings.ToLower(authParts[0]) != "basic" {
return "", "", pkg.Wrap(pkg.ErrUnauthenticated, nil, op, msg)
}
decodedAuth, err := base64.StdEncoding.DecodeString(authParts[1])
if err != nil {
return "", "", pkg.Wrap(pkg.ErrUnauthenticated, nil, op, msg)
}
authParts = strings.Split(string(decodedAuth), ":")
if len(authParts) != 2 {
return "", "", pkg.Wrap(pkg.ErrUnauthenticated, nil, op, msg)
}
return authParts[0], authParts[1], nil
}
func I32P2RP(i *int32) *models.Role {
if i == nil { if i == nil {
return nil return nil
} }
ii := models.Role(*i) ii := models.Role(*i)
return &ii return &ii
} }
func P2P(p models.Pagination) userv1.Pagination {
return userv1.Pagination{
Page: p.Page,
Total: p.Total,
}
}
func U2U(u models.User) userv1.User {
return userv1.User{
Id: u.Id,
Username: u.Username,
Role: int32(u.Role),
CreatedAt: u.CreatedAt,
ModifiedAt: u.ModifiedAt,
}
}

View file

@ -9,9 +9,9 @@ type Caller interface {
CreateUser(ctx context.Context, username string, password string, role models.Role) (int32, error) CreateUser(ctx context.Context, username string, password string, role models.Role) (int32, error)
ReadUserByUsername(ctx context.Context, username string) (*models.User, error) ReadUserByUsername(ctx context.Context, username string) (*models.User, error)
ReadUserById(ctx context.Context, id int32) (*models.User, error) ReadUserById(ctx context.Context, id int32) (*models.User, error)
UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error UpdateUser(ctx context.Context, id int32, update *models.UserUpdate) error
DeleteUser(ctx context.Context, id int32) error DeleteUser(ctx context.Context, id int32) error
ListUsers(ctx context.Context, page int32, pageSize int32) ([]*models.User, int32, error) ListUsers(ctx context.Context, filters models.UsersListFilters) (*models.UsersList, error)
} }
type TxCaller interface { type TxCaller interface {
@ -26,7 +26,7 @@ type PgRepository interface {
} }
type ValkeyRepository interface { type ValkeyRepository interface {
CreateSession(ctx context.Context, userId int32, role models.Role, userAgent, ip string) (*models.Session, error) CreateSession(ctx context.Context, creation *models.SessionCreation) (*models.Session, error)
ReadSession(ctx context.Context, sessionId string) (*models.Session, error) ReadSession(ctx context.Context, sessionId string) (*models.Session, error)
UpdateSession(ctx context.Context, sessionId string) error UpdateSession(ctx context.Context, sessionId string) error
DeleteSession(ctx context.Context, sessionId string) error DeleteSession(ctx context.Context, sessionId string) error

View file

@ -164,12 +164,12 @@ SET username = COALESCE(?, trim(lower(username))),
WHERE id = ? WHERE id = ?
` `
func (c *Caller) UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error { func (c *Caller) UpdateUser(ctx context.Context, id int32, update *models.UserUpdate) error {
const op = "Caller.UpdateUser" const op = "Caller.UpdateUser"
var err error var err error
if username != nil { if update.Username != nil {
if err = ValidUsername(*username); err != nil { if err = ValidUsername(*update.Username); err != nil {
return pkg.Wrap(pkg.ErrBadInput, err, op, "username validation") return pkg.Wrap(pkg.ErrBadInput, err, op, "username validation")
} }
} }
@ -178,8 +178,8 @@ func (c *Caller) UpdateUser(ctx context.Context, id int32, username *string, rol
_, err = c.db.ExecContext( _, err = c.db.ExecContext(
ctx, ctx,
query, query,
username, update.Username,
role, update.Role,
id, id,
) )
@ -208,29 +208,33 @@ const (
CountUsers = "SELECT COUNT(*) FROM users" CountUsers = "SELECT COUNT(*) FROM users"
) )
func (c *Caller) ListUsers(ctx context.Context, page int32, pageSize int32) ([]*models.User, int32, error) { func (c *Caller) ListUsers(ctx context.Context, filters models.UsersListFilters) (*models.UsersList, error) {
const op = "Caller.ListUsers" const op = "Caller.ListUsers"
if pageSize > 20 { if filters.PageSize > 20 {
return nil, 0, pkg.Wrap(pkg.ErrBadInput, nil, op, "limit > 20") return nil, pkg.Wrap(pkg.ErrBadInput, nil, op, "limit > 20")
} }
var usersList []*models.User usersList := &models.UsersList{
Users: make([]*models.User, 0),
Pagination: models.Pagination{},
}
query := c.db.Rebind(ListUsers) query := c.db.Rebind(ListUsers)
err := c.db.SelectContext(ctx, &usersList, query, pageSize, (page-1)*pageSize) err := c.db.SelectContext(ctx, &usersList.Users, query, filters.PageSize, filters.Offset())
if err != nil { if err != nil {
return nil, 0, handlePgErr(err, op) return nil, handlePgErr(err, op)
} }
query = c.db.Rebind(CountUsers) query = c.db.Rebind(CountUsers)
var count int32 err = c.db.GetContext(ctx, &usersList.Pagination.Total, query)
err = c.db.GetContext(ctx, &count, query)
if err != nil { if err != nil {
return nil, 0, handlePgErr(err, op) return nil, handlePgErr(err, op)
} }
return usersList, count, nil usersList.Pagination.Page = filters.Page
return usersList, nil
} }
func handlePgErr(err error, op string) error { func handlePgErr(err error, op string) error {

View file

@ -32,17 +32,17 @@ func sha256string(s string) string {
return hex.EncodeToString(hasher.Sum(nil)) return hex.EncodeToString(hasher.Sum(nil))
} }
func (r *ValkeyRepository) CreateSession(ctx context.Context, userId int32, role models.Role, userAgent, ip string) (*models.Session, error) { func (r *ValkeyRepository) CreateSession(ctx context.Context, creation *models.SessionCreation) (*models.Session, error) {
const op = "ValkeyRepository.CreateSession" const op = "ValkeyRepository.CreateSession"
session := &models.Session{ session := &models.Session{
Id: uuid.NewString(), Id: uuid.NewString(),
UserId: userId, UserId: creation.UserId,
Role: role, Role: creation.Role,
CreatedAt: time.Now(), CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(sessionLifetime), ExpiresAt: time.Now().Add(sessionLifetime),
UserAgent: userAgent, UserAgent: creation.UserAgent,
Ip: ip, Ip: creation.Ip,
} }
err := session.Valid() err := session.Valid()
@ -50,7 +50,7 @@ func (r *ValkeyRepository) CreateSession(ctx context.Context, userId int32, role
return nil, pkg.Wrap(pkg.ErrInternal, err, op, "validating session") return nil, pkg.Wrap(pkg.ErrInternal, err, op, "validating session")
} }
userIdHash := sha256string(strconv.FormatInt(int64(userId), 10)) userIdHash := sha256string(strconv.FormatInt(int64(creation.UserId), 10))
sessionIdHash := sha256string(session.Id) sessionIdHash := sha256string(session.Id)
sessionData, err := json.Marshal(session) sessionData, err := json.Marshal(session)

View file

@ -9,12 +9,12 @@ type UseCase interface {
CreateUser(ctx context.Context, username string, password string, role models.Role) (int32, error) CreateUser(ctx context.Context, username string, password string, role models.Role) (int32, error)
ReadUserById(ctx context.Context, id int32) (*models.User, error) ReadUserById(ctx context.Context, id int32) (*models.User, error)
ReadUserByUsername(ctx context.Context, username string) (*models.User, error) ReadUserByUsername(ctx context.Context, username string) (*models.User, error)
UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error UpdateUser(ctx context.Context, id int32, update *models.UserUpdate) error
DeleteUser(ctx context.Context, id int32) error DeleteUser(ctx context.Context, id int32) error
CreateSession(ctx context.Context, userId int32, role models.Role, userAgent, ip string) (*models.Session, error) CreateSession(ctx context.Context, creation *models.SessionCreation) (*models.Session, error)
ReadSession(ctx context.Context, sessionId string) (*models.Session, error) ReadSession(ctx context.Context, sessionId string) (*models.Session, error)
UpdateSession(ctx context.Context, sessionId string) error UpdateSession(ctx context.Context, sessionId string) error
DeleteSession(ctx context.Context, sessionId string) error DeleteSession(ctx context.Context, sessionId string) error
DeleteAllSessions(ctx context.Context, userId int32) error DeleteAllSessions(ctx context.Context, userId int32) error
ListUsers(ctx context.Context, page int32, pageSize int32) ([]*models.User, int32, error) ListUsers(ctx context.Context, filters models.UsersListFilters) (*models.UsersList, error)
} }

View file

@ -85,7 +85,7 @@ func (u *UseCase) ReadUserByUsername(ctx context.Context, username string) (*mod
return user, nil return user, nil
} }
func (u *UseCase) UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error { func (u *UseCase) UpdateUser(ctx context.Context, id int32, update *models.UserUpdate) error {
const op = "UseCase.UpdateUser" const op = "UseCase.UpdateUser"
token, ok := ctx.Value(TokenKey).(*models.JWT) token, ok := ctx.Value(TokenKey).(*models.JWT)
@ -106,7 +106,7 @@ func (u *UseCase) UpdateUser(ctx context.Context, id int32, username *string, ro
return pkg.Wrap(nil, err, op, "cannot start transaction") return pkg.Wrap(nil, err, op, "cannot start transaction")
} }
err = tx.UpdateUser(ctx, id, username, role) err = tx.UpdateUser(ctx, id, update)
if err != nil { if err != nil {
return pkg.Wrap(nil, errors.Join(err, tx.Rollback()), op, "cannot update user") return pkg.Wrap(nil, errors.Join(err, tx.Rollback()), op, "cannot update user")
} }
@ -161,10 +161,10 @@ func (u *UseCase) DeleteUser(ctx context.Context, id int32) error {
} }
// CreateSession is for login only. There are no permission checks! DO NOT USE IT AS AN ENDPOINT RESPONSE! // CreateSession is for login only. There are no permission checks! DO NOT USE IT AS AN ENDPOINT RESPONSE!
func (u *UseCase) CreateSession(ctx context.Context, userId int32, role models.Role, userAgent, ip string) (*models.Session, error) { func (u *UseCase) CreateSession(ctx context.Context, creation *models.SessionCreation) (*models.Session, error) {
const op = "UseCase.CreateSession" const op = "UseCase.CreateSession"
session, err := u.sessionRepo.CreateSession(ctx, userId, role, userAgent, ip) session, err := u.sessionRepo.CreateSession(ctx, creation)
if err != nil { if err != nil {
return nil, pkg.Wrap(nil, err, op, "cannot create session") return nil, pkg.Wrap(nil, err, op, "cannot create session")
} }
@ -253,21 +253,22 @@ func (u *UseCase) DeleteAllSessions(ctx context.Context, userId int32) error {
return nil return nil
} }
func (u *UseCase) ListUsers(ctx context.Context, page int32, pageSize int32) ([]*models.User, int32, error) { func (u *UseCase) ListUsers(ctx context.Context, filters models.UsersListFilters) (*models.UsersList, error) {
const op = "UseCase.ListUsers" const op = "UseCase.ListUsers"
token, ok := ctx.Value(TokenKey).(*models.JWT) token, ok := ctx.Value(TokenKey).(*models.JWT)
if !ok { if !ok {
return nil, 0, pkg.Wrap(pkg.ErrUnauthenticated, nil, op, "no token in context") return nil, pkg.Wrap(pkg.ErrUnauthenticated, nil, op, "no token in context")
} }
if !token.Role.HasPermission(models.Read, models.ResourceAnotherUser) { if !token.Role.HasPermission(models.Read, models.ResourceAnotherUser) {
return nil, 0, pkg.Wrap(pkg.NoPermission, nil, op, "no permission") return nil, pkg.Wrap(pkg.NoPermission, nil, op, "no permission")
} }
usersList, count, err := u.userRepo.C().ListUsers(ctx, page, pageSize) usersList, err := u.userRepo.C().ListUsers(ctx, filters)
if err != nil { if err != nil {
return nil, 0, pkg.Wrap(nil, err, op, "can't list users") return nil, pkg.Wrap(nil, err, op, "can't list users")
} }
return usersList, count, nil
return usersList, nil
} }

2
proto

@ -1 +1 @@
Subproject commit c1b7fd7a2d32678641ebd3acfe3d5b2eca5d0c72 Subproject commit 5c14d329bdff30eba7526f8f4e0d941ad454cb27