auth services, middleware, and other stuff
All checks were successful
ci/woodpecker/pr/build Pipeline was successful
ci/woodpecker/pr/lint Pipeline was successful
ci/woodpecker/pr/test Pipeline was successful

This commit is contained in:
2025-05-22 13:55:43 -04:00
parent 70bb4e66b4
commit e55d419d44
22 changed files with 985 additions and 95 deletions

View File

@@ -0,0 +1,67 @@
package handler
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"gitea.michaelthomson.dev/mthomson/habits/internal/user/service"
"github.com/gofrs/uuid/v5"
)
type UserRegisterer interface {
Register(ctx context.Context, email string, password string) (uuid.UUID, error)
}
type RegisterRequest struct {
Email string `json:"email"`
Password string `json:"password"`
}
type RegisterResponse struct {
Id string `json:"id"`
}
func HandleRegisterUser(logger *slog.Logger, userService UserRegisterer) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
registerRequest := RegisterRequest{}
decoder := json.NewDecoder(r.Body)
decoder.DisallowUnknownFields()
err := decoder.Decode(&registerRequest)
if err != nil {
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusBadRequest)
return
}
uuid, err := userService.Register(ctx, registerRequest.Email, registerRequest.Password)
if err != nil {
if err == service.ErrUserExists {
http.Error(w, "", http.StatusConflict)
return
}
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusInternalServerError)
return
}
response := RegisterResponse{uuid.String()}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
err = json.NewEncoder(w).Encode(response)
if err != nil {
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusInternalServerError)
return
}
}
}

View File

@@ -0,0 +1,148 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"log/slog"
"errors"
"net/http"
"net/http/httptest"
"testing"
"gitea.michaelthomson.dev/mthomson/habits/internal/user/service"
"github.com/gofrs/uuid/v5"
)
type MockUserRegisterer struct {
RegisterUserFunc func(ctx context.Context, email string, password string) (uuid.UUID, error)
}
func (m *MockUserRegisterer) Register(ctx context.Context, email string, password string) (uuid.UUID, error) {
return m.RegisterUserFunc(ctx, email, password)
}
func TestCreateUser(t *testing.T) {
logger := slog.Default()
t.Run("create user", func(t *testing.T) {
createUserRequest := RegisterRequest{Email: "test@test.com", Password: "password"}
newUUID := NewUUID(t)
service := MockUserRegisterer{
RegisterUserFunc: func(ctx context.Context, email string, password string) (uuid.UUID, error) {
return newUUID, nil
},
}
handler := HandleRegisterUser(logger, &service)
requestBody, err := json.Marshal(createUserRequest)
if err != nil {
t.Fatalf("Failed to marshal request %+v: %v", createUserRequest, err)
}
req := httptest.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(requestBody))
res := httptest.NewRecorder()
handler(res, req)
AssertStatusCodes(t, res.Code, http.StatusOK)
})
t.Run("returns 409 when user exists", func(t *testing.T) {
createUserRequest := RegisterRequest{Email: "test@test.com", Password: "password"}
newUUID := NewUUID(t)
service := MockUserRegisterer{
RegisterUserFunc: func(ctx context.Context, email string, password string) (uuid.UUID, error) {
return newUUID, service.ErrUserExists
},
}
handler := HandleRegisterUser(logger, &service)
requestBody, err := json.Marshal(createUserRequest)
if err != nil {
t.Fatalf("Failed to marshal request %+v: %v", createUserRequest, err)
}
req := httptest.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(requestBody))
res := httptest.NewRecorder()
handler(res, req)
AssertStatusCodes(t, res.Code, http.StatusConflict)
})
t.Run("returns 400 with bad json", func(t *testing.T) {
handler := HandleRegisterUser(logger, nil)
badStruct := struct {
Foo string
}{
Foo: "bar",
}
requestBody, err := json.Marshal(badStruct)
if err != nil {
t.Fatalf("Failed to marshal request %+v: %v", badStruct, err)
}
req := httptest.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(requestBody))
res := httptest.NewRecorder()
handler(res, req)
AssertStatusCodes(t, res.Code, http.StatusBadRequest)
})
t.Run("returns 500 arbitrary errors", func(t *testing.T) {
createUserRequest := RegisterRequest{Email: "test@test.com", Password: "password"}
newUUID := NewUUID(t)
service := MockUserRegisterer{
RegisterUserFunc: func(ctx context.Context, email string, password string) (uuid.UUID, error) {
return newUUID, errors.New("foo bar")
},
}
handler := HandleRegisterUser(logger, &service)
requestBody, err := json.Marshal(createUserRequest)
if err != nil {
t.Fatalf("Failed to marshal request %+v: %v", createUserRequest, err)
}
req := httptest.NewRequest(http.MethodPost, "/user", bytes.NewBuffer(requestBody))
res := httptest.NewRecorder()
handler(res, req)
AssertStatusCodes(t, res.Code, http.StatusInternalServerError)
})
}
func AssertStatusCodes(t testing.TB, got, want int) {
t.Helper()
if got != want {
t.Errorf("got status code: %v, want status code: %v", want, got)
}
}
func NewUUID(t testing.TB) uuid.UUID {
t.Helper()
uuid, err := uuid.NewV4()
if err != nil {
t.Errorf("error generation uuid: %v", err)
}
return uuid
}

