From af9ab60092fe5433e1703795a3a49cdeeaf009d4 Mon Sep 17 00:00:00 2001 From: Vyacheslav1557 Date: Fri, 23 Aug 2024 03:56:03 +0500 Subject: [PATCH] feat: improve error handling --- internal/lib/errors.go | 129 ++++++++++++++++++++++++++++- internal/services/problem.go | 80 +++++++++++++----- internal/storage/errhandling.go | 12 +-- internal/storage/problem.go | 10 ++- internal/storage/solution.go | 8 +- internal/transport/interceptors.go | 94 +++++++++++++++------ internal/transport/problem.go | 26 +++--- internal/transport/server.go | 21 ++--- 8 files changed, 290 insertions(+), 90 deletions(-) diff --git a/internal/lib/errors.go b/internal/lib/errors.go index 5724610..92be0c4 100644 --- a/internal/lib/errors.go +++ b/internal/lib/errors.go @@ -2,14 +2,87 @@ package lib import ( "errors" + "fmt" + "go.uber.org/zap/zapcore" + "runtime" ) -var ( - ErrInternal = errors.New("internal") - ErrUnexpected = errors.New("unexpected") - ErrNoPermission = errors.New("no permission") +type code uint8 + +const ( + ErrValidationFailed code = 1 + ErrInternal code = 2 + ErrExternal code = 3 + ErrNoPermission code = 4 + ErrUnknown code = 5 + ErrDeadlineExceeded code = 6 + ErrNotFound code = 7 + ErrAlreadyExists code = 8 + ErrConflict code = 9 + ErrUnimplemented code = 10 + ErrBadInput code = 11 + ErrUnauthenticated code = 12 ) +func (c code) String() string { + switch { + case errors.Is(c, ErrValidationFailed): + return "validation error" + case errors.Is(c, ErrInternal): + return "internal error" + case errors.Is(c, ErrExternal): + return "external error" + case errors.Is(c, ErrNoPermission): + return "permission error" + case errors.Is(c, ErrUnknown): + return "unknown error" + case errors.Is(c, ErrDeadlineExceeded): + return "deadline error" + case errors.Is(c, ErrNotFound): + return "not found error" + case errors.Is(c, ErrAlreadyExists): + return "already exists error" + case errors.Is(c, ErrConflict): + return "conflict error" + case errors.Is(c, ErrUnimplemented): + return "unimplemented error" + case errors.Is(c, ErrBadInput): + return "bad input error" + } + + panic("unimplemented") +} + +func (c code) Error() string { + return c.String() +} + +type layer uint8 + +const ( + LayerTransport layer = 1 + LayerService layer = 2 + LayerStorage layer = 3 +) + +func (l layer) String() string { + switch l { + case LayerTransport: + return "transport" + case LayerService: + return "service" + case LayerStorage: + return "storage" + } + + panic("unimplemented") +} + +func location(skip int) string { + _, file, line, _ := runtime.Caller(skip) + return fmt.Sprintf("%s:%d", file, line) +} + var ( ErrBadRole = errors.New("bad role") ) @@ -18,3 +91,51 @@ var ( ErrBadTestingStrategy = errors.New("bad testing strategy") ErrBadResult = errors.New("bad result") ) + +type Error struct { + src error + layer layer + code code + msg string + loc string +} + +func wrap(src error, layer layer, class code, msg string, loc string) *Error { + return &Error{ + src: src, + layer: layer, + code: class, + msg: msg, + loc: loc, + } +} + +func (e *Error) Unwrap() []error { + return []error{e.src, e.code} +} + +func (e *Error) Error() string { + return fmt.Sprintf("%s: %s", e.code.String(), e.msg) +} + +func (e *Error) MarshalLogObject(encoder zapcore.ObjectEncoder) error { + if e.src != nil { + encoder.AddString("src", e.src.Error()) + } + encoder.AddString("layer", e.layer.String()) + encoder.AddString("code", e.code.String()) + encoder.AddString("msg", e.msg) + return nil +} + +func TransportError(src error, code code, msg string) error { + return wrap(src, LayerTransport, code, msg, location(2)) +} + +func ServiceError(src error, code code, msg string) error { + return wrap(src, LayerService, code, msg, location(2)) +} + +func StorageError(src error, code code, msg string) error { + return wrap(src, LayerStorage, code, msg, location(2)) +} diff --git a/internal/services/problem.go b/internal/services/problem.go index 7b175ae..e78da40 100644 --- a/internal/services/problem.go +++ b/internal/services/problem.go @@ -2,6 +2,7 @@ package services import ( "context" + "git.sch9.ru/new_gate/ms-tester/internal/lib" "git.sch9.ru/new_gate/ms-tester/internal/models" ) @@ -16,9 +17,14 @@ type PandocClient interface { ConvertLatexToHtml5(ctx context.Context, text string) (string, error) } +type IPermissionService interface { + Allowed(ctx context.Context, user *models.User, action string) bool +} + type ProblemService struct { - problemStorage ProblemStorage - pandocClient PandocClient + problemStorage ProblemStorage + pandocClient PandocClient + permissionService IPermissionService } func NewProblemService( @@ -31,30 +37,66 @@ func NewProblemService( } } -func (service *ProblemService) CreateProblem(ctx context.Context, problem *models.Problem, ch <-chan []byte) (int32, error) { - //userId := ctx.Value("user_id").(int32) - //html, err := service.pandocClient.ConvertLatexToHtml5(*problem.Description) - //if err != nil { - // return 0, err - //} - panic("access control is not implemented yet") - //return service.problemStorage.CreateProblem(ctx, problem) +func extractUser(ctx context.Context) *models.User { + return ctx.Value("user").(*models.User) +} + +func (service *ProblemService) CanCreateProblem(ctx context.Context) error { + if !service.permissionService.Allowed(ctx, extractUser(ctx), "create") { + return lib.ServiceError(nil, lib.ErrNoPermission, "permission denied") + } + return nil +} + +func (service *ProblemService) CanReadProblemById(ctx context.Context) error { + if !service.permissionService.Allowed(ctx, extractUser(ctx), "read") { + return lib.ServiceError(nil, lib.ErrNoPermission, "permission denied") + } + return nil +} + +func (service *ProblemService) CanUpdateProblem(ctx context.Context) error { + if !service.permissionService.Allowed(ctx, extractUser(ctx), "update") { + return lib.ServiceError(nil, lib.ErrNoPermission, "permission denied") + } + return nil +} + +func (service *ProblemService) CanDeleteProblem(ctx context.Context) error { + if !service.permissionService.Allowed(ctx, extractUser(ctx), "delete") { + return lib.ServiceError(nil, lib.ErrNoPermission, "permission denied") + } + return nil +} + +func (service *ProblemService) CreateProblem(ctx context.Context, problem *models.Problem) (int32, error) { + if err := service.CanCreateProblem(ctx); err != nil { + return 0, err + } + _, err := service.pandocClient.ConvertLatexToHtml5(ctx, *problem.Description) + if err != nil { + return 0, err + } + return service.problemStorage.CreateProblem(ctx, problem, nil) } func (service *ProblemService) ReadProblemById(ctx context.Context, id int32) (*models.Problem, error) { - //userId := ctx.Value("user_id").(int32) - panic("access control is not implemented yet") - //return service.problemStorage.ReadProblemById(ctx, id) + if err := service.CanReadProblemById(ctx); err != nil { + return nil, err + } + return service.problemStorage.ReadProblemById(ctx, id) } func (service *ProblemService) UpdateProblem(ctx context.Context, problem *models.Problem) error { - //userId := ctx.Value("user_id").(int32) - panic("access control is not implemented yet") - //return service.problemStorage.UpdateProblem(ctx, problem) + if err := service.CanUpdateProblem(ctx); err != nil { + return err + } + return service.problemStorage.UpdateProblem(ctx, problem) } func (service *ProblemService) DeleteProblem(ctx context.Context, id int32) error { - //userId := ctx.Value("user_id").(int32) - panic("access control is not implemented yet") - //return service.problemStorage.DeleteProblem(ctx, id) + if err := service.CanDeleteProblem(ctx); err != nil { + return err + } + return service.problemStorage.DeleteProblem(ctx, id) } diff --git a/internal/storage/errhandling.go b/internal/storage/errhandling.go index af8d419..ff720e3 100644 --- a/internal/storage/errhandling.go +++ b/internal/storage/errhandling.go @@ -10,12 +10,14 @@ import ( func handlePgErr(err error) error { var pgErr *pgconn.PgError if !errors.As(err, &pgErr) { - //storage.logger.DPanic("unexpected error from postgres", zap.String("err", err.Error())) - return lib.ErrUnexpected + return lib.StorageError(err, lib.ErrUnknown, "unexpected error from postgres") } if pgerrcode.IsIntegrityConstraintViolation(pgErr.Code) { - return errors.New("unique key violation") // FIXME + // TODO: probably should specify which constraint + return lib.StorageError(err, lib.ErrConflict, pgErr.Message) } - //storage.logger.DPanic("unexpected internal error from postgres", zap.String("err", err.Error())) - return lib.ErrInternal + if pgerrcode.IsNoData(pgErr.Code) { + return lib.StorageError(err, lib.ErrNotFound, pgErr.Message) + } + return lib.StorageError(err, lib.ErrUnimplemented, "unimplemented error") } diff --git a/internal/storage/problem.go b/internal/storage/problem.go index 2de1fa0..10bf2c8 100644 --- a/internal/storage/problem.go +++ b/internal/storage/problem.go @@ -2,6 +2,7 @@ package storage import ( "context" + "errors" "git.sch9.ru/new_gate/ms-tester/internal/models" "github.com/jmoiron/sqlx" "go.uber.org/zap" @@ -21,6 +22,9 @@ func NewProblemStorage(db *sqlx.DB, logger *zap.Logger) *ProblemStorage { func (storage *ProblemStorage) CreateProblem(ctx context.Context, problem *models.Problem, testGroupData []models.TestGroupData) (int32, error) { tx, err := storage.db.Beginx() + if err != nil { + return 0, handlePgErr(err) + } query := tx.Rebind(` INSERT INTO problems (name,description,time_limit,memory_limit) @@ -36,7 +40,7 @@ RETURNING id problem.MemoryLimit, ) if err != nil { - return 0, handlePgErr(err) + return 0, handlePgErr(errors.Join(err, tx.Rollback())) } for _, tgd := range testGroupData { query := tx.Rebind(` @@ -47,7 +51,7 @@ RETURNING id `) rows, err = tx.QueryxContext(ctx, query, tgd.Ts) if err != nil { - return 0, handlePgErr(err) + return 0, handlePgErr(errors.Join(err, tx.Rollback())) } var i int32 = 0 for ; i < tgd.TestAmount; i++ { @@ -59,7 +63,7 @@ RETURNING id `) rows, err = tx.QueryxContext(ctx, query, tgd.Ts) if err != nil { - return 0, handlePgErr(err) + return 0, handlePgErr(errors.Join(err, tx.Rollback())) } } } diff --git a/internal/storage/solution.go b/internal/storage/solution.go index b738860..4b7eae5 100644 --- a/internal/storage/solution.go +++ b/internal/storage/solution.go @@ -86,11 +86,11 @@ func (storage *SolutionStorage) RejudgeSolution(ctx context.Context, id int32) e return handlePgErr(err) } query := tx.Rebind("UPDATE solutions SET result = ? WHERE id = ?") - tx.QueryxContext(ctx, query, models.NotTested, id) + tx.QueryxContext(ctx, query, models.NotTested, id) // FIXME query = tx.Rebind("UPDATE subtaskruns SET result = ?,score = 0 WHERE solution_id = ?") - tx.QueryxContext(ctx, query, models.NotTested, id) + tx.QueryxContext(ctx, query, models.NotTested, id) // FIXME query = tx.Rebind("UPDATE testruns SET result = ?, score = 0 WHERE testgrouprun_id IN (SELECT id FROM tesgrouprun WHERE solution_id = ?)") - tx.QueryxContext(ctx, query, models.NotTested, id) + tx.QueryxContext(ctx, query, models.NotTested, id) // FIXME err = tx.Commit() var solution models.Solution query = storage.db.Rebind("SELECT * from solutions WHERE id=? LIMIT 1") @@ -98,7 +98,7 @@ func (storage *SolutionStorage) RejudgeSolution(ctx context.Context, id int32) e if err != nil { return handlePgErr(err) } - storage.updateResult(ctx, *solution.ParticipantId, *solution.TaskId) + storage.updateResult(ctx, *solution.ParticipantId, *solution.TaskId) // FIXME return nil } diff --git a/internal/transport/interceptors.go b/internal/transport/interceptors.go index a508f19..882d3ef 100644 --- a/internal/transport/interceptors.go +++ b/internal/transport/interceptors.go @@ -18,22 +18,18 @@ var defaultUser = &models.User{ UpdatedAt: nil, } -func extractToken(ctx context.Context) (string, error) { +func extractToken(ctx context.Context) string { md, ok := metadata.FromIncomingContext(ctx) if !ok { - return "", errors.New("no metadata") // FIXME + return "" } tokens := md.Get("token") if len(tokens) == 0 { - return "", errors.New("no token in metadata") // FIXME + return "" } - token := tokens[0] - if token == "" { - return "", errors.New("empty token in metadata") // FIXME - } - return token, nil + return tokens[0] } func (s *TesterServer) readSessionAndReadUser(ctx context.Context, token string) (*models.User, error) { @@ -41,20 +37,22 @@ func (s *TesterServer) readSessionAndReadUser(ctx context.Context, token string) // FIXME: maybe use single connection instead of multiple requests userId, err := s.sessionClient.Read(ctx, &sessionv1.ReadSessionRequest{Token: token}) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "") // FIXME + return nil, err } user, err := s.userService.ReadUserById(ctx, userId.GetUserId()) // FIXME: must be cached! if err != nil { - // FIXME: if error is "not found" (when error codes module is written) - // means user has no record, so we should create it - user = &models.User{ - UserId: lib.AsInt32P(userId.GetUserId()), - Role: models.RoleParticipant.AsPointer(), - } - err = s.userService.CreateUser(ctx, user) - if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "") // FIXME + if errors.Is(err, lib.ErrNotFound) { + user = &models.User{ + UserId: lib.AsInt32P(userId.GetUserId()), + Role: models.RoleParticipant.AsPointer(), + } + err = s.userService.CreateUser(ctx, user) + if err != nil { + return nil, err + } + } else { + return nil, err } } @@ -65,14 +63,10 @@ func insertUser(ctx context.Context, user *models.User) context.Context { return context.WithValue(ctx, "user", user) } -func extractUser(ctx context.Context) *models.User { - return ctx.Value("user").(*models.User) -} - func (s *TesterServer) AuthUnaryInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - token, err := extractToken(ctx) - if err != nil { + token := extractToken(ctx) + if token == "" { return handler(insertUser(ctx, defaultUser), req) } @@ -98,8 +92,8 @@ func (s *TesterServer) AuthStreamInterceptor() grpc.StreamServerInterceptor { return func(server interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { ctx := ss.Context() - token, err := extractToken(ctx) - if err != nil { + token := extractToken(ctx) + if token == "" { return handler(server, &ssWrapper{ServerStream: ss, ctx: insertUser(ctx, defaultUser)}) } @@ -111,3 +105,51 @@ func (s *TesterServer) AuthStreamInterceptor() grpc.StreamServerInterceptor { return handler(server, &ssWrapper{ServerStream: ss, ctx: insertUser(ctx, user)}) } } + +func ToGrpcError(err error) error { + if err == nil { + return nil + } + + // should I use map instead? + switch { + case errors.Is(err, lib.ErrValidationFailed): + return status.Error(codes.InvalidArgument, err.Error()) + case errors.Is(err, lib.ErrInternal): + return status.Error(codes.Internal, err.Error()) + case errors.Is(err, lib.ErrExternal): + return status.Error(codes.Unavailable, err.Error()) + case errors.Is(err, lib.ErrNoPermission): + return status.Error(codes.PermissionDenied, err.Error()) + case errors.Is(err, lib.ErrUnknown): + return status.Error(codes.Unknown, err.Error()) + case errors.Is(err, lib.ErrDeadlineExceeded): + return status.Error(codes.DeadlineExceeded, err.Error()) + case errors.Is(err, lib.ErrNotFound): + return status.Error(codes.NotFound, err.Error()) + case errors.Is(err, lib.ErrAlreadyExists): + return status.Error(codes.AlreadyExists, err.Error()) + case errors.Is(err, lib.ErrConflict): + return status.Error(codes.Unimplemented, err.Error()) + case errors.Is(err, lib.ErrUnimplemented): + return status.Error(codes.Unimplemented, err.Error()) + case errors.Is(err, lib.ErrUnauthenticated): + return status.Error(codes.Unauthenticated, err.Error()) + default: + return status.Error(codes.Unknown, err.Error()) + } +} + +func (s *TesterServer) ErrUnwrappingUnaryInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + resp, err := handler(ctx, req) + return resp, ToGrpcError(err) + } +} + +func (s *TesterServer) ErrUnwrappingStreamInterceptor() grpc.StreamServerInterceptor { + return func(server interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + err := handler(server, ss) + return ToGrpcError(err) + } +} diff --git a/internal/transport/problem.go b/internal/transport/problem.go index 0771dad..7f12a5c 100644 --- a/internal/transport/problem.go +++ b/internal/transport/problem.go @@ -15,17 +15,17 @@ import ( func (s *TesterServer) CreateProblem(server problemv1.ProblemService_CreateProblemServer) error { ctx := server.Context() - if !s.permissionService.Allowed(ctx, extractUser(ctx), "create") { - return status.Errorf(codes.PermissionDenied, "") // FIXME + if err := s.problemService.CanCreateProblem(ctx); err != nil { + return err } req, err := server.Recv() // receive problem if err != nil { - return err // FIXME + return lib.TransportError(err, lib.ErrBadInput, "can't receive problem") } problem := req.GetProblem() if problem == nil { - return status.Errorf(codes.Unknown, "") // FIXME + return lib.TransportError(nil, lib.ErrBadInput, "empty problem") } p := &models.Problem{ @@ -42,22 +42,23 @@ func (s *TesterServer) CreateProblem(server problemv1.ProblemService_CreateProbl return err // FIXME } - id, err := s.problemService.CreateProblem(ctx, p, nil) // FIXME + id, err := s.problemService.CreateProblem(ctx, p) // FIXME if err != nil { - return status.Errorf(codes.Unknown, "") // FIXME + return err } err = server.SendAndClose(&problemv1.CreateProblemResponse{ Id: id, }) if err != nil { - return err // FIXME + return lib.TransportError(err, lib.ErrBadInput, "can't send response") } return nil } func writeChunks(ctx context.Context, chunks <-chan []byte) error { + // use s3 // FIXME: use ctx? f, err := os.Create("out.txt") // FIXME: uuidv4 as initial temp name? if err != nil { @@ -113,13 +114,9 @@ func readChunks(ctx context.Context, server problemv1.ProblemService_CreateProbl } func (s *TesterServer) ReadProblem(ctx context.Context, req *problemv1.ReadProblemRequest) (*problemv1.ReadProblemResponse, error) { - if !s.permissionService.Allowed(ctx, extractUser(ctx), "read") { - return nil, status.Errorf(codes.PermissionDenied, "") // FIXME - } - problem, err := s.problemService.ReadProblemById(ctx, req.GetId()) if err != nil { - return nil, status.Errorf(codes.Unknown, err.Error()) // FIXME + return nil, err } return &problemv1.ReadProblemResponse{ Problem: &problemv1.ReadProblemResponse_Problem{ @@ -157,12 +154,9 @@ func (s *TesterServer) ReadProblem(ctx context.Context, req *problemv1.ReadProbl //} func (s *TesterServer) DeleteProblem(ctx context.Context, req *problemv1.DeleteProblemRequest) (*emptypb.Empty, error) { - if !s.permissionService.Allowed(ctx, extractUser(ctx), "delete") { - return nil, status.Errorf(codes.PermissionDenied, "") // FIXME - } err := s.problemService.DeleteProblem(ctx, req.GetId()) if err != nil { - return nil, status.Errorf(codes.Unknown, err.Error()) // FIXME + return nil, err } return &emptypb.Empty{}, nil } diff --git a/internal/transport/server.go b/internal/transport/server.go index e83d2ea..a82d0f5 100644 --- a/internal/transport/server.go +++ b/internal/transport/server.go @@ -14,7 +14,8 @@ import ( ) type ProblemService interface { - CreateProblem(ctx context.Context, problem *models.Problem, ch <-chan []byte) (int32, error) + CanCreateProblem(ctx context.Context) error + CreateProblem(ctx context.Context, problem *models.Problem) (int32, error) ReadProblemById(ctx context.Context, id int32) (*models.Problem, error) UpdateProblem(ctx context.Context, problem *models.Problem) error DeleteProblem(ctx context.Context, id int32) error @@ -32,10 +33,6 @@ type UserService interface { ReadUserById(ctx context.Context, userId int32) (*models.User, error) } -type PermissionService interface { - Allowed(ctx context.Context, user *models.User, action string) bool -} - type TesterServer struct { problemv1.UnimplementedProblemServiceServer problemService ProblemService @@ -43,8 +40,6 @@ type TesterServer struct { sessionClient SessionClient userService UserService - permissionService PermissionService - grpcServer *grpc.Server logger *zap.Logger } @@ -52,19 +47,19 @@ type TesterServer struct { func NewTesterServer( problemService ProblemService, sessionClient SessionClient, - permissionService PermissionService, userService UserService, logger *zap.Logger, ) *TesterServer { server := &TesterServer{ - problemService: problemService, - sessionClient: sessionClient, - permissionService: permissionService, - userService: userService, - logger: logger, + problemService: problemService, + sessionClient: sessionClient, + userService: userService, + logger: logger, } grpcServer := grpc.NewServer( + grpc.UnaryInterceptor(server.ErrUnwrappingUnaryInterceptor()), + grpc.StreamInterceptor(server.ErrUnwrappingStreamInterceptor()), grpc.UnaryInterceptor(server.AuthUnaryInterceptor()), grpc.StreamInterceptor(server.AuthStreamInterceptor()), )