From 51c5b4943c8eb1667ed6bb9fd10edaac11b9b0a9 Mon Sep 17 00:00:00 2001 From: Vyacheslav1557 Date: Fri, 28 Mar 2025 16:33:13 +0500 Subject: [PATCH] refactor(user): refactor api --- internal/models/pagination.go | 6 + internal/models/session.go | 7 + internal/models/user.go | 37 +++-- internal/users/delivery/rest/handlers.go | 127 ++++++++++-------- internal/users/repository.go | 6 +- internal/users/repository/pg_repository.go | 34 ++--- .../users/repository/valkey_repository.go | 12 +- internal/users/usecase.go | 6 +- internal/users/usecase/usecase.go | 21 +-- proto | 2 +- 10 files changed, 156 insertions(+), 102 deletions(-) create mode 100644 internal/models/pagination.go diff --git a/internal/models/pagination.go b/internal/models/pagination.go new file mode 100644 index 0000000..c2fd852 --- /dev/null +++ b/internal/models/pagination.go @@ -0,0 +1,6 @@ +package models + +type Pagination struct { + Page int32 `json:"page"` + Total int32 `json:"total"` +} diff --git a/internal/models/session.go b/internal/models/session.go index 28e8ab9..d8ba2e7 100644 --- a/internal/models/session.go +++ b/internal/models/session.go @@ -38,6 +38,13 @@ func (s Session) Valid() error { return nil } +type SessionCreation struct { + UserId int32 + Role Role + UserAgent string + Ip string +} + type JWT struct { SessionId string `json:"session_id"` UserId int32 `json:"user_id"` diff --git a/internal/models/user.go b/internal/models/user.go index b23cd84..f6c8450 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -10,6 +10,34 @@ import ( type Role int32 +type User struct { + Id int32 `db:"id"` + Username string `db:"username"` + HashedPassword string `db:"hashed_pwd"` + CreatedAt time.Time `db:"created_at"` + ModifiedAt time.Time `db:"modified_at"` + Role Role `db:"role"` +} + +type UsersListFilters struct { + PageSize int32 + Page int32 +} + +func (f UsersListFilters) Offset() int32 { + return (f.Page - 1) * f.PageSize +} + +type UsersList struct { + Users []*User + Pagination Pagination +} + +type UserUpdate struct { + Username *string + Role *Role +} + const ( RoleGuest Role = -1 RoleStudent Role = 0 @@ -96,15 +124,6 @@ allow if { } ` -type User struct { - Id int32 `db:"id"` - Username string `db:"username"` - HashedPassword string `db:"hashed_pwd"` - CreatedAt time.Time `db:"created_at"` - ModifiedAt time.Time `db:"modified_at"` - Role Role `db:"role"` -} - func (user *User) MarshalJSON() ([]byte, error) { m := map[string]interface{}{ "id": user.Id, diff --git a/internal/users/delivery/rest/handlers.go b/internal/users/delivery/rest/handlers.go index ab333c0..40f085c 100644 --- a/internal/users/delivery/rest/handlers.go +++ b/internal/users/delivery/rest/handlers.go @@ -27,31 +27,19 @@ func NewUserHandlers(userUC users.UseCase, jwtSecret string) *UserHandlers { } func (h *UserHandlers) Login(c *fiber.Ctx) error { - const op = "UserHandlers.Login" - authHeader := c.Get("Authorization", "") if authHeader == "" { return c.SendStatus(fiber.StatusUnauthorized) } - authParts := strings.Split(authHeader, " ") - if len(authParts) != 2 || strings.ToLower(authParts[0]) != "basic" { - return c.SendStatus(fiber.StatusUnauthorized) - } - - decodedAuth, err := base64.StdEncoding.DecodeString(authParts[1]) + username, pwd, err := parseBasicAuth(authHeader) if err != nil { return c.SendStatus(fiber.StatusUnauthorized) } - authParts = strings.Split(string(decodedAuth), ":") - if len(authParts) != 2 { - return c.SendStatus(fiber.StatusUnauthorized) - } - ctx := c.Context() - user, err := h.userUC.ReadUserByUsername(ctx, authParts[0]) + user, err := h.userUC.ReadUserByUsername(ctx, username) if err != nil { if errors.Is(err, pkg.ErrNotFound) { return c.SendStatus(fiber.StatusUnauthorized) @@ -60,14 +48,19 @@ func (h *UserHandlers) Login(c *fiber.Ctx) error { return c.SendStatus(pkg.ToREST(err)) } - if !user.IsSamePwd(authParts[1]) { + if !user.IsSamePwd(pwd) { return c.SendStatus(fiber.StatusUnauthorized) } userAgent := c.Get("User-Agent", "") ip := c.IP() - session, err := h.userUC.CreateSession(ctx, user.Id, user.Role, userAgent, ip) + session, err := h.userUC.CreateSession(ctx, &models.SessionCreation{ + UserId: user.Id, + Role: user.Role, + UserAgent: userAgent, + Ip: ip, + }) if err != nil { return c.SendStatus(pkg.ToREST(err)) } @@ -93,8 +86,6 @@ func (h *UserHandlers) Login(c *fiber.Ctx) error { } func (h *UserHandlers) Refresh(c *fiber.Ctx) error { - const op = "UserHandlers.Refresh" - token, ok := c.Locals(TokenKey).(*models.JWT) if !ok { return c.SendStatus(fiber.StatusUnauthorized) @@ -109,8 +100,6 @@ func (h *UserHandlers) Refresh(c *fiber.Ctx) error { } func (h *UserHandlers) Logout(c *fiber.Ctx) error { - const op = "UserHandlers.Logout" - token, ok := c.Locals(TokenKey).(*models.JWT) if !ok { return c.SendStatus(fiber.StatusUnauthorized) @@ -125,8 +114,6 @@ func (h *UserHandlers) Logout(c *fiber.Ctx) error { } func (h *UserHandlers) CompleteLogout(c *fiber.Ctx) error { - const op = "UserHandlers.CompleteLogout" - token, ok := c.Locals(TokenKey).(*models.JWT) if !ok { return c.SendStatus(fiber.StatusUnauthorized) @@ -143,14 +130,10 @@ func (h *UserHandlers) CompleteLogout(c *fiber.Ctx) error { } func (h *UserHandlers) Verify(c *fiber.Ctx) error { - const op = "UserHandlers.Verify" - return c.SendStatus(fiber.StatusNotImplemented) } func (h *UserHandlers) CreateUser(c *fiber.Ctx) error { - const op = "UserHandlers.CreateUser" - ctx := c.Context() var req = &userv1.CreateUserRequest{} @@ -170,35 +153,31 @@ func (h *UserHandlers) CreateUser(c *fiber.Ctx) error { return c.SendStatus(pkg.ToREST(err)) } - return c.JSON(map[string]interface{}{ - "id": id, - }) + return c.JSON(userv1.CreateUserResponse{Id: id}) } func (h *UserHandlers) GetUser(c *fiber.Ctx, id int32) error { - const op = "UserHandlers.GetUser" - user, err := h.userUC.ReadUserById(c.Context(), id) if err != nil { return c.SendStatus(pkg.ToREST(err)) } - return c.JSON(map[string]interface{}{ - "user": user, + return c.JSON(userv1.GetUserResponse{ + User: U2U(*user), }) } func (h *UserHandlers) UpdateUser(c *fiber.Ctx, id int32) error { - const op = "UserHandlers.UpdateUser" - var req = &userv1.UpdateUserRequest{} - err := c.BodyParser(req) if err != nil { return c.SendStatus(fiber.StatusBadRequest) } - err = h.userUC.UpdateUser(c.Context(), id, req.Username, int32PtrToRolePtr(req.Role)) + err = h.userUC.UpdateUser(c.Context(), id, &models.UserUpdate{ + Username: req.Username, + Role: I32P2RP(req.Role), + }) if err != nil { return c.SendStatus(pkg.ToREST(err)) } @@ -207,8 +186,6 @@ func (h *UserHandlers) UpdateUser(c *fiber.Ctx, id int32) error { } func (h *UserHandlers) DeleteUser(c *fiber.Ctx, id int32) error { - const op = "UserHandlers.DeleteUser" - ctx := c.Context() err := h.userUC.DeleteUser(ctx, id) @@ -220,35 +197,75 @@ func (h *UserHandlers) DeleteUser(c *fiber.Ctx, id int32) error { } func (h *UserHandlers) ListUsers(c *fiber.Ctx, params userv1.ListUsersParams) error { - const op = "UserHandlers.ListUsers" - - usersList, count, err := h.userUC.ListUsers(c.Context(), params.Page, params.PageSize) + usersList, err := h.userUC.ListUsers(c.Context(), models.UsersListFilters{ + PageSize: params.PageSize, + Page: params.Page, + }) if err != nil { return c.SendStatus(pkg.ToREST(err)) } - return c.JSON(map[string]interface{}{ - "users": usersList, - "page": params.Page, - "max_page": func() int32 { - if count%params.PageSize == 0 { - return count / params.PageSize - } - return count/params.PageSize + 1 - }(), - }) + resp := userv1.ListUsersResponse{ + Users: make([]userv1.User, len(usersList.Users)), + Pagination: P2P(usersList.Pagination), + } + + for i, user := range usersList.Users { + resp.Users[i] = U2U(*user) + } + + return c.JSON(resp) } func (h *UserHandlers) ListSessions(c *fiber.Ctx) error { - const op = "UserHandlers.ListSessions" - return c.SendStatus(fiber.StatusNotImplemented) } -func int32PtrToRolePtr(i *int32) *models.Role { +func parseBasicAuth(header string) (string, string, error) { + const ( + op = "parseBasicAuth" + msg = "invalid auth header" + ) + + authParts := strings.Split(header, " ") + if len(authParts) != 2 || strings.ToLower(authParts[0]) != "basic" { + return "", "", pkg.Wrap(pkg.ErrUnauthenticated, nil, op, msg) + } + + decodedAuth, err := base64.StdEncoding.DecodeString(authParts[1]) + if err != nil { + return "", "", pkg.Wrap(pkg.ErrUnauthenticated, nil, op, msg) + } + + authParts = strings.Split(string(decodedAuth), ":") + if len(authParts) != 2 { + return "", "", pkg.Wrap(pkg.ErrUnauthenticated, nil, op, msg) + } + + return authParts[0], authParts[1], nil +} + +func I32P2RP(i *int32) *models.Role { if i == nil { return nil } ii := models.Role(*i) return &ii } + +func P2P(p models.Pagination) userv1.Pagination { + return userv1.Pagination{ + Page: p.Page, + Total: p.Total, + } +} + +func U2U(u models.User) userv1.User { + return userv1.User{ + Id: u.Id, + Username: u.Username, + Role: int32(u.Role), + CreatedAt: u.CreatedAt, + ModifiedAt: u.ModifiedAt, + } +} diff --git a/internal/users/repository.go b/internal/users/repository.go index 09868de..1cea642 100644 --- a/internal/users/repository.go +++ b/internal/users/repository.go @@ -9,9 +9,9 @@ type Caller interface { CreateUser(ctx context.Context, username string, password string, role models.Role) (int32, error) ReadUserByUsername(ctx context.Context, username string) (*models.User, error) ReadUserById(ctx context.Context, id int32) (*models.User, error) - UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error + UpdateUser(ctx context.Context, id int32, update *models.UserUpdate) error DeleteUser(ctx context.Context, id int32) error - ListUsers(ctx context.Context, page int32, pageSize int32) ([]*models.User, int32, error) + ListUsers(ctx context.Context, filters models.UsersListFilters) (*models.UsersList, error) } type TxCaller interface { @@ -26,7 +26,7 @@ type PgRepository interface { } type ValkeyRepository interface { - CreateSession(ctx context.Context, userId int32, role models.Role, userAgent, ip string) (*models.Session, error) + CreateSession(ctx context.Context, creation *models.SessionCreation) (*models.Session, error) ReadSession(ctx context.Context, sessionId string) (*models.Session, error) UpdateSession(ctx context.Context, sessionId string) error DeleteSession(ctx context.Context, sessionId string) error diff --git a/internal/users/repository/pg_repository.go b/internal/users/repository/pg_repository.go index ef70fc1..0606cbd 100644 --- a/internal/users/repository/pg_repository.go +++ b/internal/users/repository/pg_repository.go @@ -164,12 +164,12 @@ SET username = COALESCE(?, trim(lower(username))), WHERE id = ? ` -func (c *Caller) UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error { +func (c *Caller) UpdateUser(ctx context.Context, id int32, update *models.UserUpdate) error { const op = "Caller.UpdateUser" var err error - if username != nil { - if err = ValidUsername(*username); err != nil { + if update.Username != nil { + if err = ValidUsername(*update.Username); err != nil { return pkg.Wrap(pkg.ErrBadInput, err, op, "username validation") } } @@ -178,8 +178,8 @@ func (c *Caller) UpdateUser(ctx context.Context, id int32, username *string, rol _, err = c.db.ExecContext( ctx, query, - username, - role, + update.Username, + update.Role, id, ) @@ -208,29 +208,33 @@ const ( CountUsers = "SELECT COUNT(*) FROM users" ) -func (c *Caller) ListUsers(ctx context.Context, page int32, pageSize int32) ([]*models.User, int32, error) { +func (c *Caller) ListUsers(ctx context.Context, filters models.UsersListFilters) (*models.UsersList, error) { const op = "Caller.ListUsers" - if pageSize > 20 { - return nil, 0, pkg.Wrap(pkg.ErrBadInput, nil, op, "limit > 20") + if filters.PageSize > 20 { + return nil, pkg.Wrap(pkg.ErrBadInput, nil, op, "limit > 20") } - var usersList []*models.User + usersList := &models.UsersList{ + Users: make([]*models.User, 0), + Pagination: models.Pagination{}, + } query := c.db.Rebind(ListUsers) - err := c.db.SelectContext(ctx, &usersList, query, pageSize, (page-1)*pageSize) + err := c.db.SelectContext(ctx, &usersList.Users, query, filters.PageSize, filters.Offset()) if err != nil { - return nil, 0, handlePgErr(err, op) + return nil, handlePgErr(err, op) } query = c.db.Rebind(CountUsers) - var count int32 - err = c.db.GetContext(ctx, &count, query) + err = c.db.GetContext(ctx, &usersList.Pagination.Total, query) if err != nil { - return nil, 0, handlePgErr(err, op) + return nil, handlePgErr(err, op) } - return usersList, count, nil + usersList.Pagination.Page = filters.Page + + return usersList, nil } func handlePgErr(err error, op string) error { diff --git a/internal/users/repository/valkey_repository.go b/internal/users/repository/valkey_repository.go index 18fc949..3fc8e72 100644 --- a/internal/users/repository/valkey_repository.go +++ b/internal/users/repository/valkey_repository.go @@ -32,17 +32,17 @@ func sha256string(s string) string { return hex.EncodeToString(hasher.Sum(nil)) } -func (r *ValkeyRepository) CreateSession(ctx context.Context, userId int32, role models.Role, userAgent, ip string) (*models.Session, error) { +func (r *ValkeyRepository) CreateSession(ctx context.Context, creation *models.SessionCreation) (*models.Session, error) { const op = "ValkeyRepository.CreateSession" session := &models.Session{ Id: uuid.NewString(), - UserId: userId, - Role: role, + UserId: creation.UserId, + Role: creation.Role, CreatedAt: time.Now(), ExpiresAt: time.Now().Add(sessionLifetime), - UserAgent: userAgent, - Ip: ip, + UserAgent: creation.UserAgent, + Ip: creation.Ip, } err := session.Valid() @@ -50,7 +50,7 @@ func (r *ValkeyRepository) CreateSession(ctx context.Context, userId int32, role return nil, pkg.Wrap(pkg.ErrInternal, err, op, "validating session") } - userIdHash := sha256string(strconv.FormatInt(int64(userId), 10)) + userIdHash := sha256string(strconv.FormatInt(int64(creation.UserId), 10)) sessionIdHash := sha256string(session.Id) sessionData, err := json.Marshal(session) diff --git a/internal/users/usecase.go b/internal/users/usecase.go index e72924f..9b29297 100644 --- a/internal/users/usecase.go +++ b/internal/users/usecase.go @@ -9,12 +9,12 @@ type UseCase interface { CreateUser(ctx context.Context, username string, password string, role models.Role) (int32, error) ReadUserById(ctx context.Context, id int32) (*models.User, error) ReadUserByUsername(ctx context.Context, username string) (*models.User, error) - UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error + UpdateUser(ctx context.Context, id int32, update *models.UserUpdate) error DeleteUser(ctx context.Context, id int32) error - CreateSession(ctx context.Context, userId int32, role models.Role, userAgent, ip string) (*models.Session, error) + CreateSession(ctx context.Context, creation *models.SessionCreation) (*models.Session, error) ReadSession(ctx context.Context, sessionId string) (*models.Session, error) UpdateSession(ctx context.Context, sessionId string) error DeleteSession(ctx context.Context, sessionId string) error DeleteAllSessions(ctx context.Context, userId int32) error - ListUsers(ctx context.Context, page int32, pageSize int32) ([]*models.User, int32, error) + ListUsers(ctx context.Context, filters models.UsersListFilters) (*models.UsersList, error) } diff --git a/internal/users/usecase/usecase.go b/internal/users/usecase/usecase.go index 5ddcd67..003eebc 100644 --- a/internal/users/usecase/usecase.go +++ b/internal/users/usecase/usecase.go @@ -85,7 +85,7 @@ func (u *UseCase) ReadUserByUsername(ctx context.Context, username string) (*mod return user, nil } -func (u *UseCase) UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error { +func (u *UseCase) UpdateUser(ctx context.Context, id int32, update *models.UserUpdate) error { const op = "UseCase.UpdateUser" token, ok := ctx.Value(TokenKey).(*models.JWT) @@ -106,7 +106,7 @@ func (u *UseCase) UpdateUser(ctx context.Context, id int32, username *string, ro return pkg.Wrap(nil, err, op, "cannot start transaction") } - err = tx.UpdateUser(ctx, id, username, role) + err = tx.UpdateUser(ctx, id, update) if err != nil { return pkg.Wrap(nil, errors.Join(err, tx.Rollback()), op, "cannot update user") } @@ -161,10 +161,10 @@ func (u *UseCase) DeleteUser(ctx context.Context, id int32) error { } // CreateSession is for login only. There are no permission checks! DO NOT USE IT AS AN ENDPOINT RESPONSE! -func (u *UseCase) CreateSession(ctx context.Context, userId int32, role models.Role, userAgent, ip string) (*models.Session, error) { +func (u *UseCase) CreateSession(ctx context.Context, creation *models.SessionCreation) (*models.Session, error) { const op = "UseCase.CreateSession" - session, err := u.sessionRepo.CreateSession(ctx, userId, role, userAgent, ip) + session, err := u.sessionRepo.CreateSession(ctx, creation) if err != nil { return nil, pkg.Wrap(nil, err, op, "cannot create session") } @@ -253,21 +253,22 @@ func (u *UseCase) DeleteAllSessions(ctx context.Context, userId int32) error { return nil } -func (u *UseCase) ListUsers(ctx context.Context, page int32, pageSize int32) ([]*models.User, int32, error) { +func (u *UseCase) ListUsers(ctx context.Context, filters models.UsersListFilters) (*models.UsersList, error) { const op = "UseCase.ListUsers" token, ok := ctx.Value(TokenKey).(*models.JWT) if !ok { - return nil, 0, pkg.Wrap(pkg.ErrUnauthenticated, nil, op, "no token in context") + return nil, pkg.Wrap(pkg.ErrUnauthenticated, nil, op, "no token in context") } if !token.Role.HasPermission(models.Read, models.ResourceAnotherUser) { - return nil, 0, pkg.Wrap(pkg.NoPermission, nil, op, "no permission") + return nil, pkg.Wrap(pkg.NoPermission, nil, op, "no permission") } - usersList, count, err := u.userRepo.C().ListUsers(ctx, page, pageSize) + usersList, err := u.userRepo.C().ListUsers(ctx, filters) if err != nil { - return nil, 0, pkg.Wrap(nil, err, op, "can't list users") + return nil, pkg.Wrap(nil, err, op, "can't list users") } - return usersList, count, nil + + return usersList, nil } diff --git a/proto b/proto index c1b7fd7..5c14d32 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit c1b7fd7a2d32678641ebd3acfe3d5b2eca5d0c72 +Subproject commit 5c14d329bdff30eba7526f8f4e0d941ad454cb27