ms-auth/internal/users/repository/pg_repository.go

260 lines
5.6 KiB
Go
Raw Normal View History

2024-10-09 17:07:38 +00:00
package repository
2024-08-14 10:36:43 +00:00
import (
"context"
2024-12-30 15:04:26 +00:00
"database/sql"
2024-08-14 10:36:43 +00:00
"errors"
2024-08-14 15:24:57 +00:00
"git.sch9.ru/new_gate/ms-auth/internal/models"
2024-12-30 15:04:26 +00:00
"git.sch9.ru/new_gate/ms-auth/internal/users"
"git.sch9.ru/new_gate/ms-auth/pkg"
2024-08-14 10:36:43 +00:00
"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jmoiron/sqlx"
2024-12-30 15:04:26 +00:00
"golang.org/x/crypto/bcrypt"
"net/mail"
2024-08-14 10:36:43 +00:00
)
2024-10-09 17:07:38 +00:00
type UsersRepository struct {
2024-12-30 15:04:26 +00:00
db *sqlx.DB
2024-08-14 10:36:43 +00:00
}
2024-12-30 15:04:26 +00:00
func NewUserRepository(db *sqlx.DB) *UsersRepository {
2024-10-09 17:07:38 +00:00
return &UsersRepository{
2024-12-30 15:04:26 +00:00
db: db,
2024-08-14 10:36:43 +00:00
}
}
2024-12-30 15:04:26 +00:00
func (r *UsersRepository) BeginTx(ctx context.Context) (users.TxCaller, error) {
const op = "UsersRepository.BeginTx"
tx, err := r.db.BeginTxx(ctx, nil)
if err != nil {
return nil, pkg.Wrap(pkg.ErrInternal, err, op, "database error")
}
return &TxCaller{
Caller: Caller{db: tx},
db: tx,
}, nil
}
func (r *UsersRepository) C() users.Caller {
return &Caller{db: r.db}
}
type TxOrDB interface {
Rebind(query string) string
GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}
type Caller struct {
db TxOrDB
}
type TxCaller struct {
Caller
db *sqlx.Tx
}
func (c *TxCaller) Commit() error {
const op = "TxCaller.Commit"
err := c.db.Commit()
if err != nil {
return pkg.Wrap(pkg.ErrInternal, err, op, "database error")
}
return nil
}
func (c *TxCaller) Rollback() error {
const op = "TxCaller.Rollback"
err := c.db.Rollback()
if err != nil {
return pkg.Wrap(pkg.ErrInternal, err, op, "database error")
}
return nil
}
2024-08-14 10:36:43 +00:00
2024-10-09 17:07:38 +00:00
const createUser = `
INSERT INTO users
2024-12-30 15:04:26 +00:00
(username, hashed_pwd, role)
VALUES (trim(lower(?)), ?, ?)
2024-10-09 17:07:38 +00:00
RETURNING id
`
2024-12-30 15:04:26 +00:00
func (c *Caller) CreateUser(ctx context.Context, username, password string, role models.Role) (int32, error) {
const op = "Caller.CreateUser"
if err := ValidUsername(username); err != nil {
return 0, pkg.Wrap(pkg.ErrBadInput, err, op, "username validation")
2024-08-14 10:36:43 +00:00
}
2024-12-30 15:04:26 +00:00
if err := ValidPassword(password); err != nil {
return 0, pkg.Wrap(pkg.ErrBadInput, err, op, "password validation")
2024-08-14 10:36:43 +00:00
}
2024-12-30 15:04:26 +00:00
if err := ValidRole(role); err != nil {
return 0, pkg.Wrap(pkg.ErrBadInput, err, op, "role validation")
2024-08-14 10:36:43 +00:00
}
2024-12-30 15:04:26 +00:00
hpwd, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return 0, pkg.Wrap(pkg.ErrBadInput, err, op, "password validation")
}
2024-10-09 17:07:38 +00:00
2024-12-30 15:04:26 +00:00
query := c.db.Rebind(createUser)
2024-08-14 10:36:43 +00:00
2024-12-30 15:04:26 +00:00
rows, err := c.db.QueryxContext(
2024-08-14 10:36:43 +00:00
ctx,
query,
2024-12-30 15:04:26 +00:00
username,
string(hpwd),
role,
2024-08-14 10:36:43 +00:00
)
if err != nil {
2024-12-30 15:04:26 +00:00
return 0, handlePgErr(err, op)
2024-08-14 10:36:43 +00:00
}
defer rows.Close()
var id int32
2024-10-09 17:07:38 +00:00
rows.Next()
err = rows.Scan(&id)
2024-08-14 10:36:43 +00:00
if err != nil {
2024-12-30 15:04:26 +00:00
return 0, handlePgErr(err, op)
2024-08-14 10:36:43 +00:00
}
return id, nil
}
2024-10-09 17:07:38 +00:00
const readUserByUsername = "SELECT * from users WHERE username=? LIMIT 1"
2024-12-30 15:04:26 +00:00
func (c *Caller) ReadUserByUsername(ctx context.Context, username string) (*models.User, error) {
const op = "Caller.ReadUserByUsername"
2024-08-14 10:36:43 +00:00
var user models.User
2024-12-30 15:04:26 +00:00
query := c.db.Rebind(readUserByUsername)
err := c.db.GetContext(ctx, &user, query, username)
2024-08-14 10:36:43 +00:00
if err != nil {
2024-12-30 15:04:26 +00:00
return nil, handlePgErr(err, op)
2024-08-14 10:36:43 +00:00
}
return &user, nil
}
2024-10-09 17:07:38 +00:00
const readUserById = "SELECT * from users WHERE id=? LIMIT 1"
2024-12-30 15:04:26 +00:00
func (c *Caller) ReadUserById(ctx context.Context, id int32) (*models.User, error) {
const op = "Caller.ReadUserById"
2024-08-14 10:36:43 +00:00
var user models.User
2024-12-30 15:04:26 +00:00
query := c.db.Rebind(readUserById)
err := c.db.GetContext(ctx, &user, query, id)
2024-08-14 10:36:43 +00:00
if err != nil {
2024-12-30 15:04:26 +00:00
return nil, handlePgErr(err, op)
2024-08-14 10:36:43 +00:00
}
return &user, nil
}
2024-10-09 17:07:38 +00:00
const updateUser = `
UPDATE users
2024-12-30 15:04:26 +00:00
SET username = COALESCE(?, trim(lower(username))),
2024-10-09 17:07:38 +00:00
role = COALESCE(?, role)
WHERE id = ?
`
2024-12-30 15:04:26 +00:00
func (c *Caller) UpdateUser(ctx context.Context, id int32, username *string, role *models.Role) error {
const op = "Caller.UpdateUser"
2024-08-14 10:36:43 +00:00
var err error
2024-12-30 15:04:26 +00:00
if username != nil {
if err = ValidUsername(*username); err != nil {
return pkg.Wrap(pkg.ErrBadInput, err, op, "username validation")
2024-08-14 10:36:43 +00:00
}
}
2024-12-30 15:04:26 +00:00
query := c.db.Rebind(updateUser)
_, err = c.db.ExecContext(
2024-08-14 10:36:43 +00:00
ctx,
query,
2024-12-30 15:04:26 +00:00
username,
role,
id,
2024-08-14 10:36:43 +00:00
)
if err != nil {
2024-12-30 15:04:26 +00:00
return handlePgErr(err, op)
2024-08-14 10:36:43 +00:00
}
return nil
}
2024-10-09 17:07:38 +00:00
2024-12-30 15:04:26 +00:00
const deleteUser = "DELETE FROM users WHERE id = ?"
func (c *Caller) DeleteUser(ctx context.Context, id int32) error {
const op = "Caller.DeleteUser"
2024-10-09 17:07:38 +00:00
2024-12-30 15:04:26 +00:00
query := c.db.Rebind(deleteUser)
_, err := c.db.ExecContext(ctx, query, id)
2024-08-14 10:36:43 +00:00
if err != nil {
2024-12-30 15:04:26 +00:00
return handlePgErr(err, op)
2024-08-14 10:36:43 +00:00
}
return nil
}
2024-12-30 15:04:26 +00:00
func handlePgErr(err error, op string) error {
2024-08-14 10:36:43 +00:00
var pgErr *pgconn.PgError
2024-12-30 15:04:26 +00:00
if errors.As(err, &pgErr) {
if pgerrcode.IsIntegrityConstraintViolation(pgErr.Code) {
return pkg.Wrap(pkg.ErrBadInput, err, op, pgErr.Message)
}
if pgerrcode.IsNoData(pgErr.Code) {
return pkg.Wrap(pkg.ErrNotFound, err, op, pgErr.Message)
}
2024-08-14 10:36:43 +00:00
}
2024-12-30 15:04:26 +00:00
return pkg.Wrap(pkg.ErrUnhandled, err, op, "unexpected error")
}
func ValidEmail(str string) error {
emailAddress, err := mail.ParseAddress(str)
if err != nil || emailAddress.Address != str {
return errors.New("invalid email")
}
return nil
}
func ValidUsername(str string) error {
if len(str) < 5 {
return errors.New("too short username")
}
if len(str) > 70 {
return errors.New("too long username")
}
if err := ValidEmail(str); err == nil {
return errors.New("username cannot be an email")
}
return nil
}
func ValidPassword(str string) error {
if len(str) < 5 {
return errors.New("too short password")
}
if len(str) > 70 {
return errors.New("too long password")
}
return nil
}
func ValidRole(role models.Role) error {
switch role {
case models.RoleAdmin:
return nil
case models.RoleModerator:
return nil
case models.RoleParticipant:
return nil
2024-08-14 10:36:43 +00:00
}
2024-12-30 15:04:26 +00:00
return errors.New("invalid role")
2024-08-14 10:36:43 +00:00
}