View File

@@ -1,6 +1,7 @@
package repository
import (
"bytes"
"context"
"errors"
"log/slog"
@@ -17,15 +18,12 @@ var (
type UserRow struct {
Id uuid.UUID
Email string
HashedPassword string
}
func NewUserRow(id uuid.UUID, email string, hashedPassword string) UserRow {
return UserRow{Id: id, Email: email, HashedPassword: hashedPassword}
HashedPassword []byte
Salt []byte
}
func (u UserRow) Equal(user UserRow) bool {
return u.Id == user.Id && u.Email == user.Email && u.HashedPassword == user.HashedPassword
return u.Id == user.Id && u.Email == user.Email && bytes.Equal(u.HashedPassword, user.HashedPassword) && bytes.Equal(u.Salt, user.Salt)
}
type UserRepository struct {
@@ -43,7 +41,24 @@ func NewUserRepository(logger *slog.Logger, db *pgxpool.Pool) *UserRepository {
func (r *UserRepository) GetById(ctx context.Context, id uuid.UUID) (UserRow, error) {
user := UserRow{}
err := r.db.QueryRow(ctx, "SELECT * FROM users WHERE id = $1;", id).Scan(&user.Id, &user.Email, &user.HashedPassword)
err := r.db.QueryRow(ctx, "SELECT * FROM users WHERE id = $1;", id).Scan(&user.Id, &user.Email, &user.HashedPassword, &user.Salt)
if err != nil {
if err == pgx.ErrNoRows {
return user, ErrNotFound
}
r.logger.ErrorContext(ctx, err.Error())
return user, err
}
return user, nil
}
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (UserRow, error) {
user := UserRow{}
err := r.db.QueryRow(ctx, "SELECT * FROM users WHERE email = $1;", email).Scan(&user.Id, &user.Email, &user.HashedPassword, &user.Salt)
if err != nil {
if err == pgx.ErrNoRows {
@@ -60,7 +75,7 @@ func (r *UserRepository) GetById(ctx context.Context, id uuid.UUID) (UserRow, er
func (r *UserRepository) Create(ctx context.Context, user UserRow) (UserRow, error) {
var result pgx.Row
if user.Id.IsNil() {
result = r.db.QueryRow(ctx, "INSERT INTO users (email, hashed_password) VALUES ($1, $2) RETURNING id;", user.Email, user.HashedPassword)
result = r.db.QueryRow(ctx, "INSERT INTO users (email, hashed_password, salt) VALUES ($1, $2, $3) RETURNING id;", user.Email, user.HashedPassword, user.Salt)
err := result.Scan(&user.Id)
@@ -69,7 +84,7 @@ func (r *UserRepository) Create(ctx context.Context, user UserRow) (UserRow, err
return UserRow{}, err
}
} else {
_, err := r.db.Exec(ctx, "INSERT INTO users (id, email, hashed_password) VALUES ($1, $2, $3);", user.Id, user.Email, user.HashedPassword)
_, err := r.db.Exec(ctx, "INSERT INTO users (id, email, hashed_password, salt) VALUES ($1, $2, $3, $4);", user.Id, user.Email, user.HashedPassword, user.Salt)
if err != nil {
r.logger.ErrorContext(ctx, err.Error())
@@ -81,7 +96,7 @@ func (r *UserRepository) Create(ctx context.Context, user UserRow) (UserRow, err
}
func (r *UserRepository) Update(ctx context.Context, user UserRow) error {
result, err := r.db.Exec(ctx, "UPDATE users SET email = $1, hashed_password = $2 WHERE id = $3;", user.Email, user.HashedPassword, user.Id)
result, err := r.db.Exec(ctx, "UPDATE users SET email = $1, hashed_password = $2, salt = $3 WHERE id = $4;", user.Email, user.HashedPassword, user.Salt, user.Id)
if err != nil {
r.logger.ErrorContext(ctx, err.Error())

View File

@@ -6,6 +6,7 @@ import (
"log/slog"
"testing"
"gitea.michaelthomson.dev/mthomson/habits/internal/auth"
"gitea.michaelthomson.dev/mthomson/habits/internal/test"
"gitea.michaelthomson.dev/mthomson/habits/internal/user/repository"
"github.com/gofrs/uuid/v5"
@@ -18,22 +19,36 @@ func TestCRUD(t *testing.T) {
defer tdb.TearDown()
r := repository.NewUserRepository(logger, tdb.Db)
uuid := NewUUID(t)
argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
hashSalt, err := argon2IdHash.GenerateHash([]byte("supersecurepassword"), []byte("supersecuresalt"))
if err != nil {
t.Errorf("could not generate hash: %v", err)
}
t.Run("creates new user", func(t *testing.T) {
newUser := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"}
newUser := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt}
_, err := r.Create(ctx, newUser)
AssertNoError(t, err)
})
t.Run("gets user", func(t *testing.T) {
want := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"}
t.Run("gets user by id", func(t *testing.T) {
want := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt}
got, err := r.GetById(ctx, uuid)
AssertNoError(t, err)
AssertUserRows(t, got, want)
})
t.Run("gets user by email", func(t *testing.T) {
want := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt}
got, err := r.GetByEmail(ctx, "test@test.com")
AssertNoError(t, err)
AssertUserRows(t, got, want)
})
t.Run("updates user", func(t *testing.T) {
want := repository.UserRow{Id: uuid, Email: "new@test.com", HashedPassword: "supersecurehash"}
want := repository.UserRow{Id: uuid, Email: "new@test.com", HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt}
err := r.Update(ctx, want)
AssertNoError(t, err)

View File

@@ -5,52 +5,48 @@ import (
"errors"
"log/slog"
"gitea.michaelthomson.dev/mthomson/habits/internal/auth"
"gitea.michaelthomson.dev/mthomson/habits/internal/user/repository"
"github.com/gofrs/uuid/v5"
)
var (
ErrNotFound error = errors.New("user cannot be found")
ErrNotFound error = errors.New("user cannot be found")
ErrUserExists error = errors.New("user already exists")
)
type User struct {
Id uuid.UUID
Email string
HashedPassword string
}
func NewUser(id uuid.UUID, email string, hashedPassword string) User {
return User{Id: id, Email: email, HashedPassword: hashedPassword}
Id uuid.UUID
Email string
}
func UserFromUserRow(userRow repository.UserRow) User {
return User{Id: userRow.Id, Email: userRow.Email, HashedPassword: userRow.HashedPassword}
}
func UserRowFromUser(user User) repository.UserRow {
return repository.UserRow{Id: user.Id, Email: user.Email, HashedPassword: user.HashedPassword}
return User{Id: userRow.Id, Email: userRow.Email}
}
func (t User) Equal(user User) bool {
return t.Id == user.Id && t.Email == user.Email && t.HashedPassword == user.HashedPassword
return t.Id == user.Id && t.Email == user.Email
}
type UserRepository interface {
Create(ctx context.Context, user repository.UserRow) (repository.UserRow, error)
GetById(ctx context.Context, id uuid.UUID) (repository.UserRow, error)
GetByEmail(ctx context.Context, email string) (repository.UserRow, error)
Update(ctx context.Context, user repository.UserRow) error
Delete(ctx context.Context, id uuid.UUID) error
}
type UserService struct {
logger *slog.Logger
repo UserRepository
logger *slog.Logger
repo UserRepository
argon2IdHash *auth.Argon2IdHash
}
func NewUserService(logger *slog.Logger, userRepo UserRepository) *UserService {
func NewUserService(logger *slog.Logger, userRepo UserRepository, argon2IdHash *auth.Argon2IdHash) *UserService {
return &UserService{
logger: logger,
repo: userRepo,
logger: logger,
repo: userRepo,
argon2IdHash: argon2IdHash,
}
}
@@ -69,17 +65,37 @@ func (s *UserService) GetUser(ctx context.Context, id uuid.UUID) (User, error) {
return UserFromUserRow(user), err
}
func (s *UserService) CreateUser(ctx context.Context, user User) (User, error) {
userRow := UserRowFromUser(user)
newUserRow, err := s.repo.Create(ctx, userRow)
func (s *UserService) Register(ctx context.Context, email string, password string) (uuid.UUID, error) {
uuid, err := uuid.NewV4()
if err != nil {
s.logger.ErrorContext(ctx, err.Error())
return User{}, err
return uuid, err
}
return UserFromUserRow(newUserRow), err
_, err = s.repo.GetByEmail(ctx, email)
if err != repository.ErrNotFound {
return uuid, ErrUserExists
}
hashSalt, err := s.argon2IdHash.GenerateHash([]byte(password), nil)
if err != nil {
s.logger.ErrorContext(ctx, err.Error())
return uuid, err
}
userRow := repository.UserRow{Id: uuid, Email: email, HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt}
_, err = s.repo.Create(ctx, userRow)
if err != nil {
s.logger.ErrorContext(ctx, err.Error())
return uuid, err
}
return uuid, err
}
func (s *UserService) DeleteUser(ctx context.Context, id uuid.UUID) error {
@@ -96,10 +112,66 @@ func (s *UserService) DeleteUser(ctx context.Context, id uuid.UUID) error {
return err
}
func (s *UserService) UpdateUser(ctx context.Context, user User) error {
userRow := UserRowFromUser(user)
func (s *UserService) UpdateUserEmail(ctx context.Context, id uuid.UUID, email string) error {
user, err := s.repo.GetById(ctx, id)
err := s.repo.Update(ctx, userRow)
if err == repository.ErrNotFound {
return ErrNotFound
}
if err != nil {
s.logger.ErrorContext(ctx, err.Error())
return err
}
_, err = s.repo.GetByEmail(ctx, email)
switch err {
case repository.ErrNotFound:
user.Email = email
err = s.repo.Update(ctx, user)
if err == repository.ErrNotFound {
return ErrNotFound
}
if err != nil {
s.logger.ErrorContext(ctx, err.Error())
}
return err
case nil:
return ErrUserExists
default:
s.logger.ErrorContext(ctx, err.Error())
return err
}
}
func (s *UserService) UpdateUserPassword(ctx context.Context, id uuid.UUID, password string) error {
user, err := s.repo.GetById(ctx, id)
if err == repository.ErrNotFound {
return ErrNotFound
}
if err != nil {
s.logger.ErrorContext(ctx, err.Error())
return err
}
hashSalt, err := s.argon2IdHash.GenerateHash([]byte(password), nil)
if err != nil {
s.logger.ErrorContext(ctx, err.Error())
return err
}
user.HashedPassword = hashSalt.Hash
user.Salt = hashSalt.Salt
err = s.repo.Update(ctx, user)
if err == repository.ErrNotFound {
return ErrNotFound

View File

@@ -5,27 +5,26 @@ import (
"log/slog"
"testing"
"gitea.michaelthomson.dev/mthomson/habits/internal/auth"
"gitea.michaelthomson.dev/mthomson/habits/internal/test"
"gitea.michaelthomson.dev/mthomson/habits/internal/user/repository"
"gitea.michaelthomson.dev/mthomson/habits/internal/user/service"
"github.com/gofrs/uuid/v5"
)
func TestCreateUser(t *testing.T) {
func TestRegisterUser(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slog.Default()
tdb := test.NewTestDatabase(t)
defer tdb.TearDown()
r := repository.NewUserRepository(logger, tdb.Db)
argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
userService := service.NewUserService(logger, r)
userService := service.NewUserService(logger, r, argon2IdHash)
t.Run("Create user", func(t *testing.T) {
uuid := NewUUID(t)
user := service.NewUser(uuid, "test@test.com", "supersecurehash")
_, err := userService.CreateUser(ctx, user)
_, err := userService.Register(ctx, "test@test.com", "supersecurepassword")
AssertNoError(t, err)
})
@@ -38,13 +37,12 @@ func TestGetUser(t *testing.T) {
tdb := test.NewTestDatabase(t)
defer tdb.TearDown()
r := repository.NewUserRepository(logger, tdb.Db)
uuid := NewUUID(t)
argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
row := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"}
_, err := r.Create(ctx, row)
AssertNoError(t, err);
userService := service.NewUserService(logger, r, argon2IdHash)
userService := service.NewUserService(logger, r)
uuid, err := userService.Register(ctx, "test@test.com", "supersecurepassword")
AssertNoError(t, err)
t.Run("Get exisiting user", func(t *testing.T) {
_, err := userService.GetUser(ctx, uuid)
@@ -66,13 +64,12 @@ func TestDeleteUser(t *testing.T) {
tdb := test.NewTestDatabase(t)
defer tdb.TearDown()
r := repository.NewUserRepository(logger, tdb.Db)
uuid := NewUUID(t)
argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
row := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"}
_, err := r.Create(ctx, row)
AssertNoError(t, err);
userService := service.NewUserService(logger, r, argon2IdHash)
userService := service.NewUserService(logger, r)
uuid, err := userService.Register(ctx, "test@test.com", "supersecurepassword")
AssertNoError(t, err)
t.Run("Delete exisiting user", func(t *testing.T) {
err := userService.DeleteUser(ctx, uuid)
@@ -87,25 +84,22 @@ func TestDeleteUser(t *testing.T) {
})
}
func TestUpdateUser(t *testing.T) {
func TestUpdateUserEmail(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slog.Default()
tdb := test.NewTestDatabase(t)
defer tdb.TearDown()
r := repository.NewUserRepository(logger, tdb.Db)
uuid := NewUUID(t)
argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
row := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"}
_, err := r.Create(ctx, row)
AssertNoError(t, err);
userService := service.NewUserService(logger, r, argon2IdHash)
userService := service.NewUserService(logger, r)
uuid, err := userService.Register(ctx, "test@test.com", "supersecurepassword")
AssertNoError(t, err)
t.Run("Update exisiting user", func(t *testing.T) {
user := service.User{uuid, "new@email.com", "supersecurehash"}
err := userService.UpdateUser(ctx, user)
t.Run("Update existing user email", func(t *testing.T) {
err := userService.UpdateUserEmail(ctx, uuid, "newemail@test.com")
AssertNoError(t, err)
@@ -113,13 +107,13 @@ func TestUpdateUser(t *testing.T) {
AssertNoError(t, err)
AssertUsers(t, newUser, user)
if newUser.Email != "newemail@test.com" {
t.Errorf("Emails do not match wanted %q, got %q", "newemail@test.com", newUser.Email)
}
})
t.Run("Update non-existant user", func(t *testing.T) {
user := service.User{NewUUID(t), "new@email.com", "supersecurehash"}
err := userService.UpdateUser(ctx, user)
err := userService.UpdateUserEmail(ctx, NewUUID(t), "newemail@test.com")
AssertErrors(t, err, service.ErrNotFound)
})