package transport import ( "context" "errors" "git.sch9.ru/new_gate/ms-tester/internal/lib" "git.sch9.ru/new_gate/ms-tester/internal/models" sessionv1 "git.sch9.ru/new_gate/ms-tester/pkg/go/gen/proto/session/v1" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) var defaultUser = &models.User{ UserId: nil, Role: models.RoleSpectator.AsPointer(), UpdatedAt: nil, } func extractToken(ctx context.Context) string { md, ok := metadata.FromIncomingContext(ctx) if !ok { return "" } tokens := md.Get("token") if len(tokens) == 0 { return "" } return tokens[0] } func (s *TesterServer) readSessionAndReadUser(ctx context.Context, token string) (*models.User, error) { // FIXME: possible bottle neck: should we cache it? (think of it in future) // FIXME: maybe use single connection instead of multiple requests userId, err := s.sessionClient.Read(ctx, &sessionv1.ReadSessionRequest{Token: token}) if err != nil { return nil, err } user, err := s.userService.ReadUserById(ctx, userId.GetUserId()) // FIXME: must be cached! if err != nil { if errors.Is(err, lib.ErrNotFound) { user = &models.User{ UserId: lib.AsInt32P(userId.GetUserId()), Role: models.RoleParticipant.AsPointer(), } err = s.userService.CreateUser(ctx, user) if err != nil { return nil, err } } else { return nil, err } } return user, nil } func insertUser(ctx context.Context, user *models.User) context.Context { return context.WithValue(ctx, "user", user) } func (s *TesterServer) AuthUnaryInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { token := extractToken(ctx) if token == "" { return handler(insertUser(ctx, defaultUser), req) } user, err := s.readSessionAndReadUser(ctx, token) if err != nil { return handler(insertUser(ctx, defaultUser), req) } return handler(insertUser(ctx, user), req) } } type ssWrapper struct { grpc.ServerStream ctx context.Context } func (s *ssWrapper) Context() context.Context { return s.ctx } func (s *TesterServer) AuthStreamInterceptor() grpc.StreamServerInterceptor { return func(server interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { ctx := ss.Context() token := extractToken(ctx) if token == "" { return handler(server, &ssWrapper{ServerStream: ss, ctx: insertUser(ctx, defaultUser)}) } user, err := s.readSessionAndReadUser(ctx, token) if err != nil { return handler(server, &ssWrapper{ServerStream: ss, ctx: insertUser(ctx, defaultUser)}) } return handler(server, &ssWrapper{ServerStream: ss, ctx: insertUser(ctx, user)}) } } func ToGrpcError(err error) error { if err == nil { return nil } // should I use map instead? switch { case errors.Is(err, lib.ErrValidationFailed): return status.Error(codes.InvalidArgument, err.Error()) case errors.Is(err, lib.ErrInternal): return status.Error(codes.Internal, err.Error()) case errors.Is(err, lib.ErrExternal): return status.Error(codes.Unavailable, err.Error()) case errors.Is(err, lib.ErrNoPermission): return status.Error(codes.PermissionDenied, err.Error()) case errors.Is(err, lib.ErrUnknown): return status.Error(codes.Unknown, err.Error()) case errors.Is(err, lib.ErrDeadlineExceeded): return status.Error(codes.DeadlineExceeded, err.Error()) case errors.Is(err, lib.ErrNotFound): return status.Error(codes.NotFound, err.Error()) case errors.Is(err, lib.ErrAlreadyExists): return status.Error(codes.AlreadyExists, err.Error()) case errors.Is(err, lib.ErrConflict): return status.Error(codes.Unimplemented, err.Error()) case errors.Is(err, lib.ErrUnimplemented): return status.Error(codes.Unimplemented, err.Error()) case errors.Is(err, lib.ErrUnauthenticated): return status.Error(codes.Unauthenticated, err.Error()) default: return status.Error(codes.Unknown, err.Error()) } } func (s *TesterServer) ErrUnwrappingUnaryInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { resp, err := handler(ctx, req) return resp, ToGrpcError(err) } } func (s *TesterServer) ErrUnwrappingStreamInterceptor() grpc.StreamServerInterceptor { return func(server interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { err := handler(server, ss) return ToGrpcError(err) } }