package rest

import (
	"context"
	"encoding/base64"
	"git.sch9.ru/new_gate/ms-tester/internal/auth"
	"git.sch9.ru/new_gate/ms-tester/internal/models"
	"git.sch9.ru/new_gate/ms-tester/pkg"
	"github.com/gofiber/fiber/v2"
	"github.com/golang-jwt/jwt/v4"
	"strings"
	"time"
)

type Handlers struct {
	authUC    auth.UseCase
	jwtSecret string
}

func NewHandlers(authUC auth.UseCase, jwtSecret string) *Handlers {
	return &Handlers{
		authUC:    authUC,
		jwtSecret: jwtSecret,
	}
}

const (
	sessionKey = "session"
)

func sessionFromCtx(ctx context.Context) (*models.Session, error) {
	const op = "sessionFromCtx"

	session, ok := ctx.Value(sessionKey).(*models.Session)
	if !ok {
		return nil, pkg.Wrap(pkg.ErrUnauthenticated, nil, op, "")
	}

	return session, nil
}

func (h *Handlers) ListSessions(c *fiber.Ctx) error {
	return c.SendStatus(fiber.StatusNotImplemented)
}

func (h *Handlers) Terminate(c *fiber.Ctx) error {
	ctx := c.Context()

	session, err := sessionFromCtx(ctx)
	if err != nil {
		return c.SendStatus(pkg.ToREST(err))
	}

	err = h.authUC.Terminate(ctx, session.UserId)
	if err != nil {
		return c.SendStatus(pkg.ToREST(err))
	}

	return c.SendStatus(fiber.StatusOK)
}

func (h *Handlers) Login(c *fiber.Ctx) error {
	authHeader := c.Get("Authorization", "")
	if authHeader == "" {
		return c.SendStatus(fiber.StatusUnauthorized)
	}

	username, pwd, err := parseBasicAuth(authHeader)
	if err != nil {
		return c.SendStatus(fiber.StatusUnauthorized)
	}

	credentials := &models.Credentials{
		Username: strings.ToLower(username),
		Password: pwd,
	}
	device := &models.Device{
		Ip:       c.IP(),
		UseAgent: c.Get("User-Agent", ""),
	}

	ctx := c.Context()

	session, err := h.authUC.Login(ctx, credentials, device)
	if err != nil {
		return c.SendStatus(pkg.ToREST(err))
	}

	claims := jwt.NewWithClaims(jwt.SigningMethodHS256, models.JWT{
		SessionId: session.Id,
		UserId:    session.UserId,
		Role:      session.Role,
		IssuedAt:  time.Now().Unix(),
	})

	token, err := claims.SignedString([]byte(h.jwtSecret))
	if err != nil {
		return c.SendStatus(fiber.StatusInternalServerError)
	}

	c.Set("Authorization", "Bearer "+token)

	return c.SendStatus(fiber.StatusOK)
}

func (h *Handlers) Logout(c *fiber.Ctx) error {
	ctx := c.Context()

	session, err := sessionFromCtx(ctx)
	if err != nil {
		return c.SendStatus(pkg.ToREST(err))
	}

	err = h.authUC.Logout(c.Context(), session.Id)
	if err != nil {
		return c.SendStatus(pkg.ToREST(err))
	}

	return c.SendStatus(fiber.StatusOK)
}

func (h *Handlers) Refresh(c *fiber.Ctx) error {
	ctx := c.Context()

	session, err := sessionFromCtx(ctx)
	if err != nil {
		return c.SendStatus(pkg.ToREST(err))
	}

	err = h.authUC.Refresh(c.Context(), session.Id)
	if err != nil {
		return c.SendStatus(pkg.ToREST(err))
	}

	return c.SendStatus(fiber.StatusOK)
}

func parseBasicAuth(header string) (string, string, error) {
	const (
		op  = "parseBasicAuth"
		msg = "invalid auth header"
	)

	authParts := strings.Split(header, " ")
	if len(authParts) != 2 || strings.ToLower(authParts[0]) != "basic" {
		return "", "", pkg.Wrap(pkg.ErrUnauthenticated, nil, op, msg)
	}

	decodedAuth, err := base64.StdEncoding.DecodeString(authParts[1])
	if err != nil {
		return "", "", pkg.Wrap(pkg.ErrUnauthenticated, nil, op, msg)
	}

	authParts = strings.Split(string(decodedAuth), ":")
	if len(authParts) != 2 {
		return "", "", pkg.Wrap(pkg.ErrUnauthenticated, nil, op, msg)
	}

	return authParts[0], authParts[1], nil
}