package rest import ( "encoding/base64" "errors" userv1 "git.sch9.ru/new_gate/ms-auth/contracts/user/v1" "git.sch9.ru/new_gate/ms-auth/internal/models" "git.sch9.ru/new_gate/ms-auth/internal/users" "git.sch9.ru/new_gate/ms-auth/pkg" "github.com/gofiber/fiber/v2" "github.com/golang-jwt/jwt/v4" "strings" "time" ) type UserHandlers struct { userUC users.UseCase jwtSecret string } func NewUserHandlers(userUC users.UseCase, jwtSecret string) *UserHandlers { return &UserHandlers{ userUC: userUC, jwtSecret: jwtSecret, } } func (h *UserHandlers) Login(c *fiber.Ctx) error { authHeader := c.Get("Authorization", "") if authHeader == "" { return c.SendStatus(fiber.StatusUnauthorized) } username, pwd, err := parseBasicAuth(authHeader) if err != nil { return c.SendStatus(fiber.StatusUnauthorized) } ctx := c.Context() user, err := h.userUC.ReadUserByUsername(ctx, username) if err != nil { if errors.Is(err, pkg.ErrNotFound) { return c.SendStatus(fiber.StatusUnauthorized) } return c.SendStatus(pkg.ToREST(err)) } if !user.IsSamePwd(pwd) { return c.SendStatus(fiber.StatusUnauthorized) } userAgent := c.Get("User-Agent", "") ip := c.IP() session, err := h.userUC.CreateSession(ctx, &models.SessionCreation{ UserId: user.Id, Role: user.Role, UserAgent: userAgent, Ip: ip, }) if err != nil { return c.SendStatus(pkg.ToREST(err)) } claims := jwt.NewWithClaims(jwt.SigningMethodHS256, models.JWT{ SessionId: session.Id, UserId: user.Id, Role: user.Role, ExpiresAt: session.ExpiresAt.Unix(), IssuedAt: time.Now().Unix(), NotBefore: time.Now().Unix(), Permissions: models.Grants[user.Role.String()], }) token, err := claims.SignedString([]byte(h.jwtSecret)) if err != nil { return c.SendStatus(fiber.StatusInternalServerError) } c.Set("Authorization", "Bearer "+token) return c.SendStatus(fiber.StatusOK) } func (h *UserHandlers) Refresh(c *fiber.Ctx) error { token, ok := c.Locals(TokenKey).(*models.JWT) if !ok { return c.SendStatus(fiber.StatusUnauthorized) } err := h.userUC.UpdateSession(c.Context(), token.SessionId) if err != nil { return c.SendStatus(pkg.ToREST(err)) } return c.SendStatus(fiber.StatusOK) } func (h *UserHandlers) Logout(c *fiber.Ctx) error { token, ok := c.Locals(TokenKey).(*models.JWT) if !ok { return c.SendStatus(fiber.StatusUnauthorized) } err := h.userUC.DeleteSession(c.Context(), token.SessionId) if err != nil { return c.SendStatus(pkg.ToREST(err)) } return c.SendStatus(fiber.StatusOK) } func (h *UserHandlers) CompleteLogout(c *fiber.Ctx) error { token, ok := c.Locals(TokenKey).(*models.JWT) if !ok { return c.SendStatus(fiber.StatusUnauthorized) } ctx := c.Context() err := h.userUC.DeleteAllSessions(ctx, token.UserId) if err != nil { return c.SendStatus(pkg.ToREST(err)) } return c.SendStatus(fiber.StatusOK) } func (h *UserHandlers) Verify(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusNotImplemented) } func (h *UserHandlers) CreateUser(c *fiber.Ctx) error { ctx := c.Context() var req = &userv1.CreateUserRequest{} err := c.BodyParser(req) if err != nil { return c.SendStatus(fiber.StatusBadRequest) } id, err := h.userUC.CreateUser( ctx, req.Username, req.Password, models.RoleStudent, ) if err != nil { return c.SendStatus(pkg.ToREST(err)) } return c.JSON(userv1.CreateUserResponse{Id: id}) } func (h *UserHandlers) GetUser(c *fiber.Ctx, id int32) error { user, err := h.userUC.ReadUserById(c.Context(), id) if err != nil { return c.SendStatus(pkg.ToREST(err)) } return c.JSON(userv1.GetUserResponse{ User: U2U(*user), }) } func (h *UserHandlers) UpdateUser(c *fiber.Ctx, id int32) error { var req = &userv1.UpdateUserRequest{} err := c.BodyParser(req) if err != nil { return c.SendStatus(fiber.StatusBadRequest) } err = h.userUC.UpdateUser(c.Context(), id, &models.UserUpdate{ Username: req.Username, Role: I32P2RP(req.Role), }) if err != nil { return c.SendStatus(pkg.ToREST(err)) } return c.SendStatus(fiber.StatusOK) } func (h *UserHandlers) DeleteUser(c *fiber.Ctx, id int32) error { ctx := c.Context() err := h.userUC.DeleteUser(ctx, id) if err != nil { return c.SendStatus(pkg.ToREST(err)) } return c.SendStatus(fiber.StatusOK) } func (h *UserHandlers) ListUsers(c *fiber.Ctx, params userv1.ListUsersParams) error { usersList, err := h.userUC.ListUsers(c.Context(), models.UsersListFilters{ PageSize: params.PageSize, Page: params.Page, }) if err != nil { return c.SendStatus(pkg.ToREST(err)) } resp := userv1.ListUsersResponse{ Users: make([]userv1.User, len(usersList.Users)), Pagination: P2P(usersList.Pagination), } for i, user := range usersList.Users { resp.Users[i] = U2U(*user) } return c.JSON(resp) } func (h *UserHandlers) ListSessions(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusNotImplemented) } func parseBasicAuth(header string) (string, string, error) { const ( op = "parseBasicAuth" msg = "invalid auth header" ) authParts := strings.Split(header, " ") if len(authParts) != 2 || strings.ToLower(authParts[0]) != "basic" { return "", "", pkg.Wrap(pkg.ErrUnauthenticated, nil, op, msg) } decodedAuth, err := base64.StdEncoding.DecodeString(authParts[1]) if err != nil { return "", "", pkg.Wrap(pkg.ErrUnauthenticated, nil, op, msg) } authParts = strings.Split(string(decodedAuth), ":") if len(authParts) != 2 { return "", "", pkg.Wrap(pkg.ErrUnauthenticated, nil, op, msg) } return authParts[0], authParts[1], nil } func I32P2RP(i *int32) *models.Role { if i == nil { return nil } ii := models.Role(*i) return &ii } func P2P(p models.Pagination) userv1.Pagination { return userv1.Pagination{ Page: p.Page, Total: p.Total, } } func U2U(u models.User) userv1.User { return userv1.User{ Id: u.Id, Username: u.Username, Role: int32(u.Role), CreatedAt: u.CreatedAt, ModifiedAt: u.ModifiedAt, } }