package repository import ( "context" "database/sql" "errors" "git.sch9.ru/new_gate/ms-auth/internal/models" "git.sch9.ru/new_gate/ms-auth/internal/users" "git.sch9.ru/new_gate/ms-auth/pkg" "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v5/pgconn" "github.com/jmoiron/sqlx" "golang.org/x/crypto/bcrypt" "net/mail" ) type UsersRepository struct { db *sqlx.DB } func NewUserRepository(db *sqlx.DB) *UsersRepository { return &UsersRepository{ db: db, } } func (r *UsersRepository) BeginTx(ctx context.Context) (users.TxCaller, error) { const op = "UsersRepository.BeginTx" tx, err := r.db.BeginTxx(ctx, nil) if err != nil { return nil, pkg.Wrap(pkg.ErrInternal, err, op, "database error") } return &TxCaller{ Caller: Caller{db: tx}, db: tx, }, nil } func (r *UsersRepository) C() users.Caller { return &Caller{db: r.db} } type TxOrDB interface { Rebind(query string) string GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) } type Caller struct { db TxOrDB } type TxCaller struct { Caller db *sqlx.Tx } func (c *TxCaller) Commit() error { const op = "TxCaller.Commit" err := c.db.Commit() if err != nil { return pkg.Wrap(pkg.ErrInternal, err, op, "database error") } return nil } func (c *TxCaller) Rollback() error { const op = "TxCaller.Rollback" err := c.db.Rollback() if err != nil { return pkg.Wrap(pkg.ErrInternal, err, op, "database error") } return nil } const createUser = ` INSERT INTO users (username, hashed_pwd, role) VALUES (trim(lower(?)), ?, ?) RETURNING id ` func (c *Caller) CreateUser(ctx context.Context, username, password string, role models.Role) (int32, error) { const op = "Caller.CreateUser" if err := ValidUsername(username); err != nil { return 0, pkg.Wrap(pkg.ErrBadInput, err, op, "username validation") } if err := ValidPassword(password); err != nil { return 0, pkg.Wrap(pkg.ErrBadInput, err, op, "password validation") } if err := ValidRole(role); err != nil { return 0, pkg.Wrap(pkg.ErrBadInput, err, op, "role validation") } hpwd, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return 0, pkg.Wrap(pkg.ErrBadInput, err, op, "password validation") } query := c.db.Rebind(createUser) rows, err := c.db.QueryxContext( ctx, query, username, string(hpwd), role, ) if err != nil { return 0, handlePgErr(err, op) } defer rows.Close() var id int32 rows.Next() err = rows.Scan(&id) if err != nil { return 0, handlePgErr(err, op) } return id, nil } const readUserByUsername = "SELECT * from users WHERE username=? LIMIT 1" func (c *Caller) ReadUserByUsername(ctx context.Context, username string) (*models.User, error) { const op = "Caller.ReadUserByUsername" var user models.User query := c.db.Rebind(readUserByUsername) err := c.db.GetContext(ctx, &user, query, username) if err != nil { return nil, handlePgErr(err, op) } return &user, nil } const readUserById = "SELECT * from users WHERE id=? LIMIT 1" func (c *Caller) ReadUserById(ctx context.Context, id int32) (*models.User, error) { const op = "Caller.ReadUserById" var user models.User query := c.db.Rebind(readUserById) err := c.db.GetContext(ctx, &user, query, id) if err != nil { return nil, handlePgErr(err, op) } return &user, nil } const updateUser = ` UPDATE users SET username = COALESCE(?, trim(lower(username))), role = COALESCE(?, role) WHERE id = ? ` func (c *Caller) UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error { const op = "Caller.UpdateUser" var err error if username != nil { if err = ValidUsername(*username); err != nil { return pkg.Wrap(pkg.ErrBadInput, err, op, "username validation") } } query := c.db.Rebind(updateUser) _, err = c.db.ExecContext( ctx, query, username, role, id, ) if err != nil { return handlePgErr(err, op) } return nil } const deleteUser = "DELETE FROM users WHERE id = ?" func (c *Caller) DeleteUser(ctx context.Context, id int32) error { const op = "Caller.DeleteUser" query := c.db.Rebind(deleteUser) _, err := c.db.ExecContext(ctx, query, id) if err != nil { return handlePgErr(err, op) } return nil } func handlePgErr(err error, op string) error { var pgErr *pgconn.PgError if errors.As(err, &pgErr) { if pgerrcode.IsIntegrityConstraintViolation(pgErr.Code) { return pkg.Wrap(pkg.ErrBadInput, err, op, pgErr.Message) } if pgerrcode.IsNoData(pgErr.Code) { return pkg.Wrap(pkg.ErrNotFound, err, op, pgErr.Message) } } return pkg.Wrap(pkg.ErrUnhandled, err, op, "unexpected error") } func ValidEmail(str string) error { emailAddress, err := mail.ParseAddress(str) if err != nil || emailAddress.Address != str { return errors.New("invalid email") } return nil } func ValidUsername(str string) error { if len(str) < 5 { return errors.New("too short username") } if len(str) > 70 { return errors.New("too long username") } if err := ValidEmail(str); err == nil { return errors.New("username cannot be an email") } return nil } func ValidPassword(str string) error { if len(str) < 5 { return errors.New("too short password") } if len(str) > 70 { return errors.New("too long password") } return nil } func ValidRole(role models.Role) error { switch role { case models.RoleAdmin: return nil case models.RoleModerator: return nil case models.RoleParticipant: return nil } return errors.New("invalid role") }