271 lines
5.9 KiB
Go
271 lines
5.9 KiB
Go
package rest
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"errors"
|
|
userv1 "git.sch9.ru/new_gate/ms-auth/contracts/user/v1"
|
|
"git.sch9.ru/new_gate/ms-auth/internal/models"
|
|
"git.sch9.ru/new_gate/ms-auth/internal/users"
|
|
"git.sch9.ru/new_gate/ms-auth/pkg"
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/golang-jwt/jwt/v4"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type UserHandlers struct {
|
|
userUC users.UseCase
|
|
|
|
jwtSecret string
|
|
}
|
|
|
|
func NewUserHandlers(userUC users.UseCase, jwtSecret string) *UserHandlers {
|
|
return &UserHandlers{
|
|
userUC: userUC,
|
|
jwtSecret: jwtSecret,
|
|
}
|
|
}
|
|
|
|
func (h *UserHandlers) Login(c *fiber.Ctx) error {
|
|
authHeader := c.Get("Authorization", "")
|
|
if authHeader == "" {
|
|
return c.SendStatus(fiber.StatusUnauthorized)
|
|
}
|
|
|
|
username, pwd, err := parseBasicAuth(authHeader)
|
|
if err != nil {
|
|
return c.SendStatus(fiber.StatusUnauthorized)
|
|
}
|
|
|
|
ctx := c.Context()
|
|
|
|
user, err := h.userUC.ReadUserByUsername(ctx, username)
|
|
if err != nil {
|
|
if errors.Is(err, pkg.ErrNotFound) {
|
|
return c.SendStatus(fiber.StatusUnauthorized)
|
|
}
|
|
|
|
return c.SendStatus(pkg.ToREST(err))
|
|
}
|
|
|
|
if !user.IsSamePwd(pwd) {
|
|
return c.SendStatus(fiber.StatusUnauthorized)
|
|
}
|
|
|
|
userAgent := c.Get("User-Agent", "")
|
|
ip := c.IP()
|
|
|
|
session, err := h.userUC.CreateSession(ctx, &models.SessionCreation{
|
|
UserId: user.Id,
|
|
Role: user.Role,
|
|
UserAgent: userAgent,
|
|
Ip: ip,
|
|
})
|
|
if err != nil {
|
|
return c.SendStatus(pkg.ToREST(err))
|
|
}
|
|
|
|
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, models.JWT{
|
|
SessionId: session.Id,
|
|
UserId: user.Id,
|
|
Role: user.Role,
|
|
ExpiresAt: session.ExpiresAt.Unix(),
|
|
IssuedAt: time.Now().Unix(),
|
|
NotBefore: time.Now().Unix(),
|
|
Permissions: models.Grants[user.Role.String()],
|
|
})
|
|
|
|
token, err := claims.SignedString([]byte(h.jwtSecret))
|
|
if err != nil {
|
|
return c.SendStatus(fiber.StatusInternalServerError)
|
|
}
|
|
|
|
c.Set("Authorization", "Bearer "+token)
|
|
|
|
return c.SendStatus(fiber.StatusOK)
|
|
}
|
|
|
|
func (h *UserHandlers) Refresh(c *fiber.Ctx) error {
|
|
token, ok := c.Locals(TokenKey).(*models.JWT)
|
|
if !ok {
|
|
return c.SendStatus(fiber.StatusUnauthorized)
|
|
}
|
|
|
|
err := h.userUC.UpdateSession(c.Context(), token.SessionId)
|
|
if err != nil {
|
|
return c.SendStatus(pkg.ToREST(err))
|
|
}
|
|
|
|
return c.SendStatus(fiber.StatusOK)
|
|
}
|
|
|
|
func (h *UserHandlers) Logout(c *fiber.Ctx) error {
|
|
token, ok := c.Locals(TokenKey).(*models.JWT)
|
|
if !ok {
|
|
return c.SendStatus(fiber.StatusUnauthorized)
|
|
}
|
|
|
|
err := h.userUC.DeleteSession(c.Context(), token.SessionId)
|
|
if err != nil {
|
|
return c.SendStatus(pkg.ToREST(err))
|
|
}
|
|
|
|
return c.SendStatus(fiber.StatusOK)
|
|
}
|
|
|
|
func (h *UserHandlers) CompleteLogout(c *fiber.Ctx) error {
|
|
token, ok := c.Locals(TokenKey).(*models.JWT)
|
|
if !ok {
|
|
return c.SendStatus(fiber.StatusUnauthorized)
|
|
}
|
|
|
|
ctx := c.Context()
|
|
|
|
err := h.userUC.DeleteAllSessions(ctx, token.UserId)
|
|
if err != nil {
|
|
return c.SendStatus(pkg.ToREST(err))
|
|
}
|
|
|
|
return c.SendStatus(fiber.StatusOK)
|
|
}
|
|
|
|
func (h *UserHandlers) Verify(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusNotImplemented)
|
|
}
|
|
|
|
func (h *UserHandlers) CreateUser(c *fiber.Ctx) error {
|
|
ctx := c.Context()
|
|
|
|
var req = &userv1.CreateUserRequest{}
|
|
|
|
err := c.BodyParser(req)
|
|
if err != nil {
|
|
return c.SendStatus(fiber.StatusBadRequest)
|
|
}
|
|
|
|
id, err := h.userUC.CreateUser(
|
|
ctx,
|
|
req.Username,
|
|
req.Password,
|
|
models.RoleStudent,
|
|
)
|
|
if err != nil {
|
|
return c.SendStatus(pkg.ToREST(err))
|
|
}
|
|
|
|
return c.JSON(userv1.CreateUserResponse{Id: id})
|
|
}
|
|
|
|
func (h *UserHandlers) GetUser(c *fiber.Ctx, id int32) error {
|
|
user, err := h.userUC.ReadUserById(c.Context(), id)
|
|
if err != nil {
|
|
return c.SendStatus(pkg.ToREST(err))
|
|
}
|
|
|
|
return c.JSON(userv1.GetUserResponse{
|
|
User: U2U(*user),
|
|
})
|
|
}
|
|
|
|
func (h *UserHandlers) UpdateUser(c *fiber.Ctx, id int32) error {
|
|
var req = &userv1.UpdateUserRequest{}
|
|
err := c.BodyParser(req)
|
|
if err != nil {
|
|
return c.SendStatus(fiber.StatusBadRequest)
|
|
}
|
|
|
|
err = h.userUC.UpdateUser(c.Context(), id, &models.UserUpdate{
|
|
Username: req.Username,
|
|
Role: I32P2RP(req.Role),
|
|
})
|
|
if err != nil {
|
|
return c.SendStatus(pkg.ToREST(err))
|
|
}
|
|
|
|
return c.SendStatus(fiber.StatusOK)
|
|
}
|
|
|
|
func (h *UserHandlers) DeleteUser(c *fiber.Ctx, id int32) error {
|
|
ctx := c.Context()
|
|
|
|
err := h.userUC.DeleteUser(ctx, id)
|
|
if err != nil {
|
|
return c.SendStatus(pkg.ToREST(err))
|
|
}
|
|
|
|
return c.SendStatus(fiber.StatusOK)
|
|
}
|
|
|
|
func (h *UserHandlers) ListUsers(c *fiber.Ctx, params userv1.ListUsersParams) error {
|
|
usersList, err := h.userUC.ListUsers(c.Context(), models.UsersListFilters{
|
|
PageSize: params.PageSize,
|
|
Page: params.Page,
|
|
})
|
|
if err != nil {
|
|
return c.SendStatus(pkg.ToREST(err))
|
|
}
|
|
|
|
resp := userv1.ListUsersResponse{
|
|
Users: make([]userv1.User, len(usersList.Users)),
|
|
Pagination: P2P(usersList.Pagination),
|
|
}
|
|
|
|
for i, user := range usersList.Users {
|
|
resp.Users[i] = U2U(*user)
|
|
}
|
|
|
|
return c.JSON(resp)
|
|
}
|
|
|
|
func (h *UserHandlers) ListSessions(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusNotImplemented)
|
|
}
|
|
|
|
func parseBasicAuth(header string) (string, string, error) {
|
|
const (
|
|
op = "parseBasicAuth"
|
|
msg = "invalid auth header"
|
|
)
|
|
|
|
authParts := strings.Split(header, " ")
|
|
if len(authParts) != 2 || strings.ToLower(authParts[0]) != "basic" {
|
|
return "", "", pkg.Wrap(pkg.ErrUnauthenticated, nil, op, msg)
|
|
}
|
|
|
|
decodedAuth, err := base64.StdEncoding.DecodeString(authParts[1])
|
|
if err != nil {
|
|
return "", "", pkg.Wrap(pkg.ErrUnauthenticated, nil, op, msg)
|
|
}
|
|
|
|
authParts = strings.Split(string(decodedAuth), ":")
|
|
if len(authParts) != 2 {
|
|
return "", "", pkg.Wrap(pkg.ErrUnauthenticated, nil, op, msg)
|
|
}
|
|
|
|
return authParts[0], authParts[1], nil
|
|
}
|
|
|
|
func I32P2RP(i *int32) *models.Role {
|
|
if i == nil {
|
|
return nil
|
|
}
|
|
ii := models.Role(*i)
|
|
return &ii
|
|
}
|
|
|
|
func P2P(p models.Pagination) userv1.Pagination {
|
|
return userv1.Pagination{
|
|
Page: p.Page,
|
|
Total: p.Total,
|
|
}
|
|
}
|
|
|
|
func U2U(u models.User) userv1.User {
|
|
return userv1.User{
|
|
Id: u.Id,
|
|
Username: u.Username,
|
|
Role: int32(u.Role),
|
|
CreatedAt: u.CreatedAt,
|
|
ModifiedAt: u.ModifiedAt,
|
|
}
|
|
}
|