refactor(user): refactor api

This commit is contained in:
Vyacheslav1557 2025-03-28 16:33:13 +05:00
parent 2a18e1d42f
commit 51c5b4943c
10 changed files with 156 additions and 102 deletions

View file

@ -0,0 +1,6 @@
package models
type Pagination struct {
Page int32 `json:"page"`
Total int32 `json:"total"`
}

View file

@ -38,6 +38,13 @@ func (s Session) Valid() error {
return nil
}
type SessionCreation struct {
UserId int32
Role Role
UserAgent string
Ip string
}
type JWT struct {
SessionId string `json:"session_id"`
UserId int32 `json:"user_id"`

View file

@ -10,6 +10,34 @@ import (
type Role int32
type User struct {
Id int32 `db:"id"`
Username string `db:"username"`
HashedPassword string `db:"hashed_pwd"`
CreatedAt time.Time `db:"created_at"`
ModifiedAt time.Time `db:"modified_at"`
Role Role `db:"role"`
}
type UsersListFilters struct {
PageSize int32
Page int32
}
func (f UsersListFilters) Offset() int32 {
return (f.Page - 1) * f.PageSize
}
type UsersList struct {
Users []*User
Pagination Pagination
}
type UserUpdate struct {
Username *string
Role *Role
}
const (
RoleGuest Role = -1
RoleStudent Role = 0
@ -96,15 +124,6 @@ allow if {
}
`
type User struct {
Id int32 `db:"id"`
Username string `db:"username"`
HashedPassword string `db:"hashed_pwd"`
CreatedAt time.Time `db:"created_at"`
ModifiedAt time.Time `db:"modified_at"`
Role Role `db:"role"`
}
func (user *User) MarshalJSON() ([]byte, error) {
m := map[string]interface{}{
"id": user.Id,

View file

@ -27,31 +27,19 @@ func NewUserHandlers(userUC users.UseCase, jwtSecret string) *UserHandlers {
}
func (h *UserHandlers) Login(c *fiber.Ctx) error {
const op = "UserHandlers.Login"
authHeader := c.Get("Authorization", "")
if authHeader == "" {
return c.SendStatus(fiber.StatusUnauthorized)
}
authParts := strings.Split(authHeader, " ")
if len(authParts) != 2 || strings.ToLower(authParts[0]) != "basic" {
return c.SendStatus(fiber.StatusUnauthorized)
}
decodedAuth, err := base64.StdEncoding.DecodeString(authParts[1])
username, pwd, err := parseBasicAuth(authHeader)
if err != nil {
return c.SendStatus(fiber.StatusUnauthorized)
}
authParts = strings.Split(string(decodedAuth), ":")
if len(authParts) != 2 {
return c.SendStatus(fiber.StatusUnauthorized)
}
ctx := c.Context()
user, err := h.userUC.ReadUserByUsername(ctx, authParts[0])
user, err := h.userUC.ReadUserByUsername(ctx, username)
if err != nil {
if errors.Is(err, pkg.ErrNotFound) {
return c.SendStatus(fiber.StatusUnauthorized)
@ -60,14 +48,19 @@ func (h *UserHandlers) Login(c *fiber.Ctx) error {
return c.SendStatus(pkg.ToREST(err))
}
if !user.IsSamePwd(authParts[1]) {
if !user.IsSamePwd(pwd) {
return c.SendStatus(fiber.StatusUnauthorized)
}
userAgent := c.Get("User-Agent", "")
ip := c.IP()
session, err := h.userUC.CreateSession(ctx, user.Id, user.Role, userAgent, ip)
session, err := h.userUC.CreateSession(ctx, &models.SessionCreation{
UserId: user.Id,
Role: user.Role,
UserAgent: userAgent,
Ip: ip,
})
if err != nil {
return c.SendStatus(pkg.ToREST(err))
}
@ -93,8 +86,6 @@ func (h *UserHandlers) Login(c *fiber.Ctx) error {
}
func (h *UserHandlers) Refresh(c *fiber.Ctx) error {
const op = "UserHandlers.Refresh"
token, ok := c.Locals(TokenKey).(*models.JWT)
if !ok {
return c.SendStatus(fiber.StatusUnauthorized)
@ -109,8 +100,6 @@ func (h *UserHandlers) Refresh(c *fiber.Ctx) error {
}
func (h *UserHandlers) Logout(c *fiber.Ctx) error {
const op = "UserHandlers.Logout"
token, ok := c.Locals(TokenKey).(*models.JWT)
if !ok {
return c.SendStatus(fiber.StatusUnauthorized)
@ -125,8 +114,6 @@ func (h *UserHandlers) Logout(c *fiber.Ctx) error {
}
func (h *UserHandlers) CompleteLogout(c *fiber.Ctx) error {
const op = "UserHandlers.CompleteLogout"
token, ok := c.Locals(TokenKey).(*models.JWT)
if !ok {
return c.SendStatus(fiber.StatusUnauthorized)
@ -143,14 +130,10 @@ func (h *UserHandlers) CompleteLogout(c *fiber.Ctx) error {
}
func (h *UserHandlers) Verify(c *fiber.Ctx) error {
const op = "UserHandlers.Verify"
return c.SendStatus(fiber.StatusNotImplemented)
}
func (h *UserHandlers) CreateUser(c *fiber.Ctx) error {
const op = "UserHandlers.CreateUser"
ctx := c.Context()
var req = &userv1.CreateUserRequest{}
@ -170,35 +153,31 @@ func (h *UserHandlers) CreateUser(c *fiber.Ctx) error {
return c.SendStatus(pkg.ToREST(err))
}
return c.JSON(map[string]interface{}{
"id": id,
})
return c.JSON(userv1.CreateUserResponse{Id: id})
}
func (h *UserHandlers) GetUser(c *fiber.Ctx, id int32) error {
const op = "UserHandlers.GetUser"
user, err := h.userUC.ReadUserById(c.Context(), id)
if err != nil {
return c.SendStatus(pkg.ToREST(err))
}
return c.JSON(map[string]interface{}{
"user": user,
return c.JSON(userv1.GetUserResponse{
User: U2U(*user),
})
}
func (h *UserHandlers) UpdateUser(c *fiber.Ctx, id int32) error {
const op = "UserHandlers.UpdateUser"
var req = &userv1.UpdateUserRequest{}
err := c.BodyParser(req)
if err != nil {
return c.SendStatus(fiber.StatusBadRequest)
}
err = h.userUC.UpdateUser(c.Context(), id, req.Username, int32PtrToRolePtr(req.Role))
err = h.userUC.UpdateUser(c.Context(), id, &models.UserUpdate{
Username: req.Username,
Role: I32P2RP(req.Role),
})
if err != nil {
return c.SendStatus(pkg.ToREST(err))
}
@ -207,8 +186,6 @@ func (h *UserHandlers) UpdateUser(c *fiber.Ctx, id int32) error {
}
func (h *UserHandlers) DeleteUser(c *fiber.Ctx, id int32) error {
const op = "UserHandlers.DeleteUser"
ctx := c.Context()
err := h.userUC.DeleteUser(ctx, id)
@ -220,35 +197,75 @@ func (h *UserHandlers) DeleteUser(c *fiber.Ctx, id int32) error {
}
func (h *UserHandlers) ListUsers(c *fiber.Ctx, params userv1.ListUsersParams) error {
const op = "UserHandlers.ListUsers"
usersList, count, err := h.userUC.ListUsers(c.Context(), params.Page, params.PageSize)
usersList, err := h.userUC.ListUsers(c.Context(), models.UsersListFilters{
PageSize: params.PageSize,
Page: params.Page,
})
if err != nil {
return c.SendStatus(pkg.ToREST(err))
}
return c.JSON(map[string]interface{}{
"users": usersList,
"page": params.Page,
"max_page": func() int32 {
if count%params.PageSize == 0 {
return count / params.PageSize
}
return count/params.PageSize + 1
}(),
})
resp := userv1.ListUsersResponse{
Users: make([]userv1.User, len(usersList.Users)),
Pagination: P2P(usersList.Pagination),
}
for i, user := range usersList.Users {
resp.Users[i] = U2U(*user)
}
return c.JSON(resp)
}
func (h *UserHandlers) ListSessions(c *fiber.Ctx) error {
const op = "UserHandlers.ListSessions"
return c.SendStatus(fiber.StatusNotImplemented)
}
func int32PtrToRolePtr(i *int32) *models.Role {
func parseBasicAuth(header string) (string, string, error) {
const (
op = "parseBasicAuth"
msg = "invalid auth header"
)
authParts := strings.Split(header, " ")
if len(authParts) != 2 || strings.ToLower(authParts[0]) != "basic" {
return "", "", pkg.Wrap(pkg.ErrUnauthenticated, nil, op, msg)
}
decodedAuth, err := base64.StdEncoding.DecodeString(authParts[1])
if err != nil {
return "", "", pkg.Wrap(pkg.ErrUnauthenticated, nil, op, msg)
}
authParts = strings.Split(string(decodedAuth), ":")
if len(authParts) != 2 {
return "", "", pkg.Wrap(pkg.ErrUnauthenticated, nil, op, msg)
}
return authParts[0], authParts[1], nil
}
func I32P2RP(i *int32) *models.Role {
if i == nil {
return nil
}
ii := models.Role(*i)
return &ii
}
func P2P(p models.Pagination) userv1.Pagination {
return userv1.Pagination{
Page: p.Page,
Total: p.Total,
}
}
func U2U(u models.User) userv1.User {
return userv1.User{
Id: u.Id,
Username: u.Username,
Role: int32(u.Role),
CreatedAt: u.CreatedAt,
ModifiedAt: u.ModifiedAt,
}
}

View file

@ -9,9 +9,9 @@ type Caller interface {
CreateUser(ctx context.Context, username string, password string, role models.Role) (int32, error)
ReadUserByUsername(ctx context.Context, username string) (*models.User, error)
ReadUserById(ctx context.Context, id int32) (*models.User, error)
UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error
UpdateUser(ctx context.Context, id int32, update *models.UserUpdate) error
DeleteUser(ctx context.Context, id int32) error
ListUsers(ctx context.Context, page int32, pageSize int32) ([]*models.User, int32, error)
ListUsers(ctx context.Context, filters models.UsersListFilters) (*models.UsersList, error)
}
type TxCaller interface {
@ -26,7 +26,7 @@ type PgRepository interface {
}
type ValkeyRepository interface {
CreateSession(ctx context.Context, userId int32, role models.Role, userAgent, ip string) (*models.Session, error)
CreateSession(ctx context.Context, creation *models.SessionCreation) (*models.Session, error)
ReadSession(ctx context.Context, sessionId string) (*models.Session, error)
UpdateSession(ctx context.Context, sessionId string) error
DeleteSession(ctx context.Context, sessionId string) error

View file

@ -164,12 +164,12 @@ SET username = COALESCE(?, trim(lower(username))),
WHERE id = ?
`
func (c *Caller) UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error {
func (c *Caller) UpdateUser(ctx context.Context, id int32, update *models.UserUpdate) error {
const op = "Caller.UpdateUser"
var err error
if username != nil {
if err = ValidUsername(*username); err != nil {
if update.Username != nil {
if err = ValidUsername(*update.Username); err != nil {
return pkg.Wrap(pkg.ErrBadInput, err, op, "username validation")
}
}
@ -178,8 +178,8 @@ func (c *Caller) UpdateUser(ctx context.Context, id int32, username *string, rol
_, err = c.db.ExecContext(
ctx,
query,
username,
role,
update.Username,
update.Role,
id,
)
@ -208,29 +208,33 @@ const (
CountUsers = "SELECT COUNT(*) FROM users"
)
func (c *Caller) ListUsers(ctx context.Context, page int32, pageSize int32) ([]*models.User, int32, error) {
func (c *Caller) ListUsers(ctx context.Context, filters models.UsersListFilters) (*models.UsersList, error) {
const op = "Caller.ListUsers"
if pageSize > 20 {
return nil, 0, pkg.Wrap(pkg.ErrBadInput, nil, op, "limit > 20")
if filters.PageSize > 20 {
return nil, pkg.Wrap(pkg.ErrBadInput, nil, op, "limit > 20")
}
var usersList []*models.User
usersList := &models.UsersList{
Users: make([]*models.User, 0),
Pagination: models.Pagination{},
}
query := c.db.Rebind(ListUsers)
err := c.db.SelectContext(ctx, &usersList, query, pageSize, (page-1)*pageSize)
err := c.db.SelectContext(ctx, &usersList.Users, query, filters.PageSize, filters.Offset())
if err != nil {
return nil, 0, handlePgErr(err, op)
return nil, handlePgErr(err, op)
}
query = c.db.Rebind(CountUsers)
var count int32
err = c.db.GetContext(ctx, &count, query)
err = c.db.GetContext(ctx, &usersList.Pagination.Total, query)
if err != nil {
return nil, 0, handlePgErr(err, op)
return nil, handlePgErr(err, op)
}
return usersList, count, nil
usersList.Pagination.Page = filters.Page
return usersList, nil
}
func handlePgErr(err error, op string) error {

View file

@ -32,17 +32,17 @@ func sha256string(s string) string {
return hex.EncodeToString(hasher.Sum(nil))
}
func (r *ValkeyRepository) CreateSession(ctx context.Context, userId int32, role models.Role, userAgent, ip string) (*models.Session, error) {
func (r *ValkeyRepository) CreateSession(ctx context.Context, creation *models.SessionCreation) (*models.Session, error) {
const op = "ValkeyRepository.CreateSession"
session := &models.Session{
Id: uuid.NewString(),
UserId: userId,
Role: role,
UserId: creation.UserId,
Role: creation.Role,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(sessionLifetime),
UserAgent: userAgent,
Ip: ip,
UserAgent: creation.UserAgent,
Ip: creation.Ip,
}
err := session.Valid()
@ -50,7 +50,7 @@ func (r *ValkeyRepository) CreateSession(ctx context.Context, userId int32, role
return nil, pkg.Wrap(pkg.ErrInternal, err, op, "validating session")
}
userIdHash := sha256string(strconv.FormatInt(int64(userId), 10))
userIdHash := sha256string(strconv.FormatInt(int64(creation.UserId), 10))
sessionIdHash := sha256string(session.Id)
sessionData, err := json.Marshal(session)

View file

@ -9,12 +9,12 @@ type UseCase interface {
CreateUser(ctx context.Context, username string, password string, role models.Role) (int32, error)
ReadUserById(ctx context.Context, id int32) (*models.User, error)
ReadUserByUsername(ctx context.Context, username string) (*models.User, error)
UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error
UpdateUser(ctx context.Context, id int32, update *models.UserUpdate) error
DeleteUser(ctx context.Context, id int32) error
CreateSession(ctx context.Context, userId int32, role models.Role, userAgent, ip string) (*models.Session, error)
CreateSession(ctx context.Context, creation *models.SessionCreation) (*models.Session, error)
ReadSession(ctx context.Context, sessionId string) (*models.Session, error)
UpdateSession(ctx context.Context, sessionId string) error
DeleteSession(ctx context.Context, sessionId string) error
DeleteAllSessions(ctx context.Context, userId int32) error
ListUsers(ctx context.Context, page int32, pageSize int32) ([]*models.User, int32, error)
ListUsers(ctx context.Context, filters models.UsersListFilters) (*models.UsersList, error)
}

View file

@ -85,7 +85,7 @@ func (u *UseCase) ReadUserByUsername(ctx context.Context, username string) (*mod
return user, nil
}
func (u *UseCase) UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error {
func (u *UseCase) UpdateUser(ctx context.Context, id int32, update *models.UserUpdate) error {
const op = "UseCase.UpdateUser"
token, ok := ctx.Value(TokenKey).(*models.JWT)
@ -106,7 +106,7 @@ func (u *UseCase) UpdateUser(ctx context.Context, id int32, username *string, ro
return pkg.Wrap(nil, err, op, "cannot start transaction")
}
err = tx.UpdateUser(ctx, id, username, role)
err = tx.UpdateUser(ctx, id, update)
if err != nil {
return pkg.Wrap(nil, errors.Join(err, tx.Rollback()), op, "cannot update user")
}
@ -161,10 +161,10 @@ func (u *UseCase) DeleteUser(ctx context.Context, id int32) error {
}
// CreateSession is for login only. There are no permission checks! DO NOT USE IT AS AN ENDPOINT RESPONSE!
func (u *UseCase) CreateSession(ctx context.Context, userId int32, role models.Role, userAgent, ip string) (*models.Session, error) {
func (u *UseCase) CreateSession(ctx context.Context, creation *models.SessionCreation) (*models.Session, error) {
const op = "UseCase.CreateSession"
session, err := u.sessionRepo.CreateSession(ctx, userId, role, userAgent, ip)
session, err := u.sessionRepo.CreateSession(ctx, creation)
if err != nil {
return nil, pkg.Wrap(nil, err, op, "cannot create session")
}
@ -253,21 +253,22 @@ func (u *UseCase) DeleteAllSessions(ctx context.Context, userId int32) error {
return nil
}
func (u *UseCase) ListUsers(ctx context.Context, page int32, pageSize int32) ([]*models.User, int32, error) {
func (u *UseCase) ListUsers(ctx context.Context, filters models.UsersListFilters) (*models.UsersList, error) {
const op = "UseCase.ListUsers"
token, ok := ctx.Value(TokenKey).(*models.JWT)
if !ok {
return nil, 0, pkg.Wrap(pkg.ErrUnauthenticated, nil, op, "no token in context")
return nil, pkg.Wrap(pkg.ErrUnauthenticated, nil, op, "no token in context")
}
if !token.Role.HasPermission(models.Read, models.ResourceAnotherUser) {
return nil, 0, pkg.Wrap(pkg.NoPermission, nil, op, "no permission")
return nil, pkg.Wrap(pkg.NoPermission, nil, op, "no permission")
}
usersList, count, err := u.userRepo.C().ListUsers(ctx, page, pageSize)
usersList, err := u.userRepo.C().ListUsers(ctx, filters)
if err != nil {
return nil, 0, pkg.Wrap(nil, err, op, "can't list users")
return nil, pkg.Wrap(nil, err, op, "can't list users")
}
return usersList, count, nil
return usersList, nil
}

2
proto

@ -1 +1 @@
Subproject commit c1b7fd7a2d32678641ebd3acfe3d5b2eca5d0c72
Subproject commit 5c14d329bdff30eba7526f8f4e0d941ad454cb27