diff --git a/.gitignore b/.gitignore index 95a8cdd..451fb35 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,2 @@ .env -.idea -pkg/go/gen \ No newline at end of file +.idea \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 0cc94e2..d0ff4d7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,8 +8,8 @@ RUN --mount=type=cache,target=/go/pkg/mod/ \ FROM base AS builder RUN --mount=type=cache,target=/go/pkg/mod/ \ --mount=type=bind,target=. \ - go build -o /bin/server . + go build -o /bin/server cmd/ms-auth/main.go FROM scratch AS runner COPY --from=builder /bin/server /bin/ -ENTRYPOINT [ "/bin/server" ] \ No newline at end of file +ENTRYPOINT ["/bin/server"] \ No newline at end of file diff --git a/Makefile b/Makefile index 50ae327..1c5c1ab 100644 --- a/Makefile +++ b/Makefile @@ -1,12 +1,18 @@ tag = latest - gen: - @buf generate + @protoc --proto_path=proto --go_opt=paths=source_relative \ + --go_out=proto --go-grpc_out=proto --grpc-gateway_out=proto \ + proto/user/v1/user.proto +gen-openapi: + @protoc --proto_path=proto --openapi_out=proto/user/v1 \ + proto/user/v1/user.proto dev: @make gen - @go run main.go - + @go run cmd/ms-auth/main.go +proxy: + @make gen + @go run cmd/ms-auth-proxy/main.go build: @make gen @docker build . -t ms-auth:${tag} diff --git a/buf.gen.yaml b/buf.gen.yaml deleted file mode 100644 index 92b08eb..0000000 --- a/buf.gen.yaml +++ /dev/null @@ -1,12 +0,0 @@ -version: v1 -managed: - enabled: true - go_package_prefix: - default: git.sch9.ru/new_gate/ms-auth/pkg/go/gen -plugins: - - name: go - out: pkg/go/gen - opt: paths=source_relative - - name: go-grpc - out: pkg/go/gen - opt: paths=source_relative diff --git a/buf.yaml b/buf.yaml deleted file mode 100644 index 1a51945..0000000 --- a/buf.yaml +++ /dev/null @@ -1,7 +0,0 @@ -version: v1 -breaking: - use: - - FILE -lint: - use: - - DEFAULT diff --git a/cmd/ms-auth-proxy/main.go b/cmd/ms-auth-proxy/main.go new file mode 100644 index 0000000..b4dc11c --- /dev/null +++ b/cmd/ms-auth-proxy/main.go @@ -0,0 +1,50 @@ +package main + +import ( + "context" + "fmt" + "git.sch9.ru/new_gate/ms-auth/config" + userv1gw "git.sch9.ru/new_gate/ms-auth/proto/user/v1" + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + "github.com/ilyakaznacheev/cleanenv" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "net/http" + "os" + "os/signal" + "syscall" +) + +func main() { + var cfg config.Config + err := cleanenv.ReadConfig(".env", &cfg) + if err != nil { + panic(fmt.Sprintf("error reading config: %s", err.Error())) + } + + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + mux := runtime.NewServeMux() + opts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())} + err = userv1gw.RegisterUserServiceHandlerFromEndpoint(ctx, mux, cfg.Address, opts) + if err != nil { + panic(err) + } + + go func() { + err = http.ListenAndServe(cfg.ProxyAddress, mux) + if err != nil { + panic(err) + } + }() + + fmt.Println("server proxy started") + + stop := make(chan os.Signal, 1) + signal.Notify(stop, syscall.SIGTERM, syscall.SIGINT) + + <-stop + return +} diff --git a/cmd/ms-auth/main.go b/cmd/ms-auth/main.go new file mode 100644 index 0000000..3e7fd99 --- /dev/null +++ b/cmd/ms-auth/main.go @@ -0,0 +1,129 @@ +package main + +import ( + "context" + "fmt" + "git.sch9.ru/new_gate/ms-auth/config" + "git.sch9.ru/new_gate/ms-auth/internal/models" + usersDelivery "git.sch9.ru/new_gate/ms-auth/internal/users/delivery/grpc" + usersRepository "git.sch9.ru/new_gate/ms-auth/internal/users/repository" + usersUseCase "git.sch9.ru/new_gate/ms-auth/internal/users/usecase" + "git.sch9.ru/new_gate/ms-auth/pkg" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" + "github.com/ilyakaznacheev/cleanenv" + _ "github.com/jackc/pgx/v5/stdlib" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/reflection" + "net" + "os" + "os/signal" + "syscall" +) + +// InterceptorLogger adapts zap logger to interceptor logger. +// This code is simple enough to be copied and not imported. +func InterceptorLogger(l *zap.Logger) logging.Logger { + return logging.LoggerFunc(func(ctx context.Context, lvl logging.Level, msg string, fields ...any) { + f := make([]zap.Field, 0, len(fields)/2) + + for i := 0; i < len(fields); i += 2 { + key := fields[i] + value := fields[i+1] + + switch v := value.(type) { + case string: + f = append(f, zap.String(key.(string), v)) + case int: + f = append(f, zap.Int(key.(string), v)) + case bool: + f = append(f, zap.Bool(key.(string), v)) + default: + f = append(f, zap.Any(key.(string), v)) + } + } + + logger := l.WithOptions(zap.AddCallerSkip(1)).With(f...) + + switch lvl { + case logging.LevelDebug: + logger.Debug(msg) + case logging.LevelInfo: + logger.Info(msg) + case logging.LevelWarn: + logger.Warn(msg) + case logging.LevelError: + logger.Error(msg) + default: + panic(fmt.Sprintf("unknown level %v", lvl)) + } + }) +} + +func main() { + var cfg config.Config + err := cleanenv.ReadConfig(".env", &cfg) + if err != nil { + panic(fmt.Sprintf("error reading config: %s", err.Error())) + } + + var logger *zap.Logger + if cfg.Env == "prod" { + logger = zap.Must(zap.NewProduction()) + } else if cfg.Env == "dev" { + logger = zap.Must(zap.NewDevelopment()) + } else { + panic(fmt.Sprintf(`error reading config: env expected "prod" or "dev", got "%s"`, cfg.Env)) + } + + logger.Info("connecting to postgres") + db, err := pkg.NewPostgresDB(cfg.PostgresDSN) + if err != nil { + logger.Fatal(fmt.Sprintf("error connecting to postgres: %s", err.Error())) + } + defer db.Close() + logger.Info("successfully connected to postgres") + + logger.Info("connecting to redis") + vk, err := pkg.NewValkeyClient(cfg.RedisDSN) + if err != nil { + logger.Fatal(fmt.Sprintf("error connecting to redis: %s", err.Error())) + } + logger.Info("successfully connected to redis") + + userRepo := usersRepository.NewUserRepository(db) + + _, err = userRepo.C().CreateUser(context.Background(), cfg.AdminUsername, cfg.AdminPassword, models.RoleAdmin) + if err != nil { + logger.Error(fmt.Sprintf("error creating admin user: %s", err.Error())) + } + + sessionRepo := usersRepository.NewValkeyRepository(vk) + userUC := usersUseCase.NewUseCase(userRepo, sessionRepo, cfg) + + gserver := grpc.NewServer(grpc.ChainUnaryInterceptor( + logging.UnaryServerInterceptor(InterceptorLogger(logger)), + )) + defer gserver.GracefulStop() + + usersDelivery.NewUserHandlers(gserver, userUC) + reflection.Register(gserver) + + ln, err := net.Listen("tcp", cfg.Address) + if err != nil { + panic(err) + } + + go func() { + if err = gserver.Serve(ln); err != nil { + panic(err) + } + }() + + logger.Info(fmt.Sprintf("server started on %s", cfg.Address)) + + stop := make(chan os.Signal, 1) + signal.Notify(stop, syscall.SIGTERM, syscall.SIGINT) + + <-stop +} diff --git a/config/config.go b/config/config.go index bbdfce0..807f18a 100644 --- a/config/config.go +++ b/config/config.go @@ -1,11 +1,16 @@ package config type Config struct { - Env string `env:"ENV" env-default:"prod"` - Address string `env:"ADDRESS" env-default:":8090"` + Env string `env:"ENV" env-default:"prod"` + + Address string `env:"ADDRESS" env-default:":8090"` + ProxyAddress string `env:"PROXY_ADDRESS" env-default:":8091"` PostgresDSN string `env:"POSTGRES_DSN" required:"true"` RedisDSN string `env:"REDIS_DSN" required:"true"` JWTSecret string `env:"JWT_SECRET" required:"true"` + + AdminUsername string `env:"ADMIN_USERNAME" env-default:"admin"` + AdminPassword string `env:"ADMIN_PASSWORD" env-default:"admin"` } diff --git a/go.mod b/go.mod index 14d664f..cd428a2 100644 --- a/go.mod +++ b/go.mod @@ -1,53 +1,45 @@ module git.sch9.ru/new_gate/ms-auth -go 1.21.3 +go 1.22.7 + +toolchain go1.22.10 require ( github.com/DATA-DOG/go-sqlmock v1.5.2 + github.com/golang-jwt/jwt/v4 v4.5.1 github.com/google/uuid v1.6.0 + github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.2.0 + github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 github.com/ilyakaznacheev/cleanenv v1.5.0 github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 - github.com/stretchr/testify v1.9.0 + github.com/stretchr/testify v1.10.0 github.com/valkey-io/valkey-go v1.0.47 + github.com/valkey-io/valkey-go/mock v1.0.47 go.uber.org/mock v0.4.0 go.uber.org/zap v1.27.0 - golang.org/x/crypto v0.26.0 - google.golang.org/grpc v1.67.1 - google.golang.org/protobuf v1.34.2 + golang.org/x/crypto v0.31.0 + google.golang.org/genproto/googleapis/api v0.0.0-20241118233622-e639e219e697 + google.golang.org/grpc v1.68.0 + google.golang.org/protobuf v1.35.2 ) require ( - cel.dev/expr v0.16.0 // indirect - cloud.google.com/go/compute/metadata v0.5.0 // indirect - github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect - github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/cncf/xds/go v0.0.0-20240723142845-024c85f92f20 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/envoyproxy/go-control-plane v0.13.0 // indirect - github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect - github.com/golang/glog v1.2.2 // indirect - github.com/google/go-cmp v0.6.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect - github.com/kr/pretty v0.3.1 // indirect - github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect - github.com/valkey-io/valkey-go/mock v1.0.47 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/net v0.28.0 // indirect - golang.org/x/oauth2 v0.22.0 // indirect - golang.org/x/sync v0.8.0 // indirect - golang.org/x/sys v0.24.0 // indirect - golang.org/x/text v0.17.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20240814211410-ddb44dafa142 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 // indirect + golang.org/x/net v0.29.0 // indirect + golang.org/x/sync v0.10.0 // indirect + golang.org/x/sys v0.28.0 // indirect + golang.org/x/text v0.21.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20241118233622-e639e219e697 // indirect ) require ( github.com/BurntSushi/toml v1.2.1 // indirect - github.com/golang-jwt/jwt v3.2.2+incompatible github.com/jackc/pgx/v5 v5.6.0 github.com/jmoiron/sqlx v1.4.0 github.com/joho/godotenv v1.5.1 // indirect diff --git a/go.sum b/go.sum index 631b12f..948a2fc 100644 --- a/go.sum +++ b/go.sum @@ -1,37 +1,26 @@ -cel.dev/expr v0.16.0 h1:yloc84fytn4zmJX2GU3TkXGsaieaV7dQ057Qs4sIG2Y= -cel.dev/expr v0.16.0/go.mod h1:TRSuuV7DlVCE/uwv5QbAiW/v8l5O8C4eEPHeu7gf7Sg= -cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY= -cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/BurntSushi/toml v1.2.1 h1:9F2/+DoOYIOksmaJFPw1tGFy1eDnIJXg+UHjuD8lTak= github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= -github.com/census-instrumentation/opencensus-proto v0.4.1 h1:iKLQ0xPNFxR/2hzXZMrBo8f1j86j5WHzznCCQxV/b8g= -github.com/census-instrumentation/opencensus-proto v0.4.1/go.mod h1:4T9NM4+4Vw91VeyqjLS6ao50K5bOcLKN6Q42XnYaRYw= -github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= -github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cncf/xds/go v0.0.0-20240723142845-024c85f92f20 h1:N+3sFI5GUjRKBi+i0TxYVST9h4Ie192jJWpHvthBBgg= -github.com/cncf/xds/go v0.0.0-20240723142845-024c85f92f20/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/envoyproxy/go-control-plane v0.13.0 h1:HzkeUz1Knt+3bK+8LG1bxOO/jzWZmdxpwC51i202les= -github.com/envoyproxy/go-control-plane v0.13.0/go.mod h1:GRaKG3dwvFoTg4nj7aXdZnvMg4d7nvT/wl9WgVXn3Q8= -github.com/envoyproxy/protoc-gen-validate v1.1.0 h1:tntQDh69XqOCOZsDz0lVJQez/2L6Uu2PdjCQwWCJ3bM= -github.com/envoyproxy/protoc-gen-validate v1.1.0/go.mod h1:sXRDRVmzEbkM7CVcM06s9shE/m23dg3wzjl0UWqJ2q4= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= -github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= -github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/golang/glog v1.2.2 h1:1+mZ9upx1Dh6FmUTFR1naJ77miKiXgALjWOZ3NVFPmY= -github.com/golang/glog v1.2.2/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= +github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo= +github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.2.0 h1:kQ0NI7W1B3HwiN5gAYtY+XFItDPbLBwYRxAqbFTyDes= +github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.2.0/go.mod h1:zrT2dxOAjNFPRGjTUe2Xmb4q4YdUwVvQFV6xiCSf+z0= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 h1:TmHmbvxPmaegwhDubVz0lICL0J5Ka2vwTzhoePEXsGE= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0/go.mod h1:qztMSjm835F2bXf+5HKAPIS5qsmQDqZna/PgVt4rWtI= github.com/ilyakaznacheev/cleanenv v1.5.0 h1:0VNZXggJE2OYdXE87bfSSwGxeiGt9moSR2lOrsHHvr4= github.com/ilyakaznacheev/cleanenv v1.5.0/go.mod h1:a5aDzaJrLCQZsazHol1w8InnDcOX0OColm64SlIi6gk= github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 h1:Dj0L5fhJ9F82ZJyVOmBx6msDp/kfd1t9GRfny/mfJA0= @@ -57,23 +46,17 @@ github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/onsi/gomega v1.31.1 h1:KYppCUK+bUgAZwHOu7EXVBKyQA6ILvOESHkn/tgoqvo= -github.com/onsi/gomega v1.31.1/go.mod h1:y40C95dwAD1Nz36SsEnxvfFe8FFfNxzI5eJ0EYGyAy0= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= -github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= -github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= +github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= +github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/valkey-io/valkey-go v1.0.41 h1:pWgh9MP24Vl0ANZ0KxEMwB/LHvTUKwlm2SPuWIrSlFw= -github.com/valkey-io/valkey-go v1.0.41/go.mod h1:LXqAbjygRuA1YRocojTslAGx2dQB4p8feaseGviWka4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/valkey-io/valkey-go v1.0.47 h1:fW5+m2BaLAbxB1EWEEWmj+i2n+YcYFBDG/jKs6qu5j8= github.com/valkey-io/valkey-go v1.0.47/go.mod h1:BXlVAPIL9rFQinSFM+N32JfWzfCaUAqBpZkc4vPY6fM= github.com/valkey-io/valkey-go/mock v1.0.47 h1:fQZUJrJEx4IG7vH1CjSSqPmx+5Gd6cwwdr7gcDDAIe0= @@ -86,43 +69,26 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= -golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= -golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= -golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= -golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= -golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= -golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= -golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= -golang.org/x/oauth2 v0.22.0 h1:BzDx2FehcG7jJwgWLELCdmLuxk2i+x9UDpSiss2u0ZA= -golang.org/x/oauth2 v0.22.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= -golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= -golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= -golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= -golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= -google.golang.org/genproto/googleapis/api v0.0.0-20240814211410-ddb44dafa142 h1:wKguEg1hsxI2/L3hUYrpo1RVi48K+uTyzKqprwLXsb8= -google.golang.org/genproto/googleapis/api v0.0.0-20240814211410-ddb44dafa142/go.mod h1:d6be+8HhtEtucleCbxpPW9PA9XwISACu8nvpPqF0BVo= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 h1:NnYq6UN9ReLM9/Y01KWNOWyI5xQ9kbIms5GGJVwS/Yc= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 h1:e7S5W7MGGLaSu8j3YjdezkZ+m1/Nm0uRVRMEMGk26Xs= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= -google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA= -google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0= -google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E= -google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= -google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= -google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= -google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= -google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= +golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= +golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= +golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +google.golang.org/genproto/googleapis/api v0.0.0-20241118233622-e639e219e697 h1:pgr/4QbFyktUv9CtQ/Fq4gzEE6/Xs7iCXbktaGzLHbQ= +google.golang.org/genproto/googleapis/api v0.0.0-20241118233622-e639e219e697/go.mod h1:+D9ySVjN8nY8YCVjc5O7PZDIdZporIDY3KaGfJunh88= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241118233622-e639e219e697 h1:LWZqQOEjDyONlF1H6afSWpAL/znlREo2tHfLoe+8LMA= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241118233622-e639e219e697/go.mod h1:5uTbfoYQed2U9p3KIj2/Zzm02PYhndfdmML0qC3q3FU= +google.golang.org/grpc v1.68.0 h1:aHQeeJbo8zAkAa3pRzrVjZlbz6uSfeOXlJNQM0RAbz0= +google.golang.org/grpc v1.68.0/go.mod h1:fmSPC5AsjSBCK54MyHRx48kpOti1/jRfOlwEWywNjWA= +google.golang.org/protobuf v1.35.2 h1:8Ar7bF+apOIoThw1EdZl0p1oWvMqTHmpA2fRTyZO8io= +google.golang.org/protobuf v1.35.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/internal/models/role.go b/internal/models/role.go deleted file mode 100644 index 62f2bd2..0000000 --- a/internal/models/role.go +++ /dev/null @@ -1,46 +0,0 @@ -package models - -import ( - "git.sch9.ru/new_gate/ms-auth/pkg/utils" -) - -type Role int32 - -const ( - RoleSpectator Role = 0 - RoleParticipant Role = 1 - RoleModerator Role = 2 - RoleAdmin Role = 3 -) - -func (role Role) IsAdmin() bool { - return role == RoleAdmin -} - -func (role Role) IsModerator() bool { - return role == RoleModerator -} - -func (role Role) IsParticipant() bool { - return role == RoleParticipant -} - -func (role Role) IsSpectator() bool { - return role == RoleSpectator -} - -func (role Role) AtLeast(other Role) bool { - return role >= other -} - -func (role Role) AtMost(other Role) bool { - return role <= other -} - -func (role Role) Valid() error { - switch role { - case RoleSpectator, RoleParticipant, RoleModerator, RoleAdmin: - return nil - } - return utils.ErrBadRole -} diff --git a/internal/models/session.go b/internal/models/session.go index dd709e2..8f92d66 100644 --- a/internal/models/session.go +++ b/internal/models/session.go @@ -1,55 +1,65 @@ package models import ( - "git.sch9.ru/new_gate/ms-auth/pkg/utils" - "github.com/golang-jwt/jwt" + "encoding/json" + "errors" "github.com/google/uuid" + "time" ) type Session struct { - Id *string - UserId *int32 -} - -func NewSession(userId int32) *Session { - return &Session{ - Id: utils.AsStringP(uuid.NewString()), - UserId: &userId, - } + Id string `json:"id" db:"id"` + UserId int32 `json:"user_id" db:"user_id"` + Role Role `json:"role" db:"role"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + //UserAgent string `json:"user_agent"` + //Ip string `json:"ip"` } func (s Session) Valid() error { - if s.Id == nil { - return utils.ErrBadSession + if uuid.Validate(s.Id) != nil { + return errors.New("invalid session id") } - if s.UserId == nil { - return utils.ErrBadSession + if s.UserId == 0 { + return errors.New("empty user id") + } + if s.CreatedAt.IsZero() { + return errors.New("empty created at") + } + if !s.Role.IsAdmin() && !s.Role.IsModerator() && !s.Role.IsParticipant() { + return errors.New("invalid role") } return nil } -func (s Session) Token(secret string) (string, error) { +func NewSession(userId int32, role Role) (string, string, error) { + s := &Session{ + Id: uuid.NewString(), + UserId: userId, + Role: role, + CreatedAt: time.Now(), + } if err := s.Valid(); err != nil { - return "", err + return "", "", err } - refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, s) - str, err := refreshToken.SignedString([]byte(secret)) + + b, err := json.Marshal(s) if err != nil { - return "", utils.ErrBadSession + return "", "", err } - return str, nil + + return string(b), s.Id, nil } -func Parse(tkn string, secret string) (*Session, error) { - parsedToken, err := jwt.ParseWithClaims(tkn, &Session{}, func(token *jwt.Token) (interface{}, error) { - return []byte(secret), nil - }) - if err != nil { - return nil, utils.ErrBadSession - } - session := parsedToken.Claims.(*Session) - if err = session.Valid(); err != nil { +func ParseSession(s string) (*Session, error) { + sess := &Session{} + if err := json.Unmarshal([]byte(s), sess); err != nil { return nil, err } - return session, nil + + if err := sess.Valid(); err != nil { + return nil, err + } + + return sess, nil } diff --git a/internal/models/user.go b/internal/models/user.go index 69e6ba7..e44ed9a 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -1,81 +1,52 @@ package models import ( - "git.sch9.ru/new_gate/ms-auth/pkg/utils" + "errors" "golang.org/x/crypto/bcrypt" "time" ) +type Role int32 + +const ( + RoleParticipant Role = 0 + RoleModerator Role = 1 + RoleAdmin Role = 2 +) + +func (role Role) IsAdmin() bool { + return role == RoleAdmin +} + +func (role Role) IsModerator() bool { + return role == RoleModerator +} + +func (role Role) IsParticipant() bool { + return role == RoleParticipant +} + +func (role Role) AtLeast(other Role) bool { + return role >= other +} + +func (role Role) AtMost(other Role) bool { + return role <= other +} + type User struct { - Id *int32 `db:"id"` - Username *string `db:"username"` - Password *string `db:"hashed_pwd"` - Email *string `db:"email"` - ExpiresAt *time.Time `db:"expires_at"` - CreatedAt *time.Time `db:"created_at"` - UpdatedAt *time.Time `db:"updated_at"` - Role *Role `db:"role"` -} - -func (user *User) ValidUsername() error { - if user.Username == nil { - return utils.ErrBadUsername - } - - err := utils.ValidUsername(*user.Username) - if err != nil { - return err - } - - return nil -} - -func (user *User) ValidPassword() error { - if user.Password == nil { - return utils.ErrBadHandleOrPassword - } - - err := utils.ValidPassword(*user.Password) - if err != nil { - return err - } - - return nil -} - -func (user *User) ValidEmail() error { - if user.Email == nil { - return utils.ErrBadEmail - } - return utils.ValidEmail(*user.Email) -} - -func (user *User) ValidRole() error { - if user.Role == nil { - return utils.ErrBadRole - } - return user.Role.Valid() -} - -func (user *User) HashPassword() error { - if user.Password == nil { - return utils.ErrBadHandleOrPassword - } - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(*user.Password), bcrypt.DefaultCost) - if err != nil { - return utils.ErrInternal - } - user.Password = utils.AsStringP(string(hashedPassword)) - return nil + Id int32 `db:"id"` + Username string `db:"username"` + HashedPassword string `db:"hashed_pwd"` + CreatedAt time.Time `db:"created_at"` + ModifiedAt time.Time `db:"modified_at"` + Role Role `db:"role"` } func (user *User) ComparePassword(password string) error { - if user.Password == nil { - return utils.ErrBadHandleOrPassword - } - err := bcrypt.CompareHashAndPassword([]byte(*user.Password), []byte(password)) + err := bcrypt.CompareHashAndPassword([]byte(user.HashedPassword), []byte(password)) if err != nil { - return utils.ErrBadHandleOrPassword + return errors.New("bad username or password") } return nil } diff --git a/internal/sessions/delivery.go b/internal/sessions/delivery.go deleted file mode 100644 index 1788f4d..0000000 --- a/internal/sessions/delivery.go +++ /dev/null @@ -1,14 +0,0 @@ -package sessions - -import ( - "context" - sessionv1 "git.sch9.ru/new_gate/ms-auth/pkg/go/gen/proto/session/v1" - "google.golang.org/protobuf/types/known/emptypb" -) - -type SessionHandlers interface { - Create(ctx context.Context, req *sessionv1.CreateSessionRequest) (*sessionv1.CreateSessionResponse, error) - Read(ctx context.Context, req *sessionv1.ReadSessionRequest) (*sessionv1.ReadSessionResponse, error) - Update(ctx context.Context, req *sessionv1.UpdateSessionRequest) (*emptypb.Empty, error) - Delete(ctx context.Context, req *sessionv1.DeleteSessionRequest) (*emptypb.Empty, error) -} diff --git a/internal/sessions/delivery/grpc/handlers.go b/internal/sessions/delivery/grpc/handlers.go deleted file mode 100644 index fec2d79..0000000 --- a/internal/sessions/delivery/grpc/handlers.go +++ /dev/null @@ -1,89 +0,0 @@ -package grpc - -import ( - "context" - "git.sch9.ru/new_gate/ms-auth/internal/models" - "git.sch9.ru/new_gate/ms-auth/internal/sessions" - "git.sch9.ru/new_gate/ms-auth/internal/users" - sessionv1 "git.sch9.ru/new_gate/ms-auth/pkg/go/gen/proto/session/v1" - "git.sch9.ru/new_gate/ms-auth/pkg/utils" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/emptypb" -) - -type sessionHandlers struct { - sessionv1.UnimplementedSessionServiceServer - sessionUC sessions.UseCase - userUC users.UseCase -} - -func NewSessionHandlers(gserver *grpc.Server, sessionUC sessions.UseCase, userUC users.UseCase) { - handlers := &sessionHandlers{ - sessionUC: sessionUC, - userUC: userUC, - } - - sessionv1.RegisterSessionServiceServer(gserver, handlers) -} - -func (s *sessionHandlers) Create(ctx context.Context, req *sessionv1.CreateSessionRequest) (*sessionv1.CreateSessionResponse, error) { - var ( - err error - user *models.User - ) - - handle := req.GetHandle() - password := req.GetPassword() - - if utils.ValidUsername(handle) == nil { - user, err = s.userUC.ReadUserByUsername(ctx, req.GetHandle()) - } else if utils.ValidEmail(handle) == nil { - user, err = s.userUC.ReadUserByEmail(ctx, handle) - } else { - return nil, utils.ErrBadHandleOrPassword - } - if err != nil { - return nil, err - } - - err = user.ComparePassword(password) - if err != nil { - return nil, err - } - - token, err := s.sessionUC.Create(ctx, *user.Id) - if err != nil { - return nil, status.Errorf(codes.Unknown, err.Error()) // FIXME - } - return &sessionv1.CreateSessionResponse{ - Token: *token, - }, nil -} - -func (s *sessionHandlers) Read(ctx context.Context, req *sessionv1.ReadSessionRequest) (*sessionv1.ReadSessionResponse, error) { - id, err := s.sessionUC.Read(ctx, req.GetToken()) - if err != nil { - return nil, status.Errorf(codes.Unknown, err.Error()) // FIXME - } - return &sessionv1.ReadSessionResponse{ - UserId: *id, - }, nil -} - -func (s *sessionHandlers) Update(ctx context.Context, req *sessionv1.UpdateSessionRequest) (*emptypb.Empty, error) { - err := s.sessionUC.Update(ctx, req.GetToken()) - if err != nil { - return nil, status.Errorf(codes.Unknown, err.Error()) // FIXME - } - return &emptypb.Empty{}, nil -} - -func (s *sessionHandlers) Delete(ctx context.Context, req *sessionv1.DeleteSessionRequest) (*emptypb.Empty, error) { - err := s.sessionUC.Delete(ctx, req.GetToken()) - if err != nil { - return nil, status.Errorf(codes.Unknown, err.Error()) // FIXME - } - return &emptypb.Empty{}, nil -} diff --git a/internal/sessions/repository/valkey_repository.go b/internal/sessions/repository/valkey_repository.go deleted file mode 100644 index 649871b..0000000 --- a/internal/sessions/repository/valkey_repository.go +++ /dev/null @@ -1,127 +0,0 @@ -package repository - -import ( - "context" - "git.sch9.ru/new_gate/ms-auth/config" - "git.sch9.ru/new_gate/ms-auth/internal/models" - "git.sch9.ru/new_gate/ms-auth/pkg/utils" - "github.com/valkey-io/valkey-go" - "go.uber.org/zap" - "strconv" - "time" -) - -type valkeyRepository struct { - db valkey.Client - cfg config.Config - logger *zap.Logger -} - -func NewValkeyRepository(db valkey.Client, cfg config.Config, logger *zap.Logger) *valkeyRepository { - return &valkeyRepository{ - db: db, - cfg: cfg, - logger: logger, - } -} - -const sessionLifetime = time.Minute * 40 - -func (r *valkeyRepository) CreateSession(ctx context.Context, userId int32) error { - session := models.NewSession(userId) - - resp := r.db.Do(ctx, r.db. - B().Set(). - Key(strconv.Itoa(int(userId))). - Value(*session.Id). - Nx(). - Exat(time.Now().Add(sessionLifetime)). - Build(), - ) - - if err := resp.Error(); err != nil { - return utils.ErrInternal - } - - return nil -} - -func (r *valkeyRepository) ReadSessionByToken(ctx context.Context, token string) (*models.Session, error) { - session, err := models.Parse(token, r.cfg.JWTSecret) - if err != nil { - return nil, err - } - - sessionRecord, err := r.ReadSessionByUserId(ctx, *session.UserId) - if err != nil { - return nil, err - } - - if *session.Id != *sessionRecord.Id { - return nil, utils.ErrInternal - } - - return session, err -} - -func (r *valkeyRepository) ReadSessionByUserId(ctx context.Context, userId int32) (*models.Session, error) { - resp := r.db.Do(ctx, r.db.B().Get().Key(strconv.Itoa(int(userId))).Build()) - if err := resp.Error(); err != nil { - return nil, utils.ErrInternal - } - - id, err := resp.ToString() - if err != nil { - return nil, utils.ErrInternal - } - - return &models.Session{ - Id: &id, - UserId: &userId, - }, err -} - -func (r *valkeyRepository) UpdateSession(ctx context.Context, session *models.Session) error { - resp := r.db.Do(ctx, r.db. - B().Set(). - Key(strconv.Itoa(int(*session.UserId))). - Value(*session.Id). - Xx(). - Exat(time.Now().Add(sessionLifetime)). - Build(), - ) - - if err := resp.Error(); err != nil { - return utils.ErrInternal - } - - return nil -} - -func (r *valkeyRepository) DeleteSessionByToken(ctx context.Context, token string) error { - session, err := models.Parse(token, r.cfg.JWTSecret) - if err != nil { - return err - } - - err = r.DeleteSessionByUserId(ctx, *session.UserId) - if err != nil { - return err - } - - return nil -} - -func (r *valkeyRepository) DeleteSessionByUserId(ctx context.Context, userId int32) error { - resp := r.db.Do(ctx, r.db. - B().Del(). - Key(strconv.Itoa(int(userId))). - Build(), - ) - - if err := resp.Error(); err != nil { - return utils.ErrInternal - } - - return nil -} diff --git a/internal/sessions/repository/valkey_repository_test.go b/internal/sessions/repository/valkey_repository_test.go deleted file mode 100644 index 8b5091e..0000000 --- a/internal/sessions/repository/valkey_repository_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package repository - -import ( - "context" - "git.sch9.ru/new_gate/ms-auth/config" - "github.com/stretchr/testify/require" - "github.com/valkey-io/valkey-go/mock" - "go.uber.org/mock/gomock" - "go.uber.org/zap" - "testing" -) - -func TestValkeyRepository_CreateSession(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - client := mock.NewClient(ctrl) - - sessionRepo := NewValkeyRepository(client, config.Config{JWTSecret: "secret"}, zap.NewNop()) - - t.Run("valid session creation", func(t *testing.T) { - var userId int32 = 1 - - matcher := mock.MatchFn(func(cmd []string) bool { - if cmd[0] != "SET" { - return false - } - if cmd[1] != "1" { - return false - } - if cmd[3] != "NX" { - return false - } - if cmd[4] != "EXAT" { - return false - } - return true - }) - - ctx := context.Background() - client.EXPECT().Do(ctx, matcher) - - err := sessionRepo.CreateSession(context.Background(), userId) - require.NoError(t, err) - }) -} diff --git a/internal/sessions/usecase.go b/internal/sessions/usecase.go deleted file mode 100644 index 43b56ac..0000000 --- a/internal/sessions/usecase.go +++ /dev/null @@ -1,10 +0,0 @@ -package sessions - -import "context" - -type UseCase interface { - Create(ctx context.Context, userId int32) (*string, error) - Read(ctx context.Context, token string) (*int32, error) - Update(ctx context.Context, token string) error - Delete(ctx context.Context, token string) error -} diff --git a/internal/sessions/usecase/usecase.go b/internal/sessions/usecase/usecase.go deleted file mode 100644 index 0f8bb73..0000000 --- a/internal/sessions/usecase/usecase.go +++ /dev/null @@ -1,71 +0,0 @@ -package usecase - -import ( - "context" - "git.sch9.ru/new_gate/ms-auth/config" - "git.sch9.ru/new_gate/ms-auth/internal/sessions" -) - -type useCase struct { - sessionRepo sessions.ValkeyRepository - cfg config.Config -} - -func NewUseCase(sessionRepo sessions.ValkeyRepository, cfg config.Config) *useCase { - return &useCase{ - sessionRepo: sessionRepo, - cfg: cfg, - } -} - -func (s *useCase) Create(ctx context.Context, userId int32) (*string, error) { - var ( - err error - ) - - s.sessionRepo.CreateSession(ctx, userId) // FIXME - - session, err := s.sessionRepo.ReadSessionByUserId(ctx, userId) - if err != nil { - return nil, err - } - - token, err := session.Token(s.cfg.JWTSecret) - if err != nil { - return nil, err - } - - return &token, nil -} - -func (s *useCase) Read(ctx context.Context, token string) (*int32, error) { - session, err := s.sessionRepo.ReadSessionByToken(ctx, token) - if err != nil { - return nil, err - } - return session.UserId, nil -} - -func (s *useCase) Update(ctx context.Context, token string) error { - session, err := s.sessionRepo.ReadSessionByToken(ctx, token) - if err != nil { - return err - } - err = s.sessionRepo.UpdateSession(ctx, session) - if err != nil { - return err - } - return nil -} - -func (s *useCase) Delete(ctx context.Context, token string) error { - session, err := s.sessionRepo.ReadSessionByToken(ctx, token) - if err != nil { - return err - } - err = s.sessionRepo.DeleteSessionByUserId(ctx, *session.UserId) - if err != nil { - return err - } - return nil -} diff --git a/internal/sessions/valkey_repository.go b/internal/sessions/valkey_repository.go deleted file mode 100644 index 88e16f7..0000000 --- a/internal/sessions/valkey_repository.go +++ /dev/null @@ -1,15 +0,0 @@ -package sessions - -import ( - "context" - "git.sch9.ru/new_gate/ms-auth/internal/models" -) - -type ValkeyRepository interface { - CreateSession(ctx context.Context, userId int32) error - ReadSessionByToken(ctx context.Context, token string) (*models.Session, error) - ReadSessionByUserId(ctx context.Context, userId int32) (*models.Session, error) - UpdateSession(ctx context.Context, session *models.Session) error - DeleteSessionByToken(ctx context.Context, token string) error - DeleteSessionByUserId(ctx context.Context, userId int32) error -} diff --git a/internal/users/delivery.go b/internal/users/delivery.go index e6cfa06..3c5bb8b 100644 --- a/internal/users/delivery.go +++ b/internal/users/delivery.go @@ -2,13 +2,18 @@ package users import ( "context" - userv1 "git.sch9.ru/new_gate/ms-auth/pkg/go/gen/proto/user/v1" + userv1 "git.sch9.ru/new_gate/ms-auth/proto/user/v1" "google.golang.org/protobuf/types/known/emptypb" ) type UserHandlers interface { CreateUser(ctx context.Context, req *userv1.CreateUserRequest) (*userv1.CreateUserResponse, error) - ReadUser(ctx context.Context, req *userv1.ReadUserRequest) (*userv1.ReadUserResponse, error) + GetUser(ctx context.Context, req *userv1.GetUserRequest) (*userv1.GetUserResponse, error) UpdateUser(ctx context.Context, req *userv1.UpdateUserRequest) (*emptypb.Empty, error) DeleteUser(ctx context.Context, req *userv1.DeleteUserRequest) (*emptypb.Empty, error) + Login(ctx context.Context, req *userv1.LoginRequest) (*emptypb.Empty, error) + Verify(ctx context.Context, req *emptypb.Empty) (*emptypb.Empty, error) + Refresh(ctx context.Context, req *emptypb.Empty) (*emptypb.Empty, error) + Logout(ctx context.Context, req *emptypb.Empty) (*emptypb.Empty, error) + CompleteLogout(ctx context.Context, req *emptypb.Empty) (*emptypb.Empty, error) } diff --git a/internal/users/delivery/grpc/handlers.go b/internal/users/delivery/grpc/handlers.go index 5a7eee6..c29493d 100644 --- a/internal/users/delivery/grpc/handlers.go +++ b/internal/users/delivery/grpc/handlers.go @@ -2,46 +2,176 @@ package grpc import ( "context" + "errors" "git.sch9.ru/new_gate/ms-auth/internal/models" "git.sch9.ru/new_gate/ms-auth/internal/users" - userv1 "git.sch9.ru/new_gate/ms-auth/pkg/go/gen/proto/user/v1" - "git.sch9.ru/new_gate/ms-auth/pkg/utils" + "git.sch9.ru/new_gate/ms-auth/pkg" + userv1 "git.sch9.ru/new_gate/ms-auth/proto/user/v1" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "google.golang.org/grpc/metadata" "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/timestamppb" + "strings" ) -type userHandlers struct { +type UserHandlers struct { userv1.UnimplementedUserServiceServer userUC users.UseCase } func NewUserHandlers(gserver *grpc.Server, userUC users.UseCase) { - handlers := &userHandlers{ + handlers := &UserHandlers{ userUC: userUC, } userv1.RegisterUserServiceServer(gserver, handlers) } -func (h *userHandlers) CreateUser(ctx context.Context, req *userv1.CreateUserRequest) (*userv1.CreateUserResponse, error) { - user := req.GetUser() - if user == nil { - return nil, status.Errorf(codes.Unknown, "") // FIXME +const ( + SessionHeaderName = "x-session-id" + AuthUserHeaderName = "x-auth-user-id" +) + +func (h *UserHandlers) Login(ctx context.Context, req *userv1.LoginRequest) (*emptypb.Empty, error) { + const op = "UserHandlers.Login" + + var ( + err error + user *models.User + ) + + username := req.GetUsername() + password := req.GetPassword() + + user, err = h.userUC.ReadUserByUsername(ctx, username) + if err != nil { + return nil, pkg.ToGRPC(err) } + + err = user.ComparePassword(password) + if err != nil { + return nil, pkg.ToGRPC(pkg.Wrap(pkg.ErrNotFound, err, op, "bad username or password")) + } + + sessionId, err := h.userUC.CreateSession(ctx, user.Id, user.Role) + if err != nil { + return nil, pkg.ToGRPC(err) + } + + header := metadata.New(map[string]string{ + SessionHeaderName: sessionId, + }) + err = grpc.SendHeader(ctx, header) + if err != nil { + return nil, err + } + + return &emptypb.Empty{}, nil +} + +func AuthSessionIdFromContext(ctx context.Context) (string, error) { + md, ok := metadata.FromIncomingContext(ctx) + + if !ok { + return "", errors.New("failed to get metadata") + } + tokens := md.Get(SessionHeaderName) + sessionId := strings.Join(tokens, "") + if len(sessionId) == 0 { + return "", errors.New("no session id in context") + } + return sessionId, nil +} + +func (h *UserHandlers) Refresh(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) { + const op = "UserHandlers.Refresh" + + sessionId, err := AuthSessionIdFromContext(ctx) + if err != nil { + return nil, pkg.ToGRPC(pkg.Wrap(err, pkg.ErrUnauthenticated, op, "no session id in context")) + } + err = h.userUC.UpdateSession(ctx, sessionId) + if err != nil { + return nil, pkg.ToGRPC(err) + } + return &emptypb.Empty{}, nil +} + +func (h *UserHandlers) Logout(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) { + const op = "UserHandlers.Logout" + + sessionId, err := AuthSessionIdFromContext(ctx) + if err != nil { + return nil, pkg.ToGRPC(pkg.Wrap(err, pkg.ErrUnauthenticated, op, "no session id in context")) + } + err = h.userUC.DeleteSession(ctx, sessionId) + if err != nil { + return nil, pkg.ToGRPC(err) + } + return &emptypb.Empty{}, nil +} + +func (h *UserHandlers) CompleteLogout(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) { + const op = "UserHandlers.CompleteLogout" + + sessionId, err := AuthSessionIdFromContext(ctx) + if err != nil { + return nil, pkg.ToGRPC(pkg.Wrap(err, pkg.ErrUnauthenticated, op, "no session id in context")) + } + + session, err := h.userUC.ReadSession(ctx, sessionId) + if err != nil { + return nil, pkg.ToGRPC(err) + } + + err = h.userUC.DeleteAllSessions(ctx, session.UserId) + if err != nil { + return nil, pkg.ToGRPC(err) + } + return &emptypb.Empty{}, nil +} + +func (h *UserHandlers) Verify(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) { + const op = "UserHandlers.Verify" + + sessionId, err := AuthSessionIdFromContext(ctx) + if err != nil { + return nil, pkg.ToGRPC(pkg.Wrap(err, pkg.ErrUnauthenticated, op, "no session id in context")) + } + token, err := h.userUC.Verify(ctx, sessionId) + if err != nil { + return nil, pkg.ToGRPC(err) + } + + header := metadata.New(map[string]string{ + AuthUserHeaderName: token, + }) + err = grpc.SendHeader(ctx, header) + if err != nil { + return nil, err + } + + return &emptypb.Empty{}, nil +} + +func (h *UserHandlers) CreateUser(ctx context.Context, req *userv1.CreateUserRequest) (*userv1.CreateUserResponse, error) { + const op = "UserHandlers.CreateUser" + + sessionId, err := AuthSessionIdFromContext(ctx) + if err != nil { + return nil, pkg.ToGRPC(pkg.Wrap(err, pkg.ErrUnauthenticated, op, "no session id in context")) + } + + ctx = context.WithValue(ctx, "userId", sessionId) + id, err := h.userUC.CreateUser( ctx, - &models.User{ - Username: utils.AsStringP(user.GetUsername()), - Password: utils.AsStringP(user.GetPassword()), - Email: nil, - ExpiresAt: utils.TimeP(user.ExpiresAt), - Role: AsMRoleP(user.GetRole()), - }, + req.GetUsername(), + req.GetPassword(), + models.RoleParticipant, ) if err != nil { - return nil, status.Errorf(codes.Unknown, err.Error()) // FIXME + return nil, pkg.ToGRPC(err) } return &userv1.CreateUserResponse{ @@ -49,55 +179,64 @@ func (h *userHandlers) CreateUser(ctx context.Context, req *userv1.CreateUserReq }, nil } -func (h *userHandlers) ReadUser(ctx context.Context, req *userv1.ReadUserRequest) (*userv1.ReadUserResponse, error) { - user, err := h.userUC.ReadUser( +func (h *UserHandlers) GetUser(ctx context.Context, req *userv1.GetUserRequest) (*userv1.GetUserResponse, error) { + user, err := h.userUC.ReadUserById( ctx, req.GetId(), ) if err != nil { - return nil, status.Errorf(codes.Unknown, err.Error()) // FIXME + return nil, pkg.ToGRPC(err) } - return &userv1.ReadUserResponse{ - User: &userv1.ReadUserResponse_User{ - Id: *user.Id, - Username: *user.Username, - ExpiresAt: utils.TimestampP(user.ExpiresAt), - CreatedAt: utils.TimestampP(user.CreatedAt), - Role: *AsRoleP(user.Role), + return &userv1.GetUserResponse{ + User: &userv1.User{ + Id: user.Id, + Username: user.Username, + CreatedAt: timestamppb.New(user.CreatedAt), + ModifiedAt: timestamppb.New(user.ModifiedAt), + Role: userv1.Role(user.Role), }, }, nil } -func (h *userHandlers) UpdateUser(ctx context.Context, req *userv1.UpdateUserRequest) (*emptypb.Empty, error) { - user := req.GetUser() - if user == nil { - return nil, status.Errorf(codes.Unknown, "") // FIXME +func (h *UserHandlers) UpdateUser(ctx context.Context, req *userv1.UpdateUserRequest) (*emptypb.Empty, error) { + const op = "UserHandlers.UpdateUser" + + sessionId, err := AuthSessionIdFromContext(ctx) + if err != nil { + return nil, pkg.ToGRPC(pkg.Wrap(err, pkg.ErrUnauthenticated, op, "no session id in context")) } - err := h.userUC.UpdateUser( + + ctx = context.WithValue(ctx, "userId", sessionId) + + err = h.userUC.UpdateUser( ctx, - &models.User{ - Id: utils.AsInt32P(user.GetId()), - Username: utils.AsStringP(user.GetUsername()), - Password: utils.AsStringP(user.GetPassword()), - Email: nil, - ExpiresAt: utils.TimeP(user.ExpiresAt), - Role: AsMRoleP(user.GetRole()), - }, + req.GetId(), + AsStringP(req.Username), + AsMRoleP(req.Role), ) if err != nil { - return nil, status.Errorf(codes.Unknown, err.Error()) // FIXME + return nil, pkg.ToGRPC(err) } return &emptypb.Empty{}, nil } -func (h *userHandlers) DeleteUser(ctx context.Context, req *userv1.DeleteUserRequest) (*emptypb.Empty, error) { - err := h.userUC.DeleteUser( +func (h *UserHandlers) DeleteUser(ctx context.Context, req *userv1.DeleteUserRequest) (*emptypb.Empty, error) { + const op = "UserHandlers.DeleteUser" + + sessionId, err := AuthSessionIdFromContext(ctx) + if err != nil { + return nil, pkg.ToGRPC(pkg.Wrap(err, pkg.ErrUnauthenticated, op, "no session id in context")) + } + + ctx = context.WithValue(ctx, "userId", sessionId) + + err = h.userUC.DeleteUser( ctx, req.GetId(), ) if err != nil { - return nil, status.Errorf(codes.Unknown, err.Error()) // FIXME + return nil, pkg.ToGRPC(err) } return &emptypb.Empty{}, nil } @@ -107,10 +246,10 @@ func AsMRoleP(v userv1.Role) *models.Role { return &vv } -func AsRoleP(r *models.Role) *userv1.Role { - if r == nil { - return nil - } - rr := userv1.Role(*r) - return &rr +func AsRoleP(v models.Role) *models.Role { + return &v +} + +func AsStringP(str string) *string { + return &str } diff --git a/internal/users/delivery/grpc/handlers_test.go b/internal/users/delivery/grpc/handlers_test.go new file mode 100644 index 0000000..34edc51 --- /dev/null +++ b/internal/users/delivery/grpc/handlers_test.go @@ -0,0 +1,428 @@ +package grpc + +import ( + "context" + "git.sch9.ru/new_gate/ms-auth/internal/models" + "git.sch9.ru/new_gate/ms-auth/internal/users" + mock_users "git.sch9.ru/new_gate/ms-auth/internal/users/delivery/mock" + userv1 "git.sch9.ru/new_gate/ms-auth/proto/user/v1" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "golang.org/x/crypto/bcrypt" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/emptypb" + "net" + "testing" + "time" +) + +func startServer(t *testing.T, uc users.UseCase, addr string) { + t.Helper() + + gserver := grpc.NewServer() + NewUserHandlers(gserver, uc) + + ln, err := net.Listen("tcp", addr) + if err != nil { + panic(err) + } + + go func() { + if err = gserver.Serve(ln); err != nil { + panic(err) + } + }() + + t.Cleanup(func() { + gserver.Stop() + }) +} + +func buildClient(t *testing.T, addr string) userv1.UserServiceClient { + t.Helper() + conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + + return userv1.NewUserServiceClient(conn) +} + +func TestUserHandlers_Login(t *testing.T) { + t.Parallel() + + const addr = "127.0.0.1:62999" + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + uc := mock_users.NewMockUseCase(ctrl) + startServer(t, uc, addr) + + client := buildClient(t, addr) + + t.Run("valid login", func(t *testing.T) { + password := "password" + hpwd, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + require.NoError(t, err) + + user := &models.User{ + Id: 1, + Username: "username", + HashedPassword: string(hpwd), + Role: models.RoleAdmin, + } + sid := uuid.NewString() + + uc.EXPECT().ReadUserByUsername(gomock.Any(), user.Username).Return(user, nil) + uc.EXPECT().CreateSession(gomock.Any(), user.Id, user.Role).Return(sid, nil) + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + t.Cleanup(cancel) + + var header metadata.MD + _, err = client.Login(ctx, &userv1.LoginRequest{ + Username: user.Username, + Password: password, + }, grpc.Header(&header)) + require.NoError(t, err) + + require.Equal(t, sid, header.Get(SessionHeaderName)[0]) + }) + + t.Run("invalid login (wrong password)", func(t *testing.T) { + password := "password" + hpwd, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + require.NoError(t, err) + + user := &models.User{ + Id: 1, + Username: "username", + HashedPassword: string(hpwd), + Role: models.RoleAdmin, + } + + uc.EXPECT().ReadUserByUsername(gomock.Any(), user.Username).Return(user, nil) + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + t.Cleanup(cancel) + + _, err = client.Login(ctx, &userv1.LoginRequest{ + Username: user.Username, + Password: "wrongpassword", + }) + + s, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.NotFound, s.Code()) + }) +} + +func TestUserHandlers_Refresh(t *testing.T) { + t.Parallel() + + const addr = "127.0.0.1:62998" + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + uc := mock_users.NewMockUseCase(ctrl) + startServer(t, uc, addr) + + client := buildClient(t, addr) + + t.Run("valid refresh", func(t *testing.T) { + sid := uuid.NewString() + uc.EXPECT().UpdateSession(gomock.Any(), sid).Return(nil) + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + t.Cleanup(cancel) + + ctx = metadata.AppendToOutgoingContext(ctx, SessionHeaderName, sid) + + _, err := client.Refresh(ctx, &emptypb.Empty{}) + require.NoError(t, err) + }) + + t.Run("invalid refresh (no session id in context)", func(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + t.Cleanup(cancel) + + _, err := client.Refresh(ctx, &emptypb.Empty{}) + + s, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.Unauthenticated, s.Code()) + }) +} + +func TestUserHandlers_Logout(t *testing.T) { + t.Parallel() + + const addr = "127.0.0.1:62997" + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + uc := mock_users.NewMockUseCase(ctrl) + startServer(t, uc, addr) + + client := buildClient(t, addr) + + t.Run("valid logout", func(t *testing.T) { + sid := uuid.NewString() + uc.EXPECT().DeleteSession(gomock.Any(), sid).Return(nil) + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + t.Cleanup(cancel) + + ctx = metadata.AppendToOutgoingContext(ctx, SessionHeaderName, sid) + + _, err := client.Logout(ctx, &emptypb.Empty{}) + require.NoError(t, err) + }) + + t.Run("invalid logout (no session id in context)", func(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + t.Cleanup(cancel) + + _, err := client.Logout(ctx, &emptypb.Empty{}) + + s, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.Unauthenticated, s.Code()) + }) +} + +func TestUserHandlers_CompleteLogout(t *testing.T) { + t.Parallel() + + const addr = "127.0.0.1:62996" + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + uc := mock_users.NewMockUseCase(ctrl) + startServer(t, uc, addr) + + client := buildClient(t, addr) + + t.Run("valid complete logout", func(t *testing.T) { + sid := uuid.NewString() + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + ctx = metadata.AppendToOutgoingContext(ctx, SessionHeaderName, sid) + t.Cleanup(cancel) + + uc.EXPECT().ReadSession(gomock.Any(), sid).Return(&models.Session{UserId: 1}, nil) + uc.EXPECT().DeleteAllSessions(gomock.Any(), int32(1)).Return(nil) + + _, err := client.CompleteLogout(ctx, &emptypb.Empty{}) + require.NoError(t, err) + }) + + t.Run("invalid complete logout (no session id in context)", func(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + t.Cleanup(cancel) + + _, err := client.CompleteLogout(ctx, &emptypb.Empty{}) + + s, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.Unauthenticated, s.Code()) + }) +} + +func TestUserHandlers_Verify(t *testing.T) { + t.Parallel() + + const addr = "127.0.0.1:62995" + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + uc := mock_users.NewMockUseCase(ctrl) + startServer(t, uc, addr) + + client := buildClient(t, addr) + + t.Run("valid verify", func(t *testing.T) { + sid := uuid.NewString() + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + ctx = metadata.AppendToOutgoingContext(ctx, SessionHeaderName, sid) + t.Cleanup(cancel) + + uc.EXPECT().Verify(gomock.Any(), sid).Return("jwt", nil) + + var header metadata.MD + _, err := client.Verify(ctx, &emptypb.Empty{}, grpc.Header(&header)) + require.NoError(t, err) + require.Equal(t, header.Get(AuthUserHeaderName)[0], "jwt") + }) + + t.Run("invalid verify (no session id in context)", func(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + t.Cleanup(cancel) + + _, err := client.Verify(ctx, &emptypb.Empty{}) + + s, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.Unauthenticated, s.Code()) + }) +} + +func TestUserHandlers_CreateUser(t *testing.T) { + t.Parallel() + + const addr = "127.0.0.1:62994" + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + uc := mock_users.NewMockUseCase(ctrl) + startServer(t, uc, addr) + + client := buildClient(t, addr) + + t.Run("valid create user", func(t *testing.T) { + username := "username" + password := "password" + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + ctx = metadata.AppendToOutgoingContext(ctx, SessionHeaderName, uuid.NewString()) + t.Cleanup(cancel) + + uc.EXPECT().CreateUser(gomock.Any(), username, password, models.RoleParticipant).Return(int32(2), nil) + + _, err := client.CreateUser(ctx, &userv1.CreateUserRequest{ + Username: username, + Password: password, + }) + require.NoError(t, err) + }) + + t.Run("invalid create user (no session id in context)", func(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + t.Cleanup(cancel) + + _, err := client.CreateUser(ctx, &userv1.CreateUserRequest{ + Username: "username", + Password: "password", + }) + + s, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.Unauthenticated, s.Code()) + }) +} + +func TestUserHandlers_GetUser(t *testing.T) { + t.Parallel() + + const addr = "127.0.0.1:62993" + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + uc := mock_users.NewMockUseCase(ctrl) + startServer(t, uc, addr) + + client := buildClient(t, addr) + + t.Run("valid get user", func(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + t.Cleanup(cancel) + + uc.EXPECT().ReadUserById(gomock.Any(), int32(1)).Return(&models.User{ + Id: 1, + Username: "username", + CreatedAt: time.Now(), + ModifiedAt: time.Now(), + Role: models.RoleParticipant, + }, nil) + + _, err := client.GetUser(ctx, &userv1.GetUserRequest{ + Id: 1, + }) + require.NoError(t, err) + }) +} + +func TestUserHandlers_UpdateUser(t *testing.T) { + t.Parallel() + + const addr = "127.0.0.1:62992" + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + uc := mock_users.NewMockUseCase(ctrl) + startServer(t, uc, addr) + + client := buildClient(t, addr) + + t.Run("valid update user", func(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + ctx = metadata.AppendToOutgoingContext(ctx, SessionHeaderName, uuid.NewString()) + t.Cleanup(cancel) + + uc.EXPECT().UpdateUser(gomock.Any(), + int32(1), + AsStringP("username"), + AsRoleP(models.RoleModerator), + ).Return(nil) + + _, err := client.UpdateUser(ctx, &userv1.UpdateUserRequest{ + Id: 1, + Username: "username", + Role: userv1.Role_ROLE_MODERATOR, + }) + require.NoError(t, err) + }) + + t.Run("invalid update user (no session id in context)", func(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + t.Cleanup(cancel) + + _, err := client.UpdateUser(ctx, &userv1.UpdateUserRequest{ + Id: 1, + Username: "username", + Role: userv1.Role_ROLE_MODERATOR, + }) + + s, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.Unauthenticated, s.Code()) + }) +} + +func TestUserHandlers_DeleteUser(t *testing.T) { + t.Parallel() + + const addr = "127.0.0.1:62991" + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + uc := mock_users.NewMockUseCase(ctrl) + startServer(t, uc, addr) + + client := buildClient(t, addr) + + t.Run("valid delete user", func(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + ctx = metadata.AppendToOutgoingContext(ctx, SessionHeaderName, uuid.NewString()) + t.Cleanup(cancel) + + uc.EXPECT().DeleteUser(gomock.Any(), int32(1)).Return(nil) + + _, err := client.DeleteUser(ctx, &userv1.DeleteUserRequest{ + Id: 1, + }) + require.NoError(t, err) + }) +} diff --git a/internal/users/delivery/grpc/token_interceptor.go b/internal/users/delivery/grpc/token_interceptor.go deleted file mode 100644 index 9adf82d..0000000 --- a/internal/users/delivery/grpc/token_interceptor.go +++ /dev/null @@ -1,30 +0,0 @@ -package grpc - -import ( - "context" - "git.sch9.ru/new_gate/ms-auth/internal/sessions" - "google.golang.org/grpc" - "google.golang.org/grpc/metadata" -) - -func TokenInterceptor(s sessions.UseCase) grpc.UnaryServerInterceptor { - return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return handler(ctx, req) - } - auth := md.Get("authorization") - if len(auth) == 0 { - return handler(ctx, req) - } - - userId, err := s.Read(ctx, auth[0]) - if err != nil { - return handler(ctx, req) - } - - ctx = context.WithValue(ctx, "userId", *userId) - - return handler(ctx, req) - } -} diff --git a/internal/users/delivery/mock/usecase_mock.go b/internal/users/delivery/mock/usecase_mock.go new file mode 100644 index 0000000..334db77 --- /dev/null +++ b/internal/users/delivery/mock/usecase_mock.go @@ -0,0 +1,202 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: git.sch9.ru/new_gate/ms-auth/internal/users (interfaces: UseCase) +// +// Generated by this command: +// +// mockgen . UseCase +// + +// Package mock_users is a generated GoMock package. +package mock_users + +import ( + context "context" + reflect "reflect" + + models "git.sch9.ru/new_gate/ms-auth/internal/models" + gomock "go.uber.org/mock/gomock" +) + +// MockUseCase is a mock of UseCase interface. +type MockUseCase struct { + ctrl *gomock.Controller + recorder *MockUseCaseMockRecorder + isgomock struct{} +} + +// MockUseCaseMockRecorder is the mock recorder for MockUseCase. +type MockUseCaseMockRecorder struct { + mock *MockUseCase +} + +// NewMockUseCase creates a new mock instance. +func NewMockUseCase(ctrl *gomock.Controller) *MockUseCase { + mock := &MockUseCase{ctrl: ctrl} + mock.recorder = &MockUseCaseMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUseCase) EXPECT() *MockUseCaseMockRecorder { + return m.recorder +} + +// CreateSession mocks base method. +func (m *MockUseCase) CreateSession(ctx context.Context, userId int32, role models.Role) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSession", ctx, userId, role) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateSession indicates an expected call of CreateSession. +func (mr *MockUseCaseMockRecorder) CreateSession(ctx, userId, role any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSession", reflect.TypeOf((*MockUseCase)(nil).CreateSession), ctx, userId, role) +} + +// CreateUser mocks base method. +func (m *MockUseCase) CreateUser(ctx context.Context, username, password string, role models.Role) (int32, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateUser", ctx, username, password, role) + ret0, _ := ret[0].(int32) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateUser indicates an expected call of CreateUser. +func (mr *MockUseCaseMockRecorder) CreateUser(ctx, username, password, role any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateUser", reflect.TypeOf((*MockUseCase)(nil).CreateUser), ctx, username, password, role) +} + +// DeleteAllSessions mocks base method. +func (m *MockUseCase) DeleteAllSessions(ctx context.Context, userId int32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAllSessions", ctx, userId) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAllSessions indicates an expected call of DeleteAllSessions. +func (mr *MockUseCaseMockRecorder) DeleteAllSessions(ctx, userId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllSessions", reflect.TypeOf((*MockUseCase)(nil).DeleteAllSessions), ctx, userId) +} + +// DeleteSession mocks base method. +func (m *MockUseCase) DeleteSession(ctx context.Context, sessionId string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteSession", ctx, sessionId) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteSession indicates an expected call of DeleteSession. +func (mr *MockUseCaseMockRecorder) DeleteSession(ctx, sessionId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSession", reflect.TypeOf((*MockUseCase)(nil).DeleteSession), ctx, sessionId) +} + +// DeleteUser mocks base method. +func (m *MockUseCase) DeleteUser(ctx context.Context, id int32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteUser", ctx, id) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteUser indicates an expected call of DeleteUser. +func (mr *MockUseCaseMockRecorder) DeleteUser(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUser", reflect.TypeOf((*MockUseCase)(nil).DeleteUser), ctx, id) +} + +// ReadSession mocks base method. +func (m *MockUseCase) ReadSession(ctx context.Context, sessionId string) (*models.Session, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadSession", ctx, sessionId) + ret0, _ := ret[0].(*models.Session) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadSession indicates an expected call of ReadSession. +func (mr *MockUseCaseMockRecorder) ReadSession(ctx, sessionId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadSession", reflect.TypeOf((*MockUseCase)(nil).ReadSession), ctx, sessionId) +} + +// ReadUserById mocks base method. +func (m *MockUseCase) ReadUserById(ctx context.Context, id int32) (*models.User, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadUserById", ctx, id) + ret0, _ := ret[0].(*models.User) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadUserById indicates an expected call of ReadUserById. +func (mr *MockUseCaseMockRecorder) ReadUserById(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadUserById", reflect.TypeOf((*MockUseCase)(nil).ReadUserById), ctx, id) +} + +// ReadUserByUsername mocks base method. +func (m *MockUseCase) ReadUserByUsername(ctx context.Context, username string) (*models.User, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadUserByUsername", ctx, username) + ret0, _ := ret[0].(*models.User) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadUserByUsername indicates an expected call of ReadUserByUsername. +func (mr *MockUseCaseMockRecorder) ReadUserByUsername(ctx, username any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadUserByUsername", reflect.TypeOf((*MockUseCase)(nil).ReadUserByUsername), ctx, username) +} + +// UpdateSession mocks base method. +func (m *MockUseCase) UpdateSession(ctx context.Context, sessionId string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateSession", ctx, sessionId) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateSession indicates an expected call of UpdateSession. +func (mr *MockUseCaseMockRecorder) UpdateSession(ctx, sessionId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSession", reflect.TypeOf((*MockUseCase)(nil).UpdateSession), ctx, sessionId) +} + +// UpdateUser mocks base method. +func (m *MockUseCase) UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateUser", ctx, id, username, role) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateUser indicates an expected call of UpdateUser. +func (mr *MockUseCaseMockRecorder) UpdateUser(ctx, id, username, role any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockUseCase)(nil).UpdateUser), ctx, id, username, role) +} + +// Verify mocks base method. +func (m *MockUseCase) Verify(ctx context.Context, sessionId string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Verify", ctx, sessionId) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Verify indicates an expected call of Verify. +func (mr *MockUseCaseMockRecorder) Verify(ctx, sessionId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockUseCase)(nil).Verify), ctx, sessionId) +} diff --git a/internal/users/pg_repository.go b/internal/users/pg_repository.go deleted file mode 100644 index 4fe5b2c..0000000 --- a/internal/users/pg_repository.go +++ /dev/null @@ -1,15 +0,0 @@ -package users - -import ( - "context" - "git.sch9.ru/new_gate/ms-auth/internal/models" -) - -type PgRepository interface { - CreateUser(ctx context.Context, user *models.User) (int32, error) - ReadUserByEmail(ctx context.Context, email string) (*models.User, error) - ReadUserByUsername(ctx context.Context, username string) (*models.User, error) - ReadUserById(ctx context.Context, id int32) (*models.User, error) - UpdateUser(ctx context.Context, user *models.User) error - DeleteUser(ctx context.Context, id int32) error -} diff --git a/internal/users/repository.go b/internal/users/repository.go new file mode 100644 index 0000000..14d77e8 --- /dev/null +++ b/internal/users/repository.go @@ -0,0 +1,33 @@ +package users + +import ( + "context" + "git.sch9.ru/new_gate/ms-auth/internal/models" +) + +type Caller interface { + CreateUser(ctx context.Context, username string, password string, role models.Role) (int32, error) + ReadUserByUsername(ctx context.Context, username string) (*models.User, error) + ReadUserById(ctx context.Context, id int32) (*models.User, error) + UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error + DeleteUser(ctx context.Context, id int32) error +} + +type TxCaller interface { + Caller + Commit() error + Rollback() error +} + +type PgRepository interface { + BeginTx(ctx context.Context) (TxCaller, error) + C() Caller +} + +type ValkeyRepository interface { + CreateSession(ctx context.Context, userId int32, role models.Role) (string, error) + ReadSession(ctx context.Context, sessionId string) (*models.Session, error) + UpdateSession(ctx context.Context, sessionId string) error + DeleteSession(ctx context.Context, sessionId string) error + DeleteAllSessions(ctx context.Context, userId int32) error +} diff --git a/internal/users/repository/pg_repository.go b/internal/users/repository/pg_repository.go index b3216ce..34c092d 100644 --- a/internal/users/repository/pg_repository.go +++ b/internal/users/repository/pg_repository.go @@ -2,69 +2,119 @@ package repository import ( "context" + "database/sql" "errors" "git.sch9.ru/new_gate/ms-auth/internal/models" - "git.sch9.ru/new_gate/ms-auth/pkg/utils" + "git.sch9.ru/new_gate/ms-auth/internal/users" + "git.sch9.ru/new_gate/ms-auth/pkg" "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v5/pgconn" "github.com/jmoiron/sqlx" - "go.uber.org/zap" - "time" + "golang.org/x/crypto/bcrypt" + "net/mail" ) type UsersRepository struct { - db *sqlx.DB - logger *zap.Logger + db *sqlx.DB } -func NewUserRepository(db *sqlx.DB, logger *zap.Logger) *UsersRepository { +func NewUserRepository(db *sqlx.DB) *UsersRepository { return &UsersRepository{ - db: db, - logger: logger, + db: db, } } -const year = time.Hour * 24 * 365 +func (r *UsersRepository) BeginTx(ctx context.Context) (users.TxCaller, error) { + const op = "UsersRepository.BeginTx" + + tx, err := r.db.BeginTxx(ctx, nil) + if err != nil { + return nil, pkg.Wrap(pkg.ErrInternal, err, op, "database error") + } + + return &TxCaller{ + Caller: Caller{db: tx}, + db: tx, + }, nil +} + +func (r *UsersRepository) C() users.Caller { + return &Caller{db: r.db} +} + +type TxOrDB interface { + Rebind(query string) string + GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error + QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + +type Caller struct { + db TxOrDB +} + +type TxCaller struct { + Caller + db *sqlx.Tx +} + +func (c *TxCaller) Commit() error { + const op = "TxCaller.Commit" + + err := c.db.Commit() + if err != nil { + return pkg.Wrap(pkg.ErrInternal, err, op, "database error") + } + return nil +} + +func (c *TxCaller) Rollback() error { + const op = "TxCaller.Rollback" + + err := c.db.Rollback() + if err != nil { + return pkg.Wrap(pkg.ErrInternal, err, op, "database error") + } + + return nil +} const createUser = ` INSERT INTO users - (username, hashed_pwd, email, expires_at, role) -VALUES (?, ?, ?, ?, ?) + (username, hashed_pwd, role) +VALUES (trim(lower(?)), ?, ?) RETURNING id ` -func (storage *UsersRepository) CreateUser(ctx context.Context, user *models.User) (int32, error) { - if err := user.ValidUsername(); err != nil { - return 0, err +func (c *Caller) CreateUser(ctx context.Context, username, password string, role models.Role) (int32, error) { + const op = "Caller.CreateUser" + + if err := ValidUsername(username); err != nil { + return 0, pkg.Wrap(pkg.ErrBadInput, err, op, "username validation") } - if err := user.ValidPassword(); err != nil { - return 0, err + if err := ValidPassword(password); err != nil { + return 0, pkg.Wrap(pkg.ErrBadInput, err, op, "password validation") } - if err := user.ValidEmail(); err != nil { - return 0, err - } - if err := user.ValidRole(); err != nil { - return 0, err - } - if err := user.HashPassword(); err != nil { // FIXME: get rid of mutation - return 0, err + if err := ValidRole(role); err != nil { + return 0, pkg.Wrap(pkg.ErrBadInput, err, op, "role validation") } - user.ExpiresAt = utils.AsTimeP(time.Now().Add(100 * year)) // FIXME: get rid of mutation + hpwd, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return 0, pkg.Wrap(pkg.ErrBadInput, err, op, "password validation") + } - query := storage.db.Rebind(createUser) + query := c.db.Rebind(createUser) - rows, err := storage.db.QueryxContext( + rows, err := c.db.QueryxContext( ctx, query, - user.Username, - user.Password, - user.Email, - user.ExpiresAt, - user.Role, + username, + string(hpwd), + role, ) if err != nil { - return 0, storage.handlePgErr(err) + return 0, handlePgErr(err, op) } defer rows.Close() @@ -72,124 +122,138 @@ func (storage *UsersRepository) CreateUser(ctx context.Context, user *models.Use rows.Next() err = rows.Scan(&id) if err != nil { - return 0, storage.handlePgErr(err) + return 0, handlePgErr(err, op) } return id, nil } -const readUserByEmail = "SELECT * from users WHERE email=? LIMIT 1" - -func (storage *UsersRepository) ReadUserByEmail(ctx context.Context, email string) (*models.User, error) { - var user models.User - query := storage.db.Rebind(readUserByEmail) - err := storage.db.GetContext(ctx, &user, query, email) - if err != nil { - return nil, storage.handlePgErr(err) - } - return &user, nil -} - const readUserByUsername = "SELECT * from users WHERE username=? LIMIT 1" -func (storage *UsersRepository) ReadUserByUsername(ctx context.Context, username string) (*models.User, error) { +func (c *Caller) ReadUserByUsername(ctx context.Context, username string) (*models.User, error) { + const op = "Caller.ReadUserByUsername" + var user models.User - query := storage.db.Rebind(readUserByUsername) - err := storage.db.GetContext(ctx, &user, query, username) + query := c.db.Rebind(readUserByUsername) + err := c.db.GetContext(ctx, &user, query, username) if err != nil { - return nil, storage.handlePgErr(err) + return nil, handlePgErr(err, op) } return &user, nil } const readUserById = "SELECT * from users WHERE id=? LIMIT 1" -func (storage *UsersRepository) ReadUserById(ctx context.Context, id int32) (*models.User, error) { +func (c *Caller) ReadUserById(ctx context.Context, id int32) (*models.User, error) { + const op = "Caller.ReadUserById" + var user models.User - query := storage.db.Rebind(readUserById) - err := storage.db.GetContext(ctx, &user, query, id) + query := c.db.Rebind(readUserById) + err := c.db.GetContext(ctx, &user, query, id) if err != nil { - return nil, storage.handlePgErr(err) + return nil, handlePgErr(err, op) } return &user, nil } const updateUser = ` UPDATE users -SET username = COALESCE(?, username), - hashed_pwd = COALESCE(?, hashed_pwd), - email = COALESCE(?, email), - expires_at = COALESCE(?, expires_at), +SET username = COALESCE(?, trim(lower(username))), role = COALESCE(?, role) WHERE id = ? ` -func (storage *UsersRepository) UpdateUser(ctx context.Context, user *models.User) error { +func (c *Caller) UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error { + const op = "Caller.UpdateUser" + var err error - if user.Username != nil { - if err = user.ValidUsername(); err != nil { - return err - } - } - if user.Password != nil { - if err = user.ValidPassword(); err != nil { - return err - } - if err = user.HashPassword(); err != nil { - return err - } - } - if user.Email != nil { - if err = utils.ValidEmail(*user.Email); err != nil { - return err - } - } - if user.Role != nil { - if err = user.Role.Valid(); err != nil { - return err + if username != nil { + if err = ValidUsername(*username); err != nil { + return pkg.Wrap(pkg.ErrBadInput, err, op, "username validation") } } - query := storage.db.Rebind(updateUser) - - _, err = storage.db.ExecContext( + query := c.db.Rebind(updateUser) + _, err = c.db.ExecContext( ctx, query, - user.Username, - user.Password, - user.Email, - user.ExpiresAt, - user.Role, - user.Id, + username, + role, + id, ) if err != nil { - return storage.handlePgErr(err) + return handlePgErr(err, op) } return nil } -const deleteUser = "UPDATE users SET expired_at=now() WHERE id = ?" +const deleteUser = "DELETE FROM users WHERE id = ?" -func (storage *UsersRepository) DeleteUser(ctx context.Context, id int32) error { - query := storage.db.Rebind(deleteUser) - _, err := storage.db.ExecContext(ctx, query, id) +func (c *Caller) DeleteUser(ctx context.Context, id int32) error { + const op = "Caller.DeleteUser" + + query := c.db.Rebind(deleteUser) + _, err := c.db.ExecContext(ctx, query, id) if err != nil { - return storage.handlePgErr(err) + return handlePgErr(err, op) } return nil } -func (storage *UsersRepository) handlePgErr(err error) error { +func handlePgErr(err error, op string) error { var pgErr *pgconn.PgError - if !errors.As(err, &pgErr) { - storage.logger.DPanic("unexpected error from postgres", zap.String("err", err.Error())) - return utils.ErrUnexpected + if errors.As(err, &pgErr) { + if pgerrcode.IsIntegrityConstraintViolation(pgErr.Code) { + return pkg.Wrap(pkg.ErrBadInput, err, op, pgErr.Message) + } + if pgerrcode.IsNoData(pgErr.Code) { + return pkg.Wrap(pkg.ErrNotFound, err, op, pgErr.Message) + } } - if pgerrcode.IsIntegrityConstraintViolation(pgErr.Code) { - return errors.New("unique key violation") // FIXME - } - storage.logger.DPanic("unexpected internal error from postgres", zap.String("err", err.Error())) - return utils.ErrInternal + return pkg.Wrap(pkg.ErrUnhandled, err, op, "unexpected error") +} + +func ValidEmail(str string) error { + emailAddress, err := mail.ParseAddress(str) + if err != nil || emailAddress.Address != str { + return errors.New("invalid email") + } + return nil +} + +func ValidUsername(str string) error { + if len(str) < 5 { + return errors.New("too short username") + } + if len(str) > 70 { + return errors.New("too long username") + } + if err := ValidEmail(str); err == nil { + return errors.New("username cannot be an email") + } + return nil +} + +func ValidPassword(str string) error { + if len(str) < 5 { + return errors.New("too short password") + } + if len(str) > 70 { + return errors.New("too long password") + } + return nil +} + +func ValidRole(role models.Role) error { + switch role { + case models.RoleAdmin: + return nil + case models.RoleModerator: + return nil + case models.RoleParticipant: + return nil + } + return errors.New("invalid role") } diff --git a/internal/users/repository/pg_repository_test.go b/internal/users/repository/pg_repository_test.go index d8bc664..3f77509 100644 --- a/internal/users/repository/pg_repository_test.go +++ b/internal/users/repository/pg_repository_test.go @@ -3,18 +3,16 @@ package repository import ( "context" "database/sql/driver" - "git.sch9.ru/new_gate/ms-auth/pkg/utils" + "git.sch9.ru/new_gate/ms-auth/internal/models" + "git.sch9.ru/new_gate/ms-auth/pkg" "github.com/DATA-DOG/go-sqlmock" "github.com/jmoiron/sqlx" "github.com/stretchr/testify/require" - "go.uber.org/zap" "testing" "time" - - "git.sch9.ru/new_gate/ms-auth/internal/models" ) -func TestUsersRepository_CreateUser(t *testing.T) { +func TestCaller_CreateUser(t *testing.T) { t.Parallel() db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) @@ -24,39 +22,98 @@ func TestUsersRepository_CreateUser(t *testing.T) { sqlxDB := sqlx.NewDb(db, "sqlmock") defer sqlxDB.Close() - userRepo := NewUserRepository(sqlxDB, zap.NewNop()) + userRepo := NewUserRepository(sqlxDB) t.Run("valid user creation", func(t *testing.T) { rows := sqlmock.NewRows([]string{"id"}).AddRow(1) - user := &models.User{ - Username: utils.AsStringP("testuser"), - Password: utils.AsStringP("testpassword"), - Email: utils.AsStringP("test@example.com"), - Role: AsRoleP(models.RoleAdmin), - } + username := "testuser" + password := "testpassword" + role := models.RoleAdmin mock.ExpectQuery(sqlxDB.Rebind(createUser)).WithArgs( - user.Username, + username, AnyString{}, - user.Email, - AnyTime{}, - user.Role, + role, ).WillReturnRows(rows) - _, err = userRepo.CreateUser(context.Background(), user) + _, err = userRepo.C().CreateUser(context.Background(), username, password, role) require.NoError(t, err) }) - // TODO: add more tests - // invalid username - // invalid password - // invalid email - // invalid role - // password hashing error - // database query error - // database scan error - // etc + t.Run("invalid user creation (invalid username)", func(t *testing.T) { + username := "test" + password := "testpassword" + role := models.RoleAdmin + + _, err = userRepo.C().CreateUser(context.Background(), username, password, role) + require.ErrorIs(t, err, pkg.ErrBadInput) + }) + + t.Run("invalid user creation (invalid password)", func(t *testing.T) { + username := "testuser" + password := "test" + role := models.RoleAdmin + + _, err = userRepo.C().CreateUser(context.Background(), username, password, role) + require.ErrorIs(t, err, pkg.ErrBadInput) + }) + + t.Run("invalid user creation (invalid role)", func(t *testing.T) { + username := "testuser" + password := "testpassword" + _, err = userRepo.C().CreateUser(context.Background(), username, password, 123) + require.ErrorIs(t, err, pkg.ErrBadInput) + }) +} + +func TestCaller_ReadUserByUsername(t *testing.T) { + t.Parallel() + + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + require.NoError(t, err) + defer db.Close() + + sqlxDB := sqlx.NewDb(db, "sqlmock") + defer sqlxDB.Close() + + userRepo := NewUserRepository(sqlxDB) + + t.Run("valid user read", func(t *testing.T) { + user := &models.User{ + Id: 1, + Username: "testuser", + HashedPassword: "hashedtestpassword", + CreatedAt: time.Now(), + ModifiedAt: time.Now(), + Role: models.RoleAdmin, + } + + rows := sqlmock.NewRows( + []string{ + "id", + "username", + "hashed_pwd", + "created_at", + "modified_at", + "role", + }).AddRow( + user.Id, + user.Username, + user.HashedPassword, + user.CreatedAt, + user.ModifiedAt, + user.Role, + ) + + mock.ExpectQuery(sqlxDB.Rebind(readUserByUsername)).WithArgs( + user.Username, + ).WillReturnRows(rows) + + readUser, err := userRepo.C().ReadUserByUsername(context.Background(), user.Username) + require.NoError(t, err) + require.Equal(t, user, readUser) + }) } func TestUsersRepository_ReadUserById(t *testing.T) { @@ -69,18 +126,16 @@ func TestUsersRepository_ReadUserById(t *testing.T) { sqlxDB := sqlx.NewDb(db, "sqlmock") defer sqlxDB.Close() - userRepo := NewUserRepository(sqlxDB, zap.NewNop()) + userRepo := NewUserRepository(sqlxDB) t.Run("valid user read", func(t *testing.T) { user := &models.User{ - Id: utils.AsInt32P(1), - Username: utils.AsStringP("testuser"), - Password: utils.AsStringP("testpassword"), - Email: utils.AsStringP("test@example.com"), - ExpiresAt: utils.AsTimeP(time.Now().Add(1 * time.Hour)), - CreatedAt: utils.AsTimeP(time.Now()), - UpdatedAt: utils.AsTimeP(time.Now()), - Role: AsRoleP(models.RoleAdmin), + Id: 1, + Username: "testuser", + HashedPassword: "hashedtestpassword", + CreatedAt: time.Now(), + ModifiedAt: time.Now(), + Role: models.RoleAdmin, } rows := sqlmock.NewRows( @@ -88,27 +143,23 @@ func TestUsersRepository_ReadUserById(t *testing.T) { "id", "username", "hashed_pwd", - "email", - "expires_at", "created_at", - "updated_at", + "modified_at", "role", }).AddRow( - *user.Id, - *user.Username, - *user.Password, - *user.Email, - *user.ExpiresAt, - *user.CreatedAt, - *user.UpdatedAt, - *user.Role, + user.Id, + user.Username, + user.HashedPassword, + user.CreatedAt, + user.ModifiedAt, + user.Role, ) mock.ExpectQuery(sqlxDB.Rebind(readUserById)).WithArgs( - *user.Id, + user.Id, ).WillReturnRows(rows) - readUser, err := userRepo.ReadUserById(context.Background(), *user.Id) + readUser, err := userRepo.C().ReadUserById(context.Background(), user.Id) require.NoError(t, err) require.Equal(t, user, readUser) }) @@ -124,38 +175,30 @@ func TestUsersRepository_UpdateUser(t *testing.T) { sqlxDB := sqlx.NewDb(db, "sqlmock") defer sqlxDB.Close() - userRepo := NewUserRepository(sqlxDB, zap.NewNop()) + userRepo := NewUserRepository(sqlxDB) t.Run("valid user update", func(t *testing.T) { user := &models.User{ - Id: utils.AsInt32P(1), - Username: utils.AsStringP("testuser"), - Password: utils.AsStringP("testpassword"), - Email: utils.AsStringP("test@example.com"), - ExpiresAt: utils.AsTimeP(time.Now().Add(1 * time.Hour)), - Role: AsRoleP(models.RoleAdmin), + Id: 1, + Username: "testuser", + HashedPassword: "hashedtestpassword", + CreatedAt: time.Now(), + ModifiedAt: time.Now(), + Role: models.RoleAdmin, } - require.NoError(t, err) + mock.ExpectExec(sqlxDB.Rebind(updateUser)).WithArgs( + user.Username, + user.Role, + user.Id, + ).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec(sqlxDB.Rebind(updateUser)). - WithArgs( - *user.Username, - AnyString{}, - *user.Email, - *user.ExpiresAt, - *user.Role, - *user.Id, - ).WillReturnResult(sqlmock.NewResult(1, 1)) - - err = userRepo.UpdateUser(context.Background(), user) + err := userRepo.C().UpdateUser(context.Background(), user.Id, AsStringP(user.Username), AsRoleP(user.Role)) require.NoError(t, err) }) - - // TODO: add more tests } -func TestUsersRepository_DeleteUser(t *testing.T) { +func TestCaller_DeleteUser(t *testing.T) { t.Parallel() db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) @@ -165,20 +208,16 @@ func TestUsersRepository_DeleteUser(t *testing.T) { sqlxDB := sqlx.NewDb(db, "sqlmock") defer sqlxDB.Close() - userRepo := NewUserRepository(sqlxDB, zap.NewNop()) + userRepo := NewUserRepository(sqlxDB) - t.Run("valid user deletion", func(t *testing.T) { - user := &models.User{ - Id: utils.AsInt32P(1), - } + t.Run("valid user delete", func(t *testing.T) { + mock.ExpectExec(sqlxDB.Rebind(deleteUser)).WithArgs( + 1, + ).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec(sqlxDB.Rebind(deleteUser)).WithArgs(*user.Id).WillReturnResult(sqlmock.NewResult(1, 1)) - - err = userRepo.DeleteUser(context.Background(), *user.Id) + err := userRepo.C().DeleteUser(context.Background(), 1) require.NoError(t, err) }) - - // TODO: add more tests } func AsRoleP(r models.Role) *models.Role { @@ -199,3 +238,7 @@ func (a AnyString) Match(v driver.Value) bool { _, ok := v.(string) return ok } + +func AsStringP(str string) *string { + return &str +} diff --git a/internal/users/repository/valkey_repository.go b/internal/users/repository/valkey_repository.go new file mode 100644 index 0000000..d4d3516 --- /dev/null +++ b/internal/users/repository/valkey_repository.go @@ -0,0 +1,194 @@ +package repository + +import ( + "context" + "fmt" + "git.sch9.ru/new_gate/ms-auth/internal/models" + "git.sch9.ru/new_gate/ms-auth/pkg" + "github.com/google/uuid" + "github.com/valkey-io/valkey-go" + "strconv" + "time" +) + +type ValkeyRepository struct { + db valkey.Client +} + +func NewValkeyRepository(db valkey.Client) *ValkeyRepository { + return &ValkeyRepository{ + db: db, + } +} + +const sessionLifetime = time.Minute * 40 + +func (r *ValkeyRepository) CreateSession(ctx context.Context, userId int32, role models.Role) (string, error) { + const op = "ValkeyRepository.CreateSession" + + sessionData, sessionId, err := models.NewSession(userId, role) + if err != nil { + return "", pkg.Wrap(pkg.ErrBadInput, err, op, "building session") + } + + resp := r.db.Do(ctx, r.db. + B().Set(). + Key(fmt.Sprintf("userid:%d:sessionid:%s", userId, sessionId)). + Value(sessionData). + Ex(sessionLifetime). + Build(), + ) + + err = resp.Error() + if err != nil { + if valkey.IsValkeyNil(err) { + return "", pkg.Wrap(pkg.ErrBadInput, err, op, "nil response") + } + return "", pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error") + } + + return sessionId, nil +} + +const ( + readSessionScript = `local result = redis.call('SCAN', 0, 'MATCH', ARGV[1]) +if #result[2] == 0 then + return nil +else + return redis.call('GET', result[2][1]) +end` +) + +func (r *ValkeyRepository) ReadSession(ctx context.Context, sessionId string) (*models.Session, error) { + const op = "ValkeyRepository.ReadSession" + + err := uuid.Validate(sessionId) + if err != nil { + return nil, pkg.Wrap(pkg.ErrBadInput, err, op, "uuid validation") + } + + resp := valkey.NewLuaScript(readSessionScript).Exec( + ctx, + r.db, + nil, + []string{fmt.Sprintf("userid:*:sessionid:%s", sessionId)}, + ) + + if err = resp.Error(); err != nil { + if valkey.IsValkeyNil(err) { + return nil, pkg.Wrap(pkg.ErrNotFound, err, op, "reading session") + } + return nil, pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error") + } + + str, err := resp.ToString() + if err != nil { + return nil, pkg.Wrap(pkg.ErrInternal, err, op, "session storage corrupted") + } + + session, err := models.ParseSession(str) + if err != nil { + return nil, pkg.Wrap(pkg.ErrInternal, err, op, "session corrupted") + } + + return session, nil +} + +const ( + updateSessionScript = `local result = redis.call('SCAN', 0, 'MATCH', ARGV[1]) +return #result[2] > 0 and redis.call('EXPIRE', result[2][1], ARGV[2]) == 1` +) + +var ( + sessionLifetimeString = strconv.Itoa(int(sessionLifetime.Seconds())) +) + +func (r *ValkeyRepository) UpdateSession(ctx context.Context, sessionId string) error { + const op = "ValkeyRepository.UpdateSession" + + err := uuid.Validate(sessionId) + if err != nil { + return pkg.Wrap(pkg.ErrBadInput, err, op, "uuid validation") + } + + resp := valkey.NewLuaScript(updateSessionScript).Exec( + ctx, + r.db, + nil, + []string{fmt.Sprintf("userid:*:sessionid:%s", sessionId), sessionLifetimeString}, + ) + + err = resp.Error() + if err != nil { + if valkey.IsValkeyNil(err) { + return pkg.Wrap(pkg.ErrBadInput, err, op, "nil response") + } + return pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error") + } + + return nil +} + +const deleteSessionScript = `local result = redis.call('SCAN', 0, 'MATCH', KEYS[1]) +return #result[2] > 0 and redis.call('DEL', result[2][1]) == 1` + +func (r *ValkeyRepository) DeleteSession(ctx context.Context, sessionId string) error { + const op = "ValkeyRepository.DeleteSession" + + err := uuid.Validate(sessionId) + if err != nil { + return pkg.Wrap(pkg.ErrBadInput, err, op, "uuid validation") + } + + resp := valkey.NewLuaScript(deleteSessionScript).Exec( + ctx, + r.db, + nil, + []string{fmt.Sprintf("userid:*:sessionid:%s", sessionId)}, + ) + + err = resp.Error() + if err != nil { + if valkey.IsValkeyNil(err) { + return pkg.Wrap(pkg.ErrBadInput, err, op, "nil response") + } + return pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error") + } + + return nil +} + +const ( + deleteUserSessionsScript = `local cursor = 0 +local dels = 0 +repeat + local result = redis.call('SCAN', cursor, 'MATCH', ARGV[1]) + for _,key in ipairs(result[2]) do + redis.call('DEL', key) + dels = dels + 1 + end + cursor = tonumber(result[1]) +until cursor == 0 +return dels` +) + +func (r *ValkeyRepository) DeleteAllSessions(ctx context.Context, userId int32) error { + const op = "ValkeyRepository.DeleteAllSessions" + + resp := valkey.NewLuaScript(deleteUserSessionsScript).Exec( + ctx, + r.db, + nil, + []string{fmt.Sprintf("userid:%d:sessionid:*", userId)}, + ) + + err := resp.Error() + if err != nil { + if valkey.IsValkeyNil(err) { + return pkg.Wrap(pkg.ErrBadInput, err, op, "nil response") + } + return pkg.Wrap(pkg.ErrUnhandled, err, op, "unhandled valkey error") + } + + return nil +} diff --git a/internal/users/repository/valkey_repository_test.go b/internal/users/repository/valkey_repository_test.go new file mode 100644 index 0000000..e652055 --- /dev/null +++ b/internal/users/repository/valkey_repository_test.go @@ -0,0 +1,267 @@ +package repository + +import ( + "context" + "fmt" + "git.sch9.ru/new_gate/ms-auth/internal/models" + "git.sch9.ru/new_gate/ms-auth/pkg" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "github.com/valkey-io/valkey-go" + "github.com/valkey-io/valkey-go/mock" + "go.uber.org/mock/gomock" + "strings" + "testing" +) + +var ( + matcherAny = mock.MatchFn(func(cmd []string) bool { return true }) +) + +func TestValkeyRepository_CreateSession(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + client := mock.NewClient(ctrl) + sessionRepo := NewValkeyRepository(client) + + var userId int32 = 1 + + matcher := mock.MatchFn(func(cmd []string) bool { + if cmd[0] != "SET" { + return false + } + if !strings.HasPrefix(cmd[1], fmt.Sprintf("userid:%d:sessionid:", userId)) { + return false + } + if cmd[3] != "EX" { + return false + } + if cmd[4] != "2400" { + return false + } + return true + }) + + t.Run("valid session creation", func(t *testing.T) { + ctx := context.Background() + client.EXPECT().Do(ctx, matcher) + sessionId, err := sessionRepo.CreateSession(context.Background(), userId, models.RoleAdmin) + require.NoError(t, err) + require.NotEmpty(t, sessionId) + }) + + t.Run("invalid session creation 1", func(t *testing.T) { + ctx := context.Background() + client.EXPECT().Do(ctx, matcher).Return(mock.ErrorResult(valkey.Nil)) + sessionId, err := sessionRepo.CreateSession(context.Background(), userId, models.RoleAdmin) + require.ErrorIs(t, err, pkg.ErrBadInput) + require.ErrorIs(t, err, valkey.Nil) + require.Empty(t, sessionId) + }) + + t.Run("invalid session creation 2 (invalid userid)", func(t *testing.T) { + sessionId, err := sessionRepo.CreateSession(context.Background(), 0, models.RoleAdmin) + require.ErrorIs(t, err, pkg.ErrBadInput) + require.Empty(t, sessionId) + }) + + t.Run("invalid session creation 3 (invalid role)", func(t *testing.T) { + sessionId, err := sessionRepo.CreateSession(context.Background(), userId, 123) + require.ErrorIs(t, err, pkg.ErrBadInput) + require.Empty(t, sessionId) + }) +} + +func TestValkeyRepository_ReadSession(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + client := mock.NewClient(ctrl) + sessionRepo := NewValkeyRepository(client) + + matcher := mock.MatchFn(func(cmd []string) bool { + if cmd[0] != "EVALSHA" { + return false + } + if cmd[2] != "0" { + return false + } + if !strings.HasPrefix(cmd[3], "userid:*:sessionid:") { + return false + } + return true + }) + + t.Run("valid session read", func(t *testing.T) { + data, id, err := models.NewSession(1, models.RoleAdmin) + require.NoError(t, err) + ctx := context.Background() + client.EXPECT().Do(ctx, matcher).Return(mock.Result(mock.ValkeyString(data))) + res, err := sessionRepo.ReadSession(context.Background(), id) + require.NoError(t, err) + require.Equal(t, int32(1), res.UserId) + require.Equal(t, id, res.Id) + require.Equal(t, models.RoleAdmin, res.Role) + }) + + t.Run("invalid session read 1 (not found)", func(t *testing.T) { + _, id, err := models.NewSession(1, models.RoleAdmin) + require.NoError(t, err) + ctx := context.Background() + client.EXPECT().Do(ctx, matcher).Return(mock.ErrorResult(valkey.Nil)) + res, err := sessionRepo.ReadSession(context.Background(), id) + require.ErrorIs(t, err, pkg.ErrNotFound) + require.ErrorIs(t, err, valkey.Nil) + require.Empty(t, res) + }) + + t.Run("invalid session read 2 (corrupted session storage)", func(t *testing.T) { + _, id, err := models.NewSession(1, models.RoleAdmin) + require.NoError(t, err) + ctx := context.Background() + client.EXPECT().Do(ctx, matcher).Return(mock.Result(mock.ValkeyInt64(123))) + res, err := sessionRepo.ReadSession(context.Background(), id) + require.ErrorIs(t, err, pkg.ErrInternal) + require.True(t, valkey.IsParseErr(err)) + require.Empty(t, res) + }) + + t.Run("invalid session read 3 (bad sessionid)", func(t *testing.T) { + res, err := sessionRepo.ReadSession(context.Background(), "123") + require.ErrorIs(t, err, pkg.ErrBadInput) + require.Empty(t, res) + }) +} + +func TestValkeyRepository_UpdateSession(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + client := mock.NewClient(ctrl) + sessionRepo := NewValkeyRepository(client) + + matcher := mock.MatchFn(func(cmd []string) bool { + if cmd[0] != "EVALSHA" { + return false + } + if cmd[2] != "0" { + return false + } + if !strings.HasPrefix(cmd[3], "userid:*:sessionid:") { + return false + } + return true + }) + + t.Run("valid session update", func(t *testing.T) { + id := uuid.NewString() + ctx := context.Background() + client.EXPECT().Do(ctx, matcher) + err := sessionRepo.UpdateSession(context.Background(), id) + require.NoError(t, err) + }) + + t.Run("invalid session update 1 (nil response)", func(t *testing.T) { + id := uuid.NewString() + ctx := context.Background() + client.EXPECT().Do(ctx, matcherAny).Return(mock.ErrorResult(valkey.Nil)) + err := sessionRepo.UpdateSession(context.Background(), id) + require.ErrorIs(t, err, pkg.ErrBadInput) + require.ErrorIs(t, err, valkey.Nil) + }) + + t.Run("invalid session update 2 (bad sessionid)", func(t *testing.T) { + err := sessionRepo.UpdateSession(context.Background(), "123") + require.ErrorIs(t, err, pkg.ErrBadInput) + }) +} + +func TestValkeyRepository_DeleteSession(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + client := mock.NewClient(ctrl) + sessionRepo := NewValkeyRepository(client) + + matcher := mock.MatchFn(func(cmd []string) bool { + if cmd[0] != "EVALSHA" { + return false + } + if cmd[2] != "0" { + return false + } + if !strings.HasPrefix(cmd[3], "userid:*:sessionid:") { + return false + } + return true + }) + + t.Run("valid session delete", func(t *testing.T) { + id := uuid.NewString() + ctx := context.Background() + client.EXPECT().Do(ctx, matcher).Return(mock.Result(mock.ValkeyInt64(1))) + err := sessionRepo.DeleteSession(context.Background(), id) + require.NoError(t, err) + }) + + t.Run("invalid session delete 1", func(t *testing.T) { + id := uuid.NewString() + ctx := context.Background() + client.EXPECT().Do(ctx, matcher).Return(mock.Result(mock.ValkeyNil())) + err := sessionRepo.DeleteSession(context.Background(), id) + require.ErrorIs(t, err, pkg.ErrBadInput) + require.ErrorIs(t, err, valkey.Nil) + }) + + t.Run("invalid session delete 2 (bad sessionid)", func(t *testing.T) { + err := sessionRepo.DeleteSession(context.Background(), "123") + require.ErrorIs(t, err, pkg.ErrBadInput) + }) +} + +func TestValkeyRepository_DeleteAllSessions(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + client := mock.NewClient(ctrl) + sessionRepo := NewValkeyRepository(client) + + matcher := mock.MatchFn(func(cmd []string) bool { + if cmd[0] != "EVALSHA" { + return false + } + if cmd[2] != "0" { + return false + } + if !strings.HasPrefix(cmd[3], "userid:1:sessionid:*") { + return false + } + return true + }) + + t.Run("valid all sessions deletion", func(t *testing.T) { + ctx := context.Background() + client.EXPECT().Do(ctx, matcher).Return(mock.Result(mock.ValkeyInt64(1))) + err := sessionRepo.DeleteAllSessions(context.Background(), 1) + require.NoError(t, err) + }) + + t.Run("invalid all sessions deletion 1 (nil response)", func(t *testing.T) { + ctx := context.Background() + client.EXPECT().Do(ctx, matcher).Return(mock.Result(mock.ValkeyNil())) + err := sessionRepo.DeleteAllSessions(context.Background(), 1) + require.ErrorIs(t, err, pkg.ErrBadInput) + require.ErrorIs(t, err, valkey.Nil) + }) +} diff --git a/internal/users/usecase.go b/internal/users/usecase.go index d848193..12724b6 100644 --- a/internal/users/usecase.go +++ b/internal/users/usecase.go @@ -6,11 +6,15 @@ import ( ) type UseCase interface { - CreateUser(ctx context.Context, user *models.User) (int32, error) - ReadUserBySessionToken(ctx context.Context, token string) (*models.User, error) - ReadUser(ctx context.Context, id int32) (*models.User, error) - ReadUserByEmail(ctx context.Context, email string) (*models.User, error) + CreateUser(ctx context.Context, username string, password string, role models.Role) (int32, error) + ReadUserById(ctx context.Context, id int32) (*models.User, error) ReadUserByUsername(ctx context.Context, username string) (*models.User, error) - UpdateUser(ctx context.Context, modifiedUser *models.User) error + UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error DeleteUser(ctx context.Context, id int32) error + CreateSession(ctx context.Context, userId int32, role models.Role) (string, error) + ReadSession(ctx context.Context, sessionId string) (*models.Session, error) + UpdateSession(ctx context.Context, sessionId string) error + DeleteSession(ctx context.Context, sessionId string) error + DeleteAllSessions(ctx context.Context, userId int32) error + Verify(ctx context.Context, sessionId string) (string, error) } diff --git a/internal/users/usecase/mock/repository_mock.go b/internal/users/usecase/mock/repository_mock.go new file mode 100644 index 0000000..0f0b68f --- /dev/null +++ b/internal/users/usecase/mock/repository_mock.go @@ -0,0 +1,390 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: git.sch9.ru/new_gate/ms-auth/internal/users (interfaces: PgRepository,Caller,TxCaller,ValkeyRepository) +// +// Generated by this command: +// +// mockgen . PgRepository,Caller,TxCaller,ValkeyRepository +// + +// Package mock_users is a generated GoMock package. +package mock_users + +import ( + context "context" + reflect "reflect" + + models "git.sch9.ru/new_gate/ms-auth/internal/models" + users "git.sch9.ru/new_gate/ms-auth/internal/users" + gomock "go.uber.org/mock/gomock" +) + +// MockPgRepository is a mock of PgRepository interface. +type MockPgRepository struct { + ctrl *gomock.Controller + recorder *MockPgRepositoryMockRecorder + isgomock struct{} +} + +// MockPgRepositoryMockRecorder is the mock recorder for MockPgRepository. +type MockPgRepositoryMockRecorder struct { + mock *MockPgRepository +} + +// NewMockPgRepository creates a new mock instance. +func NewMockPgRepository(ctrl *gomock.Controller) *MockPgRepository { + mock := &MockPgRepository{ctrl: ctrl} + mock.recorder = &MockPgRepositoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPgRepository) EXPECT() *MockPgRepositoryMockRecorder { + return m.recorder +} + +// BeginTx mocks base method. +func (m *MockPgRepository) BeginTx(ctx context.Context) (users.TxCaller, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTx", ctx) + ret0, _ := ret[0].(users.TxCaller) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTx indicates an expected call of BeginTx. +func (mr *MockPgRepositoryMockRecorder) BeginTx(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockPgRepository)(nil).BeginTx), ctx) +} + +// C mocks base method. +func (m *MockPgRepository) C() users.Caller { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "C") + ret0, _ := ret[0].(users.Caller) + return ret0 +} + +// C indicates an expected call of C. +func (mr *MockPgRepositoryMockRecorder) C() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "C", reflect.TypeOf((*MockPgRepository)(nil).C)) +} + +// MockCaller is a mock of Caller interface. +type MockCaller struct { + ctrl *gomock.Controller + recorder *MockCallerMockRecorder + isgomock struct{} +} + +// MockCallerMockRecorder is the mock recorder for MockCaller. +type MockCallerMockRecorder struct { + mock *MockCaller +} + +// NewMockCaller creates a new mock instance. +func NewMockCaller(ctrl *gomock.Controller) *MockCaller { + mock := &MockCaller{ctrl: ctrl} + mock.recorder = &MockCallerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCaller) EXPECT() *MockCallerMockRecorder { + return m.recorder +} + +// CreateUser mocks base method. +func (m *MockCaller) CreateUser(ctx context.Context, username, password string, role models.Role) (int32, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateUser", ctx, username, password, role) + ret0, _ := ret[0].(int32) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateUser indicates an expected call of CreateUser. +func (mr *MockCallerMockRecorder) CreateUser(ctx, username, password, role any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateUser", reflect.TypeOf((*MockCaller)(nil).CreateUser), ctx, username, password, role) +} + +// DeleteUser mocks base method. +func (m *MockCaller) DeleteUser(ctx context.Context, id int32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteUser", ctx, id) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteUser indicates an expected call of DeleteUser. +func (mr *MockCallerMockRecorder) DeleteUser(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUser", reflect.TypeOf((*MockCaller)(nil).DeleteUser), ctx, id) +} + +// ReadUserById mocks base method. +func (m *MockCaller) ReadUserById(ctx context.Context, id int32) (*models.User, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadUserById", ctx, id) + ret0, _ := ret[0].(*models.User) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadUserById indicates an expected call of ReadUserById. +func (mr *MockCallerMockRecorder) ReadUserById(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadUserById", reflect.TypeOf((*MockCaller)(nil).ReadUserById), ctx, id) +} + +// ReadUserByUsername mocks base method. +func (m *MockCaller) ReadUserByUsername(ctx context.Context, username string) (*models.User, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadUserByUsername", ctx, username) + ret0, _ := ret[0].(*models.User) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadUserByUsername indicates an expected call of ReadUserByUsername. +func (mr *MockCallerMockRecorder) ReadUserByUsername(ctx, username any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadUserByUsername", reflect.TypeOf((*MockCaller)(nil).ReadUserByUsername), ctx, username) +} + +// UpdateUser mocks base method. +func (m *MockCaller) UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateUser", ctx, id, username, role) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateUser indicates an expected call of UpdateUser. +func (mr *MockCallerMockRecorder) UpdateUser(ctx, id, username, role any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockCaller)(nil).UpdateUser), ctx, id, username, role) +} + +// MockTxCaller is a mock of TxCaller interface. +type MockTxCaller struct { + ctrl *gomock.Controller + recorder *MockTxCallerMockRecorder + isgomock struct{} +} + +// MockTxCallerMockRecorder is the mock recorder for MockTxCaller. +type MockTxCallerMockRecorder struct { + mock *MockTxCaller +} + +// NewMockTxCaller creates a new mock instance. +func NewMockTxCaller(ctrl *gomock.Controller) *MockTxCaller { + mock := &MockTxCaller{ctrl: ctrl} + mock.recorder = &MockTxCallerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTxCaller) EXPECT() *MockTxCallerMockRecorder { + return m.recorder +} + +// Commit mocks base method. +func (m *MockTxCaller) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockTxCallerMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTxCaller)(nil).Commit)) +} + +// CreateUser mocks base method. +func (m *MockTxCaller) CreateUser(ctx context.Context, username, password string, role models.Role) (int32, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateUser", ctx, username, password, role) + ret0, _ := ret[0].(int32) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateUser indicates an expected call of CreateUser. +func (mr *MockTxCallerMockRecorder) CreateUser(ctx, username, password, role any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateUser", reflect.TypeOf((*MockTxCaller)(nil).CreateUser), ctx, username, password, role) +} + +// DeleteUser mocks base method. +func (m *MockTxCaller) DeleteUser(ctx context.Context, id int32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteUser", ctx, id) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteUser indicates an expected call of DeleteUser. +func (mr *MockTxCallerMockRecorder) DeleteUser(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUser", reflect.TypeOf((*MockTxCaller)(nil).DeleteUser), ctx, id) +} + +// ReadUserById mocks base method. +func (m *MockTxCaller) ReadUserById(ctx context.Context, id int32) (*models.User, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadUserById", ctx, id) + ret0, _ := ret[0].(*models.User) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadUserById indicates an expected call of ReadUserById. +func (mr *MockTxCallerMockRecorder) ReadUserById(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadUserById", reflect.TypeOf((*MockTxCaller)(nil).ReadUserById), ctx, id) +} + +// ReadUserByUsername mocks base method. +func (m *MockTxCaller) ReadUserByUsername(ctx context.Context, username string) (*models.User, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadUserByUsername", ctx, username) + ret0, _ := ret[0].(*models.User) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadUserByUsername indicates an expected call of ReadUserByUsername. +func (mr *MockTxCallerMockRecorder) ReadUserByUsername(ctx, username any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadUserByUsername", reflect.TypeOf((*MockTxCaller)(nil).ReadUserByUsername), ctx, username) +} + +// Rollback mocks base method. +func (m *MockTxCaller) Rollback() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Rollback") + ret0, _ := ret[0].(error) + return ret0 +} + +// Rollback indicates an expected call of Rollback. +func (mr *MockTxCallerMockRecorder) Rollback() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockTxCaller)(nil).Rollback)) +} + +// UpdateUser mocks base method. +func (m *MockTxCaller) UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateUser", ctx, id, username, role) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateUser indicates an expected call of UpdateUser. +func (mr *MockTxCallerMockRecorder) UpdateUser(ctx, id, username, role any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockTxCaller)(nil).UpdateUser), ctx, id, username, role) +} + +// MockValkeyRepository is a mock of ValkeyRepository interface. +type MockValkeyRepository struct { + ctrl *gomock.Controller + recorder *MockValkeyRepositoryMockRecorder + isgomock struct{} +} + +// MockValkeyRepositoryMockRecorder is the mock recorder for MockValkeyRepository. +type MockValkeyRepositoryMockRecorder struct { + mock *MockValkeyRepository +} + +// NewMockValkeyRepository creates a new mock instance. +func NewMockValkeyRepository(ctrl *gomock.Controller) *MockValkeyRepository { + mock := &MockValkeyRepository{ctrl: ctrl} + mock.recorder = &MockValkeyRepositoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockValkeyRepository) EXPECT() *MockValkeyRepositoryMockRecorder { + return m.recorder +} + +// CreateSession mocks base method. +func (m *MockValkeyRepository) CreateSession(ctx context.Context, userId int32, role models.Role) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSession", ctx, userId, role) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateSession indicates an expected call of CreateSession. +func (mr *MockValkeyRepositoryMockRecorder) CreateSession(ctx, userId, role any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSession", reflect.TypeOf((*MockValkeyRepository)(nil).CreateSession), ctx, userId, role) +} + +// DeleteAllSessions mocks base method. +func (m *MockValkeyRepository) DeleteAllSessions(ctx context.Context, userId int32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAllSessions", ctx, userId) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAllSessions indicates an expected call of DeleteAllSessions. +func (mr *MockValkeyRepositoryMockRecorder) DeleteAllSessions(ctx, userId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllSessions", reflect.TypeOf((*MockValkeyRepository)(nil).DeleteAllSessions), ctx, userId) +} + +// DeleteSession mocks base method. +func (m *MockValkeyRepository) DeleteSession(ctx context.Context, sessionId string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteSession", ctx, sessionId) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteSession indicates an expected call of DeleteSession. +func (mr *MockValkeyRepositoryMockRecorder) DeleteSession(ctx, sessionId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSession", reflect.TypeOf((*MockValkeyRepository)(nil).DeleteSession), ctx, sessionId) +} + +// ReadSession mocks base method. +func (m *MockValkeyRepository) ReadSession(ctx context.Context, sessionId string) (*models.Session, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadSession", ctx, sessionId) + ret0, _ := ret[0].(*models.Session) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadSession indicates an expected call of ReadSession. +func (mr *MockValkeyRepositoryMockRecorder) ReadSession(ctx, sessionId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadSession", reflect.TypeOf((*MockValkeyRepository)(nil).ReadSession), ctx, sessionId) +} + +// UpdateSession mocks base method. +func (m *MockValkeyRepository) UpdateSession(ctx context.Context, sessionId string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateSession", ctx, sessionId) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateSession indicates an expected call of UpdateSession. +func (mr *MockValkeyRepositoryMockRecorder) UpdateSession(ctx, sessionId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSession", reflect.TypeOf((*MockValkeyRepository)(nil).UpdateSession), ctx, sessionId) +} diff --git a/internal/users/usecase/usecase.go b/internal/users/usecase/usecase.go index d644154..450eb5a 100644 --- a/internal/users/usecase/usecase.go +++ b/internal/users/usecase/usecase.go @@ -2,148 +2,284 @@ package usecase import ( "context" + "errors" "git.sch9.ru/new_gate/ms-auth/config" "git.sch9.ru/new_gate/ms-auth/internal/models" - "git.sch9.ru/new_gate/ms-auth/internal/sessions" "git.sch9.ru/new_gate/ms-auth/internal/users" - "git.sch9.ru/new_gate/ms-auth/pkg/utils" + "git.sch9.ru/new_gate/ms-auth/pkg" + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "time" ) -type useCase struct { - userRepo users.PgRepository - sessionProvider sessions.ValkeyRepository - cfg config.Config +type UseCase struct { + userRepo users.PgRepository + sessionRepo users.ValkeyRepository + cfg config.Config } func NewUseCase( userRepo users.PgRepository, - sessionRepo sessions.ValkeyRepository, + sessionRepo users.ValkeyRepository, cfg config.Config, -) *useCase { - return &useCase{ - userRepo: userRepo, - sessionProvider: sessionRepo, - cfg: cfg, +) *UseCase { + return &UseCase{ + userRepo: userRepo, + sessionRepo: sessionRepo, + cfg: cfg, } } -func (u *useCase) CreateUser(ctx context.Context, user *models.User) (int32, error) { - meId, ok := ctx.Value("userId").(*int32) +func (u *UseCase) CreateUser(ctx context.Context, username string, password string, role models.Role) (int32, error) { + const op = "UseCase.CreateUser" + + meId, ok := ctx.Value("userId").(int32) if !ok { - return 0, utils.ErrNoPermission + return 0, pkg.Wrap(pkg.ErrUnauthenticated, nil, op, "no user id in context") } - me, err := u.ReadUser(ctx, *meId) + me, err := u.userRepo.C().ReadUserById(ctx, meId) if err != nil { - return 0, err + return 0, pkg.Wrap(nil, err, op, "can't read user by id") } - switch *me.Role { - case models.RoleAdmin: - break - case models.RoleModerator: - if !user.Role.AtMost(models.RoleParticipant) { - return 0, utils.ErrNoPermission + if !me.Role.AtLeast(models.RoleModerator) || me.Role.AtMost(role) && !me.Role.IsAdmin() { + return 0, pkg.Wrap(pkg.NoPermission, nil, op, "no permission") + } + + id, err := u.userRepo.C().CreateUser(ctx, username, password, role) + if err != nil { + return 0, pkg.Wrap(nil, err, op, "can't create user") + } + + return id, nil +} + +func (u *UseCase) ReadUserById(ctx context.Context, id int32) (*models.User, error) { + const op = "UseCase.ReadUserById" + + user, err := u.userRepo.C().ReadUserById(ctx, id) + if err != nil { + return nil, pkg.Wrap(nil, err, op, "can't read user by id") + } + return user, nil +} + +func (u *UseCase) ReadUserByUsername(ctx context.Context, username string) (*models.User, error) { + const op = "UseCase.ReadUserByUsername" + + user, err := u.userRepo.C().ReadUserByUsername(ctx, username) + if err != nil { + return nil, pkg.Wrap(nil, err, op, "can't read user by username") + } + return user, nil +} + +func (u *UseCase) UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error { + const op = "UseCase.UpdateUser" + + meId, ok := ctx.Value("userId").(int32) + if !ok { + return pkg.Wrap(pkg.ErrUnauthenticated, nil, op, "no user id in context") + } + + me, err := u.userRepo.C().ReadUserById(ctx, meId) + if err != nil { + return pkg.Wrap(nil, err, op, "can't read user by id") + } + + user, err := u.userRepo.C().ReadUserById(ctx, id) + if err != nil { + return pkg.Wrap(nil, err, op, "can't read user by id") + } + + hasPermission := func() bool { + if me.Id == user.Id && role != nil { + return false } - default: - return 0, utils.ErrNoPermission - } - - return u.userRepo.CreateUser(ctx, user) -} - -func (u *useCase) ReadUserBySessionToken(ctx context.Context, token string) (*models.User, error) { - session, err := u.sessionProvider.ReadSessionByToken(ctx, token) - if err != nil { - return nil, err - } - - return u.userRepo.ReadUserById(ctx, *session.UserId) -} - -func (u *useCase) ReadUser(ctx context.Context, id int32) (*models.User, error) { - return u.userRepo.ReadUserById(ctx, id) -} - -func (u *useCase) ReadUserByEmail(ctx context.Context, email string) (*models.User, error) { - return u.userRepo.ReadUserByEmail(ctx, email) -} - -func (u *useCase) ReadUserByUsername(ctx context.Context, username string) (*models.User, error) { - return u.userRepo.ReadUserByUsername(ctx, username) -} - -func (u *useCase) UpdateUser(ctx context.Context, modifiedUser *models.User) error { - meId, ok := ctx.Value("userId").(*int32) - if !ok { - return utils.ErrNoPermission - } - - me, err := u.ReadUser(ctx, *meId) - if err != nil { - return err - } - - user, err := u.userRepo.ReadUserById(ctx, *modifiedUser.Id) - if err != nil { - return err - } - - hasAccess := func() bool { if me.Role.IsAdmin() { return true } - if me.Role.IsModerator() { - if !user.Role.AtMost(models.RoleParticipant) { - return false - } - return true - } - if me.Role.IsParticipant() { - if me.Id != user.Id { - return false - } - if modifiedUser.Username != nil { - return false - } - if modifiedUser.Email != nil { - return false - } - if modifiedUser.ExpiresAt != nil { - return false - } - if modifiedUser.Role != nil { - return false - } - return true - } - if me.Role.IsSpectator() { + if role != nil && me.Role.AtMost(*role) { return false } + if !me.Role.AtMost(user.Role) { + return true + } return false }() - if !hasAccess { - return utils.ErrNoPermission + if !hasPermission { + return pkg.Wrap(pkg.NoPermission, nil, op, "no permission") } - return u.userRepo.UpdateUser(ctx, user) + tx, err := u.userRepo.BeginTx(ctx) + if err != nil { + return pkg.Wrap(nil, err, op, "cannot start transaction") + } + + err = tx.UpdateUser(ctx, id, username, role) + if err != nil { + return pkg.Wrap(nil, errors.Join(err, tx.Rollback()), op, "cannot update user") + } + err = u.sessionRepo.DeleteAllSessions(ctx, id) + if err != nil { + return pkg.Wrap(nil, errors.Join(err, tx.Rollback()), op, "cannot delete all sessions") + } + err = tx.Commit() + if err != nil { + return pkg.Wrap(nil, err, op, "cannot commit transaction") + } + + return nil } -func (u *useCase) DeleteUser(ctx context.Context, id int32) error { - userId, ok := ctx.Value("userId").(*int32) +func (u *UseCase) DeleteUser(ctx context.Context, id int32) error { + const op = "UseCase.DeleteUser" + + userId, ok := ctx.Value("userId").(int32) if !ok { - return utils.ErrNoPermission + return pkg.Wrap(pkg.ErrUnauthenticated, nil, op, "no user id in context") } - me, err := u.ReadUser(ctx, *userId) + me, err := u.ReadUserById(ctx, userId) if err != nil { + return pkg.Wrap(nil, err, op, "can't read user by id") + } + + if me.Id == id || !me.Role.IsAdmin() { + return pkg.Wrap(pkg.NoPermission, nil, op, "no permission") + } + + tx, err := u.userRepo.BeginTx(ctx) + if err != nil { + return pkg.Wrap(nil, err, op, "cannot start transaction") + } + + err = tx.DeleteUser(ctx, id) + if err != nil { + return pkg.Wrap(nil, errors.Join(err, tx.Rollback()), op, "cannot delete user") + } + + err = u.sessionRepo.DeleteAllSessions(ctx, id) + if err != nil { + return pkg.Wrap(nil, errors.Join(err, tx.Rollback()), op, "cannot delete all sessions") + } + err = tx.Commit() + if err != nil { + return pkg.Wrap(nil, err, op, "cannot commit transaction") + } + + return nil +} + +func (u *UseCase) CreateSession(ctx context.Context, userId int32, role models.Role) (string, error) { + const op = "UseCase.CreateSession" + + sessionId, err := u.sessionRepo.CreateSession(ctx, userId, role) + if err != nil { + return "", pkg.Wrap(nil, err, op, "cannot create session") + } + + return sessionId, nil +} + +func (u *UseCase) ReadSession(ctx context.Context, sessionId string) (*models.Session, error) { + const op = "UseCase.ReadSession" + + session, err := u.sessionRepo.ReadSession(ctx, sessionId) + if err != nil { + return nil, pkg.Wrap(nil, err, op, "cannot read session") + } + return session, nil +} + +func (u *UseCase) UpdateSession(ctx context.Context, sessionId string) error { + const op = "UseCase.UpdateSession" + + err := u.sessionRepo.UpdateSession(ctx, sessionId) + if err != nil { + return pkg.Wrap(nil, err, op, "cannot update session") + } + return nil +} + +func (u *UseCase) DeleteSession(ctx context.Context, sessionId string) error { + const op = "UseCase.DeleteSession" + + err := u.sessionRepo.DeleteSession(ctx, sessionId) + if err != nil { + return pkg.Wrap(nil, err, op, "cannot delete session") + } + return nil +} + +func (u *UseCase) DeleteAllSessions(ctx context.Context, userId int32) error { + const op = "UseCase.DeleteAllSessions" + + err := u.sessionRepo.DeleteAllSessions(ctx, userId) + if err != nil { + return pkg.Wrap(nil, err, op, "cannot delete all sessions") + } + + return nil +} + +type Token struct { + SessionId string `json:"sid"` + UserId int32 `json:"sub"` + Role models.Role `json:"rle"` + ExpiresAt time.Time `json:"exp"` + IssuedAt time.Time `json:"iat"` + NotBefore time.Time `json:"nbf"` +} + +func (t Token) Valid() error { + if err := uuid.Validate(t.SessionId); err != nil { return err } + if t.UserId <= 0 { + return errors.New("invalid user id") + } + if t.Role <= 0 { + return errors.New("invalid role") + } + if t.ExpiresAt.IsZero() { + return errors.New("invalid exp") + } + if t.IssuedAt.IsZero() { + return errors.New("invalid iat") + } + if t.NotBefore.IsZero() { + return errors.New("invalid nbf") + } + return nil +} - if *me.Id == id || !me.Role.IsAdmin() { - return utils.ErrNoPermission +func (u *UseCase) Verify(ctx context.Context, sessionId string) (string, error) { + const op = "UseCase.Verify" + + session, err := u.sessionRepo.ReadSession(ctx, sessionId) + if err != nil { + return "", pkg.Wrap(nil, err, op, "cannot read session") } - return u.userRepo.DeleteUser(ctx, id) + token := jwt.NewWithClaims( + jwt.SigningMethodHS256, + Token{ + SessionId: sessionId, + UserId: session.UserId, + Role: session.Role, + ExpiresAt: time.Now().Add(time.Hour * 24), + IssuedAt: time.Now(), + NotBefore: time.Now(), + }, + ) + + signedToken, err := token.SignedString([]byte(u.cfg.JWTSecret)) + if err != nil { + return "", pkg.Wrap(pkg.ErrInternal, err, op, "cannot sign token") + } + + return signedToken, nil } diff --git a/internal/users/usecase/usecase_test.go b/internal/users/usecase/usecase_test.go new file mode 100644 index 0000000..f4ed3e1 --- /dev/null +++ b/internal/users/usecase/usecase_test.go @@ -0,0 +1,611 @@ +package usecase + +import ( + "context" + "git.sch9.ru/new_gate/ms-auth/config" + "git.sch9.ru/new_gate/ms-auth/internal/models" + mock_users "git.sch9.ru/new_gate/ms-auth/internal/users/usecase/mock" + "git.sch9.ru/new_gate/ms-auth/pkg" + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "testing" +) + +func TestUseCase_CreateUser(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + pgRepository := mock_users.NewMockPgRepository(ctrl) + vkRepository := mock_users.NewMockValkeyRepository(ctrl) + caller := mock_users.NewMockCaller(ctrl) + + uc := NewUseCase( + pgRepository, + vkRepository, + config.Config{ + JWTSecret: "abc", + }, + ) + + t.Run("valid create user (admin > moderator)", func(t *testing.T) { + userId := int32(1) + username := "username" + password := "password" + role := models.RoleModerator + ctx := context.WithValue(context.Background(), "userId", userId) + + pgRepository.EXPECT().C().Return(caller).Times(2) + caller.EXPECT().ReadUserById(ctx, userId).Return(&models.User{ + Id: userId, + Role: models.RoleAdmin, + }, nil) + caller.EXPECT().CreateUser(ctx, username, password, role).Return(int32(2), nil) + + id, err := uc.CreateUser(ctx, username, password, role) + require.NoError(t, err) + require.Equal(t, int32(2), id) + }) + + t.Run("valid create user (moderator > participant)", func(t *testing.T) { + userId := int32(1) + username := "username" + password := "password" + role := models.RoleParticipant + ctx := context.WithValue(context.Background(), "userId", userId) + + pgRepository.EXPECT().C().Return(caller).Times(2) + caller.EXPECT().ReadUserById(ctx, userId).Return(&models.User{ + Id: userId, + Role: models.RoleModerator, + }, nil) + caller.EXPECT().CreateUser(ctx, username, password, role).Return(int32(2), nil) + + id, err := uc.CreateUser(ctx, username, password, role) + require.NoError(t, err) + require.Equal(t, int32(2), id) + }) + + t.Run("valid create user (admin > participant)", func(t *testing.T) { + userId := int32(1) + username := "username" + password := "password" + role := models.RoleParticipant + ctx := context.WithValue(context.Background(), "userId", userId) + + pgRepository.EXPECT().C().Return(caller).Times(2) + caller.EXPECT().ReadUserById(ctx, userId).Return(&models.User{ + Id: userId, + Role: models.RoleAdmin, + }, nil) + caller.EXPECT().CreateUser(ctx, username, password, role).Return(int32(2), nil) + + id, err := uc.CreateUser(ctx, username, password, role) + require.NoError(t, err) + require.Equal(t, int32(2), id) + }) + + t.Run("invalid user create 1 (no user id in context)", func(t *testing.T) { + _, err := uc.CreateUser(context.Background(), "username", "password", models.RoleModerator) + require.Error(t, err) + require.ErrorIs(t, err, pkg.ErrUnauthenticated) + }) + + t.Run("invalid user create 2 (user not found)", func(t *testing.T) { + userId := int32(1) + ctx := context.WithValue(context.Background(), "userId", userId) + + pgRepository.EXPECT().C().Return(caller) + caller.EXPECT().ReadUserById(ctx, userId).Return(nil, pkg.ErrNotFound) + + _, err := uc.CreateUser(ctx, "username", "password", models.RoleModerator) + require.Error(t, err) + require.ErrorIs(t, err, pkg.ErrNotFound) + }) + + t.Run("invalid user create 3 (no permission, participant < admin)", func(t *testing.T) { + userId := int32(1) + ctx := context.WithValue(context.Background(), "userId", userId) + + pgRepository.EXPECT().C().Return(caller) + caller.EXPECT().ReadUserById(ctx, userId).Return(&models.User{ + Id: userId, + Role: models.RoleParticipant, + }, nil) + + _, err := uc.CreateUser(ctx, "username", "password", models.RoleAdmin) + require.Error(t, err) + require.ErrorIs(t, err, pkg.NoPermission) + }) + + t.Run("invalid user create 4 (no permission, participant < moderator)", func(t *testing.T) { + userId := int32(1) + ctx := context.WithValue(context.Background(), "userId", userId) + + pgRepository.EXPECT().C().Return(caller) + caller.EXPECT().ReadUserById(ctx, userId).Return(&models.User{ + Id: userId, + Role: models.RoleParticipant, + }, nil) + + _, err := uc.CreateUser(ctx, "username", "password", models.RoleModerator) + require.Error(t, err) + require.ErrorIs(t, err, pkg.NoPermission) + }) + + t.Run("invalid user create 5 (no permission, moderator < admin)", func(t *testing.T) { + userId := int32(1) + ctx := context.WithValue(context.Background(), "userId", userId) + + pgRepository.EXPECT().C().Return(caller) + caller.EXPECT().ReadUserById(ctx, userId).Return(&models.User{ + Id: userId, + Role: models.RoleModerator, + }, nil) + + _, err := uc.CreateUser(ctx, "username", "password", models.RoleAdmin) + require.Error(t, err) + require.ErrorIs(t, err, pkg.NoPermission) + }) + + t.Run("invalid user create 6 (bad input, bad username)", func(t *testing.T) { + userId := int32(1) + ctx := context.WithValue(context.Background(), "userId", userId) + + pgRepository.EXPECT().C().Return(caller).Times(2) + caller.EXPECT().ReadUserById(ctx, userId).Return(&models.User{ + Id: userId, + Role: models.RoleModerator, + }, nil) + caller.EXPECT().CreateUser(ctx, + "test", + "password", + models.RoleParticipant, + ).Return(int32(0), pkg.ErrBadInput) + + _, err := uc.CreateUser(ctx, "test", "password", models.RoleParticipant) + require.Error(t, err) + require.ErrorIs(t, err, pkg.ErrBadInput) + }) +} + +func TestUseCase_ReadUserById(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + pgRepository := mock_users.NewMockPgRepository(ctrl) + vkRepository := mock_users.NewMockValkeyRepository(ctrl) + caller := mock_users.NewMockCaller(ctrl) + + uc := NewUseCase( + pgRepository, + vkRepository, + config.Config{ + JWTSecret: "abc", + }, + ) + + t.Run("valid user read", func(t *testing.T) { + userId := int32(1) + ctx := context.Background() + + pgRepository.EXPECT().C().Return(caller) + caller.EXPECT().ReadUserById(ctx, userId).Return(&models.User{Id: userId}, nil) + + user, err := uc.ReadUserById(ctx, userId) + require.NoError(t, err) + require.Equal(t, userId, user.Id) + }) + + t.Run("invalid user read 1 (not found)", func(t *testing.T) { + userId := int32(0) + ctx := context.Background() + + pgRepository.EXPECT().C().Return(caller) + caller.EXPECT().ReadUserById(ctx, userId).Return(nil, pkg.ErrNotFound) + + _, err := uc.ReadUserById(ctx, userId) + require.Error(t, err) + require.ErrorIs(t, err, pkg.ErrNotFound) + }) +} + +func TestUseCase_ReadUserByUsername(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + pgRepository := mock_users.NewMockPgRepository(ctrl) + vkRepository := mock_users.NewMockValkeyRepository(ctrl) + caller := mock_users.NewMockCaller(ctrl) + + uc := NewUseCase( + pgRepository, + vkRepository, + config.Config{ + JWTSecret: "abc", + }, + ) + + t.Run("valid user read", func(t *testing.T) { + username := "username" + ctx := context.Background() + + pgRepository.EXPECT().C().Return(caller) + caller.EXPECT().ReadUserByUsername(ctx, username).Return(&models.User{Username: username}, nil) + + user, err := uc.ReadUserByUsername(ctx, username) + require.NoError(t, err) + require.Equal(t, username, user.Username) + }) + + t.Run("invalid user read 1 (not found)", func(t *testing.T) { + username := "username" + ctx := context.Background() + + pgRepository.EXPECT().C().Return(caller) + caller.EXPECT().ReadUserByUsername(ctx, username).Return(nil, pkg.ErrNotFound) + + _, err := uc.ReadUserByUsername(ctx, username) + require.Error(t, err) + require.ErrorIs(t, err, pkg.ErrNotFound) + }) +} + +func TestUseCase_UpdateUser(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + pgRepository := mock_users.NewMockPgRepository(ctrl) + vkRepository := mock_users.NewMockValkeyRepository(ctrl) + caller := mock_users.NewMockCaller(ctrl) + txCaller := mock_users.NewMockTxCaller(ctrl) + + uc := NewUseCase( + pgRepository, + vkRepository, + config.Config{ + JWTSecret: "abc", + }, + ) + + t.Run("valid user update", func(t *testing.T) { + userId := int32(1) + userId2 := int32(2) + ctx := context.WithValue(context.Background(), "userId", userId) + + pgRepository.EXPECT().C().Return(caller).AnyTimes() + caller.EXPECT().ReadUserById(ctx, userId).Return(&models.User{ + Id: userId, + Role: models.RoleAdmin, + }, nil) + caller.EXPECT().ReadUserById(ctx, userId2).Return(&models.User{ + Id: userId2, + Role: models.RoleModerator, + }, nil) + + pgRepository.EXPECT().BeginTx(ctx).Return(txCaller, nil) + txCaller.EXPECT().UpdateUser(ctx, + userId2, + StringP("newusername"), + RoleP(models.RoleParticipant), + ).Return(nil) + vkRepository.EXPECT().DeleteAllSessions(ctx, userId2).Return(nil) + txCaller.EXPECT().Commit().Return(nil) + + err := uc.UpdateUser(ctx, userId2, StringP("newusername"), RoleP(models.RoleParticipant)) + require.NoError(t, err) + }) + + t.Run("invalid user update 1 (no user id in context)", func(t *testing.T) { + pgRepository.EXPECT().C().Return(caller).AnyTimes() + + err := uc.UpdateUser(context.Background(), 0, StringP("newusername"), RoleP(models.RoleParticipant)) + require.Error(t, err) + require.ErrorIs(t, err, pkg.ErrUnauthenticated) + }) + + t.Run("invalid user update 2 (cant update role of myself)", func(t *testing.T) { + userId := int32(1) + ctx := context.WithValue(context.Background(), "userId", userId) + + pgRepository.EXPECT().C().Return(caller).AnyTimes() + caller.EXPECT().ReadUserById(ctx, userId).Return(&models.User{ + Id: userId, + Role: models.RoleAdmin, + }, nil).Times(2) + + err := uc.UpdateUser(ctx, userId, StringP("newusername"), RoleP(models.RoleParticipant)) + require.Error(t, err) + require.ErrorIs(t, err, pkg.NoPermission) + }) + + t.Run("invalid user update 3 (cant set role >= my role)", func(t *testing.T) { + userId := int32(1) + userId2 := int32(2) + ctx := context.WithValue(context.Background(), "userId", userId) + + pgRepository.EXPECT().C().Return(caller).AnyTimes() + caller.EXPECT().ReadUserById(ctx, userId).Return(&models.User{ + Id: userId, + Role: models.RoleModerator, + }, nil) + caller.EXPECT().ReadUserById(ctx, userId2).Return(&models.User{ + Id: userId2, + Role: models.RoleParticipant, + }, nil) + + err := uc.UpdateUser(ctx, userId2, StringP("newusername"), RoleP(models.RoleModerator)) + require.Error(t, err) + require.ErrorIs(t, err, pkg.NoPermission) + }) + + t.Run("invalid user update 4 (cant edit user with >= role than mine)", func(t *testing.T) { + userId := int32(1) + userId2 := int32(2) + ctx := context.WithValue(context.Background(), "userId", userId) + + pgRepository.EXPECT().C().Return(caller).AnyTimes() + caller.EXPECT().ReadUserById(ctx, userId).Return(&models.User{ + Id: userId, + Role: models.RoleModerator, + }, nil) + caller.EXPECT().ReadUserById(ctx, userId2).Return(&models.User{ + Id: userId2, + Role: models.RoleModerator, + }, nil) + + err := uc.UpdateUser(ctx, userId2, StringP("newusername"), RoleP(models.RoleParticipant)) + require.Error(t, err) + require.ErrorIs(t, err, pkg.NoPermission) + }) +} + +func TestUseCase_DeleteUser(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + pgRepository := mock_users.NewMockPgRepository(ctrl) + vkRepository := mock_users.NewMockValkeyRepository(ctrl) + caller := mock_users.NewMockCaller(ctrl) + txCaller := mock_users.NewMockTxCaller(ctrl) + + uc := NewUseCase( + pgRepository, + vkRepository, + config.Config{ + JWTSecret: "abc", + }, + ) + + t.Run("valid user delete", func(t *testing.T) { + userId := int32(1) + userId2 := int32(2) + ctx := context.WithValue(context.Background(), "userId", userId) + + pgRepository.EXPECT().C().Return(caller).AnyTimes() + caller.EXPECT().ReadUserById(ctx, userId).Return(&models.User{ + Id: userId, + Role: models.RoleAdmin, + }, nil) + pgRepository.EXPECT().BeginTx(ctx).Return(txCaller, nil) + vkRepository.EXPECT().DeleteAllSessions(ctx, userId2).Return(nil) + txCaller.EXPECT().DeleteUser(ctx, userId2).Return(nil) + txCaller.EXPECT().Commit().Return(nil) + + err := uc.DeleteUser(ctx, userId2) + require.NoError(t, err) + }) + + t.Run("invalid delete (cant delete myself)", func(t *testing.T) { + userId := int32(1) + ctx := context.WithValue(context.Background(), "userId", userId) + + pgRepository.EXPECT().C().Return(caller).AnyTimes() + caller.EXPECT().ReadUserById(ctx, userId).Return(&models.User{ + Id: userId, + Role: models.RoleAdmin, + }, nil) + + err := uc.DeleteUser(ctx, userId) + require.Error(t, err) + require.ErrorIs(t, err, pkg.NoPermission) + }) +} + +func TestUseCase_CreateSession(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + vkRepository := mock_users.NewMockValkeyRepository(ctrl) + + uc := NewUseCase( + nil, + vkRepository, + config.Config{ + JWTSecret: "abc", + }, + ) + + t.Run("valid session creation", func(t *testing.T) { + ctx := context.Background() + sid := uuid.NewString() + vkRepository.EXPECT().CreateSession(ctx, int32(1), models.RoleAdmin).Return(sid, nil) + + sessionId, err := uc.CreateSession(ctx, int32(1), models.RoleAdmin) + require.NoError(t, err) + require.Equal(t, sessionId, sid) + }) +} + +func TestUseCase_ReadSession(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + vkRepository := mock_users.NewMockValkeyRepository(ctrl) + + uc := NewUseCase( + nil, + vkRepository, + config.Config{ + JWTSecret: "abc", + }, + ) + + t.Run("valid session read", func(t *testing.T) { + ctx := context.Background() + sid := uuid.NewString() + vkRepository.EXPECT().ReadSession(ctx, sid).Return(&models.Session{ + UserId: 1, + Id: sid, + Role: models.RoleAdmin, + }, nil) + + session, err := uc.ReadSession(ctx, sid) + require.NoError(t, err) + require.Equal(t, session.Id, sid) + require.Equal(t, session.UserId, int32(1)) + require.Equal(t, session.Role, models.RoleAdmin) + }) +} + +func TestUseCase_UpdateSession(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + vkRepository := mock_users.NewMockValkeyRepository(ctrl) + + uc := NewUseCase( + nil, + vkRepository, + config.Config{ + JWTSecret: "abc", + }, + ) + + t.Run("valid session update", func(t *testing.T) { + ctx := context.Background() + sid := uuid.NewString() + vkRepository.EXPECT().UpdateSession(ctx, sid).Return(nil) + + err := uc.UpdateSession(ctx, sid) + require.NoError(t, err) + }) +} + +func TestUseCase_DeleteSession(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + vkRepository := mock_users.NewMockValkeyRepository(ctrl) + + uc := NewUseCase( + nil, + vkRepository, + config.Config{ + JWTSecret: "abc", + }, + ) + + t.Run("valid session delete", func(t *testing.T) { + ctx := context.Background() + sid := uuid.NewString() + vkRepository.EXPECT().DeleteSession(ctx, sid).Return(nil) + + err := uc.DeleteSession(ctx, sid) + require.NoError(t, err) + }) +} + +func TestUseCase_DeleteAllSessions(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + vkRepository := mock_users.NewMockValkeyRepository(ctrl) + + uc := NewUseCase( + nil, + vkRepository, + config.Config{ + JWTSecret: "abc", + }, + ) + + t.Run("valid session delete", func(t *testing.T) { + ctx := context.Background() + userId := int32(1) + vkRepository.EXPECT().DeleteAllSessions(ctx, userId).Return(nil) + + err := uc.DeleteAllSessions(ctx, userId) + require.NoError(t, err) + }) +} + +func TestUseCase_Verify(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + vkRepository := mock_users.NewMockValkeyRepository(ctrl) + cfg := config.Config{JWTSecret: "abc"} + uc := NewUseCase( + nil, + vkRepository, + cfg, + ) + + t.Run("valid verification", func(t *testing.T) { + ctx := context.Background() + rsess := &models.Session{ + Id: uuid.NewString(), + UserId: 1, + Role: models.RoleAdmin, + } + + vkRepository.EXPECT().ReadSession(ctx, rsess.Id).Return(rsess, nil) + + session, err := uc.Verify(ctx, rsess.Id) + require.NoError(t, err) + + token, err := jwt.ParseWithClaims(session, &Token{}, func(token *jwt.Token) (interface{}, error) { + return []byte(cfg.JWTSecret), nil + }) + + claims, ok := token.Claims.(*Token) + + require.True(t, ok) + require.NoError(t, err) + require.Equal(t, claims.SessionId, rsess.Id) + require.Equal(t, claims.UserId, rsess.UserId) + require.Equal(t, claims.Role, rsess.Role) + }) +} + +func StringP(s string) *string { + return &s +} + +func RoleP(r models.Role) *models.Role { + return &r +} diff --git a/main.go b/main.go deleted file mode 100644 index 66d319b..0000000 --- a/main.go +++ /dev/null @@ -1,77 +0,0 @@ -package main - -import ( - "fmt" - "git.sch9.ru/new_gate/ms-auth/config" - sessionsDelivery "git.sch9.ru/new_gate/ms-auth/internal/sessions/delivery/grpc" - sessionsRepository "git.sch9.ru/new_gate/ms-auth/internal/sessions/repository" - sessionsUseCase "git.sch9.ru/new_gate/ms-auth/internal/sessions/usecase" - usersDelivery "git.sch9.ru/new_gate/ms-auth/internal/users/delivery/grpc" - usersRepository "git.sch9.ru/new_gate/ms-auth/internal/users/repository" - usersUseCase "git.sch9.ru/new_gate/ms-auth/internal/users/usecase" - "git.sch9.ru/new_gate/ms-auth/pkg/external/postgres" - "git.sch9.ru/new_gate/ms-auth/pkg/external/valkey" - "github.com/ilyakaznacheev/cleanenv" - _ "github.com/jackc/pgx/v5/stdlib" - "go.uber.org/zap" - "google.golang.org/grpc" - "google.golang.org/grpc/reflection" - "net" - "os" - "os/signal" - "syscall" -) - -func main() { - var cfg config.Config - err := cleanenv.ReadConfig(".env", &cfg) - if err != nil { - panic(fmt.Sprintf("error reading config: %s", err.Error())) - } - - var logger *zap.Logger - if cfg.Env == "prod" { - logger = zap.Must(zap.NewProduction()) - } else if cfg.Env == "dev" { - logger = zap.Must(zap.NewDevelopment()) - } else { - panic(fmt.Sprintf(`error reading config: env expected "prod" or "dev", got "%s"`, cfg.Env)) - } - - db, err := postgres.NewPostgresDB(cfg.PostgresDSN) - if err != nil { - panic(err) - } - defer db.Close() - - vk, err := valkey.NewValkeyClient(cfg.RedisDSN) - - userRepo := usersRepository.NewUserRepository(db, logger) - userUC := usersUseCase.NewUseCase(userRepo, nil, cfg) - - sessionRepo := sessionsRepository.NewValkeyRepository(vk, cfg, logger) - sessionUC := sessionsUseCase.NewUseCase(sessionRepo, cfg) - - gserver := grpc.NewServer(grpc.UnaryInterceptor(usersDelivery.TokenInterceptor(sessionUC))) - defer gserver.GracefulStop() - - usersDelivery.NewUserHandlers(gserver, userUC) - sessionsDelivery.NewSessionHandlers(gserver, sessionUC, userUC) - reflection.Register(gserver) - - ln, err := net.Listen("tcp", cfg.Address) - if err != nil { - panic(err) - } - - go func() { - if err = gserver.Serve(ln); err != nil { - panic(err) - } - }() - - stop := make(chan os.Signal, 1) - signal.Notify(stop, syscall.SIGTERM, syscall.SIGINT) - - <-stop -} diff --git a/migrations/20240608163806_initial.sql b/migrations/20240608163806_initial.sql index 54136b9..d10eec6 100644 --- a/migrations/20240608163806_initial.sql +++ b/migrations/20240608163806_initial.sql @@ -2,41 +2,33 @@ -- +goose StatementBegin CREATE TABLE IF NOT EXISTS users ( - id serial NOT NULL, - username VARCHAR(70) UNIQUE NOT NULL, - hashed_pwd VARCHAR(60) NOT NULL, - email VARCHAR(70) UNIQUE, - role INT NOT NULL DEFAULT 0, - expires_at TIMESTAMPTZ NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + id serial NOT NULL, + username VARCHAR(70) UNIQUE NOT NULL, + hashed_pwd VARCHAR(60) NOT NULL, + role INT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + modified_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), PRIMARY KEY (id), CHECK (length(username) != 0 AND username = lower(username) AND username = trim(username)), - CHECK (length(email) != 0 AND email = lower(email) AND email = trim(email)), - CHECK (lower(username) != lower(email)), CHECK (length(hashed_pwd) != 0), - CHECK (role BETWEEN 0 AND 3) + CHECK (role BETWEEN 0 AND 2) ); -CREATE INDEX ON users (id); -CREATE INDEX ON users (username); -CREATE INDEX ON users (email); - -CREATE FUNCTION usr_upd_trg_fn() RETURNS TRIGGER +CREATE FUNCTION modified_at_update() RETURNS TRIGGER LANGUAGE plpgsql AS $$ BEGIN - NEW.updated_at = NOW(); + NEW.modified_at = NOW(); RETURN NEW; END; $$; -CREATE TRIGGER usr_upd_trg +CREATE TRIGGER on_users_update BEFORE UPDATE ON users FOR EACH ROW -EXECUTE PROCEDURE usr_upd_trg_fn(); +EXECUTE PROCEDURE modified_at_update(); -- +goose StatementEnd -- +goose Down diff --git a/pkg/errors.go b/pkg/errors.go new file mode 100644 index 0000000..20e6efd --- /dev/null +++ b/pkg/errors.go @@ -0,0 +1,38 @@ +package pkg + +import ( + "errors" + "fmt" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +var ( + NoPermission = errors.New("no permission") + ErrUnauthenticated = errors.New("unauthenticated") + ErrUnhandled = errors.New("unhandled") + ErrNotFound = errors.New("not found") + ErrBadInput = errors.New("bad input") + ErrInternal = errors.New("internal") +) + +func Wrap(basic error, err error, op string, msg string) error { + return errors.Join(basic, err, fmt.Errorf("during %s: %s", op, msg)) +} + +func ToGRPC(err error) error { + switch { + case errors.Is(err, ErrUnauthenticated): + return status.Errorf(codes.Unauthenticated, err.Error()) + case errors.Is(err, ErrBadInput): + return status.Errorf(codes.InvalidArgument, err.Error()) + case errors.Is(err, ErrNotFound): + return status.Errorf(codes.NotFound, err.Error()) + case errors.Is(err, ErrInternal): + return status.Errorf(codes.Internal, err.Error()) + case errors.Is(err, NoPermission): + return status.Errorf(codes.PermissionDenied, err.Error()) + } + + return status.Errorf(codes.Unknown, err.Error()) +} diff --git a/pkg/external/postgres/postgres.go b/pkg/pg_client.go similarity index 96% rename from pkg/external/postgres/postgres.go rename to pkg/pg_client.go index b83463e..686b199 100644 --- a/pkg/external/postgres/postgres.go +++ b/pkg/pg_client.go @@ -1,4 +1,4 @@ -package postgres +package pkg import ( "github.com/jmoiron/sqlx" diff --git a/pkg/utils/converter.go b/pkg/utils/converter.go deleted file mode 100644 index 4996a5f..0000000 --- a/pkg/utils/converter.go +++ /dev/null @@ -1,33 +0,0 @@ -package utils - -import ( - "google.golang.org/protobuf/types/known/timestamppb" - "time" -) - -func TimeP(t *timestamppb.Timestamp) *time.Time { - if t == nil { - return nil - } - tt := t.AsTime() - return &tt -} - -func TimestampP(t *time.Time) *timestamppb.Timestamp { - if t == nil { - return nil - } - return timestamppb.New(*t) -} - -func AsTimeP(t time.Time) *time.Time { - return &t -} - -func AsInt32P(v int32) *int32 { - return &v -} - -func AsStringP(str string) *string { - return &str -} diff --git a/pkg/utils/errors.go b/pkg/utils/errors.go deleted file mode 100644 index 3382290..0000000 --- a/pkg/utils/errors.go +++ /dev/null @@ -1,26 +0,0 @@ -package utils - -import ( - "errors" -) - -var ( - ErrInternal = errors.New("internal") - ErrUnexpected = errors.New("unexpected") - ErrNoPermission = errors.New("no permission") -) - -var ( - ErrBadHandleOrPassword = errors.New("bad handle or password") - ErrBadRole = errors.New("bad role") - ErrTooShortPassword = errors.New("too short password") - ErrTooLongPassword = errors.New("too long password") - ErrBadEmail = errors.New("bad email") - ErrBadUsername = errors.New("bad username") - ErrTooShortUsername = errors.New("too short username") - ErrTooLongUsername = errors.New("too long username") -) - -var ( - ErrBadSession = errors.New("bad session") -) diff --git a/pkg/utils/validation.go b/pkg/utils/validation.go deleted file mode 100644 index 1806b0e..0000000 --- a/pkg/utils/validation.go +++ /dev/null @@ -1,34 +0,0 @@ -package utils - -import "net/mail" - -func ValidPassword(str string) error { - if len(str) < 5 { - return ErrTooShortPassword - } - if len(str) > 70 { - return ErrTooLongPassword - } - return nil -} - -func ValidEmail(str string) error { - emailAddress, err := mail.ParseAddress(str) - if err != nil || emailAddress.Address != str { - return ErrBadEmail - } - return nil -} - -func ValidUsername(str string) error { - if len(str) < 5 { - return ErrTooShortUsername - } - if len(str) > 70 { - return ErrTooLongUsername - } - if err := ValidEmail(str); err == nil { - return ErrBadUsername - } - return nil -} diff --git a/pkg/external/valkey/valkey.go b/pkg/valkey_client.go similarity index 93% rename from pkg/external/valkey/valkey.go rename to pkg/valkey_client.go index c5a4900..87e4681 100644 --- a/pkg/external/valkey/valkey.go +++ b/pkg/valkey_client.go @@ -1,4 +1,4 @@ -package valkey +package pkg import "github.com/valkey-io/valkey-go" diff --git a/proto b/proto index 33856fd..360832a 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit 33856fdad2a50061a942a67354dbd338b9032662 +Subproject commit 360832a5ab10821a76b7df5e23950e217f2c5221