131 lines
2.9 KiB
Go
131 lines
2.9 KiB
Go
package repository
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"log/slog"
|
|
|
|
"github.com/gofrs/uuid/v5"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
var (
|
|
ErrNotFound error = errors.New("user cannot be found")
|
|
)
|
|
|
|
type UserRow struct {
|
|
Id uuid.UUID
|
|
Email string
|
|
HashedPassword []byte
|
|
Salt []byte
|
|
}
|
|
|
|
func (u UserRow) Equal(user UserRow) bool {
|
|
return u.Id == user.Id && u.Email == user.Email && bytes.Equal(u.HashedPassword, user.HashedPassword) && bytes.Equal(u.Salt, user.Salt)
|
|
}
|
|
|
|
type UserRepository struct {
|
|
logger *slog.Logger
|
|
db *pgxpool.Pool
|
|
}
|
|
|
|
func NewUserRepository(logger *slog.Logger, db *pgxpool.Pool) *UserRepository {
|
|
return &UserRepository{
|
|
logger: logger,
|
|
db: db,
|
|
}
|
|
}
|
|
|
|
func (r *UserRepository) GetById(ctx context.Context, id uuid.UUID) (UserRow, error) {
|
|
user := UserRow{}
|
|
|
|
err := r.db.QueryRow(ctx, "SELECT * FROM users WHERE id = $1;", id).Scan(&user.Id, &user.Email, &user.HashedPassword, &user.Salt)
|
|
|
|
if err != nil {
|
|
if err == pgx.ErrNoRows {
|
|
return user, ErrNotFound
|
|
}
|
|
|
|
r.logger.ErrorContext(ctx, err.Error())
|
|
return user, err
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (UserRow, error) {
|
|
user := UserRow{}
|
|
|
|
err := r.db.QueryRow(ctx, "SELECT * FROM users WHERE email = $1;", email).Scan(&user.Id, &user.Email, &user.HashedPassword, &user.Salt)
|
|
|
|
if err != nil {
|
|
if err == pgx.ErrNoRows {
|
|
return user, ErrNotFound
|
|
}
|
|
|
|
r.logger.ErrorContext(ctx, err.Error())
|
|
return user, err
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (r *UserRepository) Create(ctx context.Context, user UserRow) (UserRow, error) {
|
|
var result pgx.Row
|
|
if user.Id.IsNil() {
|
|
result = r.db.QueryRow(ctx, "INSERT INTO users (email, hashed_password, salt) VALUES ($1, $2, $3) RETURNING id;", user.Email, user.HashedPassword, user.Salt)
|
|
|
|
err := result.Scan(&user.Id)
|
|
|
|
if err != nil {
|
|
r.logger.ErrorContext(ctx, err.Error())
|
|
return UserRow{}, err
|
|
}
|
|
} else {
|
|
_, err := r.db.Exec(ctx, "INSERT INTO users (id, email, hashed_password, salt) VALUES ($1, $2, $3, $4);", user.Id, user.Email, user.HashedPassword, user.Salt)
|
|
|
|
if err != nil {
|
|
r.logger.ErrorContext(ctx, err.Error())
|
|
return UserRow{}, err
|
|
}
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (r *UserRepository) Update(ctx context.Context, user UserRow) error {
|
|
result, err := r.db.Exec(ctx, "UPDATE users SET email = $1, hashed_password = $2, salt = $3 WHERE id = $4;", user.Email, user.HashedPassword, user.Salt, user.Id)
|
|
|
|
if err != nil {
|
|
r.logger.ErrorContext(ctx, err.Error())
|
|
return err
|
|
}
|
|
|
|
rowsAffected := result.RowsAffected()
|
|
|
|
if rowsAffected == 0 {
|
|
return ErrNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *UserRepository) Delete(ctx context.Context, id uuid.UUID) error {
|
|
result, err := r.db.Exec(ctx, "DELETE FROM users WHERE id = $1;", id)
|
|
|
|
if err != nil {
|
|
r.logger.ErrorContext(ctx, err.Error())
|
|
return err
|
|
}
|
|
|
|
rowsAffected := result.RowsAffected()
|
|
|
|
if rowsAffected == 0 {
|
|
return ErrNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|