auth services, middleware, and other stuff
All checks were successful
ci/woodpecker/pr/build Pipeline was successful
ci/woodpecker/pr/lint Pipeline was successful
ci/woodpecker/pr/test Pipeline was successful

This commit is contained in:
Michael Thomson 2025-05-22 13:55:43 -04:00
parent 70bb4e66b4
commit e55d419d44
Signed by: mthomson
GPG Key ID: B6CA05EE5F436C79
22 changed files with 985 additions and 95 deletions

2
.env.example Normal file
View File

@ -0,0 +1,2 @@
POSTGRESQL_CONNECTION_STRING=postgres://todo:password@localhost:5432/todo
JWT_SECRET_KEY="supersecretjwtkey"

View File

@ -6,33 +6,59 @@ import (
"net/http" "net/http"
"os" "os"
"gitea.michaelthomson.dev/mthomson/habits/internal/auth"
authhandler "gitea.michaelthomson.dev/mthomson/habits/internal/auth/handler"
authservice "gitea.michaelthomson.dev/mthomson/habits/internal/auth/service"
"gitea.michaelthomson.dev/mthomson/habits/internal/logging" "gitea.michaelthomson.dev/mthomson/habits/internal/logging"
"gitea.michaelthomson.dev/mthomson/habits/internal/middleware" "gitea.michaelthomson.dev/mthomson/habits/internal/middleware"
"gitea.michaelthomson.dev/mthomson/habits/internal/migrate" "gitea.michaelthomson.dev/mthomson/habits/internal/migrate"
todohandler "gitea.michaelthomson.dev/mthomson/habits/internal/todo/handler" todohandler "gitea.michaelthomson.dev/mthomson/habits/internal/todo/handler"
todorepository "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository" todorepository "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
todoservice "gitea.michaelthomson.dev/mthomson/habits/internal/todo/service" todoservice "gitea.michaelthomson.dev/mthomson/habits/internal/todo/service"
userhandler "gitea.michaelthomson.dev/mthomson/habits/internal/user/handler"
userrepository "gitea.michaelthomson.dev/mthomson/habits/internal/user/repository"
userservice "gitea.michaelthomson.dev/mthomson/habits/internal/user/service"
pgxuuid "github.com/jackc/pgx-gofrs-uuid"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
pgxuuid "github.com/jackc/pgx-gofrs-uuid" "github.com/joho/godotenv"
) )
func main() { func main() {
// create logger // create logger
logger := logging.NewLogger() logger := logging.NewLogger()
// create middlewares // load env
contextMiddleware := middleware.ContextMiddleware(logger) err := godotenv.Load()
loggingMiddleware := middleware.LoggingMiddleware(logger) if err != nil {
logger.Error(err.Error())
os.Exit(1)
}
stack := []middleware.Middleware{ jwtSecretKey := os.Getenv("JWT_SECRET_KEY")
contextMiddleware, postgresqlConnectionString := os.Getenv("POSTGRESQL_CONNECTION_STRING")
// create hasher instance
argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
// create middlewares
traceMiddleware := middleware.TraceMiddleware(logger)
loggingMiddleware := middleware.LoggingMiddleware(logger)
authMiddleware := middleware.AuthMiddleware(logger, []byte(jwtSecretKey))
unauthenticatedStack := []middleware.Middleware{
traceMiddleware,
loggingMiddleware, loggingMiddleware,
} }
authenticatedStack := []middleware.Middleware{
traceMiddleware,
loggingMiddleware,
authMiddleware,
}
// create db pool // create db pool
postgresUrl := "postgres://todo:password@localhost:5432/todo" dbconfig, err := pgxpool.ParseConfig(postgresqlConnectionString)
dbconfig, err := pgxpool.ParseConfig(postgresUrl)
if err != nil { if err != nil {
logger.Error(err.Error()) logger.Error(err.Error())
os.Exit(1) os.Exit(1)
@ -54,18 +80,26 @@ func main() {
// create repos // create repos
todoRepository := todorepository.NewTodoRepository(logger, db) todoRepository := todorepository.NewTodoRepository(logger, db)
userRepository := userrepository.NewUserRepository(logger, db)
// create services // create services
todoService := todoservice.NewTodoService(logger, todoRepository) todoService := todoservice.NewTodoService(logger, todoRepository)
userService := userservice.NewUserService(logger, userRepository, argon2IdHash)
authService := authservice.NewAuthService(logger, []byte(jwtSecretKey), userRepository, argon2IdHash)
// create mux // create mux
mux := http.NewServeMux() mux := http.NewServeMux()
// register handlers // register handlers
mux.Handle("GET /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoGet(logger, todoService), stack)) // auth
mux.Handle("POST /todo", middleware.CompileMiddleware(todohandler.HandleTodoCreate(logger, todoService), stack)) mux.Handle("POST /login", middleware.CompileMiddleware(authhandler.HandleLogin(logger, authService), unauthenticatedStack))
mux.Handle("DELETE /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoDelete(logger, todoService), stack)) // users
mux.Handle("PUT /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoUpdate(logger, todoService), stack)) mux.Handle("POST /register", middleware.CompileMiddleware(userhandler.HandleRegisterUser(logger, userService), unauthenticatedStack))
// todos
mux.Handle("GET /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoGet(logger, todoService), authenticatedStack))
mux.Handle("POST /todo", middleware.CompileMiddleware(todohandler.HandleTodoCreate(logger, todoService), authenticatedStack))
mux.Handle("DELETE /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoDelete(logger, todoService), authenticatedStack))
mux.Handle("PUT /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoUpdate(logger, todoService), authenticatedStack))
// create server // create server
server := &http.Server{ server := &http.Server{

4
go.mod
View File

@ -6,10 +6,13 @@ toolchain go1.23.9
require ( require (
github.com/gofrs/uuid/v5 v5.3.2 github.com/gofrs/uuid/v5 v5.3.2
github.com/golang-jwt/jwt/v5 v5.2.2
github.com/jackc/pgx-gofrs-uuid v0.0.0-20230224015001-1d428863c2e2 github.com/jackc/pgx-gofrs-uuid v0.0.0-20230224015001-1d428863c2e2
github.com/jackc/pgx/v5 v5.7.4 github.com/jackc/pgx/v5 v5.7.4
github.com/joho/godotenv v1.5.1
github.com/testcontainers/testcontainers-go v0.37.0 github.com/testcontainers/testcontainers-go v0.37.0
github.com/testcontainers/testcontainers-go/modules/postgres v0.37.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.37.0
golang.org/x/crypto v0.38.0
) )
require ( require (
@ -64,7 +67,6 @@ require (
go.opentelemetry.io/otel v1.35.0 // indirect go.opentelemetry.io/otel v1.35.0 // indirect
go.opentelemetry.io/otel/metric v1.35.0 // indirect go.opentelemetry.io/otel/metric v1.35.0 // indirect
go.opentelemetry.io/otel/trace v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect
golang.org/x/crypto v0.38.0 // indirect
golang.org/x/sync v0.14.0 // indirect golang.org/x/sync v0.14.0 // indirect
golang.org/x/sys v0.33.0 // indirect golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.25.0 // indirect golang.org/x/text v0.25.0 // indirect

4
go.sum
View File

@ -43,6 +43,8 @@ github.com/gofrs/uuid/v5 v5.3.2 h1:2jfO8j3XgSwlz/wHqemAEugfnTlikAYHhnqQ8Xh4fE0=
github.com/gofrs/uuid/v5 v5.3.2/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8= github.com/gofrs/uuid/v5 v5.3.2/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
@ -59,6 +61,8 @@ github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg=
github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=

View File

@ -0,0 +1,66 @@
package handler
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"gitea.michaelthomson.dev/mthomson/habits/internal/auth/service"
)
type Loginer interface {
Login(ctx context.Context, email string, password string) (string, error)
}
type LoginRequest struct {
Email string `json:"email"`
Password string `json:"password"`
}
func HandleLogin(logger *slog.Logger, authService Loginer) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
loginRequest := LoginRequest{}
decoder := json.NewDecoder(r.Body)
decoder.DisallowUnknownFields()
err := decoder.Decode(&loginRequest)
if err != nil {
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusBadRequest)
return
}
token, err := authService.Login(ctx, loginRequest.Email, loginRequest.Password)
if err == service.ErrUnauthorized {
http.Error(w, "", http.StatusUnauthorized)
return
}
if err == service.ErrNotFound {
http.Error(w, "", http.StatusUnauthorized)
return
}
if err != nil {
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusInternalServerError)
return
}
cookie := http.Cookie{
Name: "token",
Value: token,
Path: "/",
MaxAge: 3600,
HttpOnly: true,
Secure: false,
SameSite: http.SameSiteLaxMode,
}
http.SetCookie(w, &cookie)
w.WriteHeader(http.StatusOK)
}
}

View File

@ -0,0 +1,168 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"log/slog"
"errors"
"net/http"
"net/http/httptest"
"testing"
"gitea.michaelthomson.dev/mthomson/habits/internal/auth/service"
)
type MockLoginer struct {
LoginFunc func(ctx context.Context, email string, password string) (string, error)
}
func (m *MockLoginer) Login(ctx context.Context, email string, password string) (string, error) {
return m.LoginFunc(ctx, email, password)
}
func TestLogin(t *testing.T) {
logger := slog.Default()
t.Run("returns 200 for existing user with correct credentials", func(t *testing.T) {
loginRequest := LoginRequest{Email: "test@test.com", Password: "password"}
token := "examplejwt"
service := MockLoginer{
LoginFunc: func(ctx context.Context, email string, password string) (string, error) {
return token, nil
},
}
handler := HandleLogin(logger, &service)
requestBody, err := json.Marshal(loginRequest)
if err != nil {
t.Fatalf("Failed to marshal request %+v: %v", loginRequest, err)
}
req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody))
res := httptest.NewRecorder()
handler(res, req)
AssertStatusCodes(t, res.Code, http.StatusOK)
AssertToken(t, res, token)
})
t.Run("returns 401 for existing user with incorrect credentials", func(t *testing.T) {
loginRequest := LoginRequest{Email: "test@test.com", Password: "password"}
service := MockLoginer{
LoginFunc: func(ctx context.Context, email string, password string) (string, error) {
return "", service.ErrUnauthorized
},
}
handler := HandleLogin(logger, &service)
requestBody, err := json.Marshal(loginRequest)
if err != nil {
t.Fatalf("Failed to marshal request %+v: %v", loginRequest, err)
}
req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody))
res := httptest.NewRecorder()
handler(res, req)
AssertStatusCodes(t, res.Code, http.StatusUnauthorized)
})
t.Run("returns 401 for non-existing user", func(t *testing.T) {
loginRequest := LoginRequest{Email: "test@test.com", Password: "password"}
service := MockLoginer{
LoginFunc: func(ctx context.Context, email string, password string) (string, error) {
return "", service.ErrNotFound
},
}
handler := HandleLogin(logger, &service)
requestBody, err := json.Marshal(loginRequest)
if err != nil {
t.Fatalf("Failed to marshal request %+v: %v", loginRequest, err)
}
req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody))
res := httptest.NewRecorder()
handler(res, req)
AssertStatusCodes(t, res.Code, http.StatusUnauthorized)
})
t.Run("returns 400 with bad json", func(t *testing.T) {
handler := HandleLogin(logger, nil)
badStruct := struct {
Foo string
}{
Foo: "bar",
}
requestBody, err := json.Marshal(badStruct)
if err != nil {
t.Fatalf("Failed to marshal request %+v: %v", badStruct, err)
}
req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody))
res := httptest.NewRecorder()
handler(res, req)
AssertStatusCodes(t, res.Code, http.StatusBadRequest)
})
t.Run("returns 500 arbitrary errors", func(t *testing.T) {
loginRequest := LoginRequest{Email: "test@test.com", Password: "password"}
service := MockLoginer{
LoginFunc: func(ctx context.Context, email string, password string) (string, error) {
return "", errors.New("foo bar")
},
}
handler := HandleLogin(logger, &service)
requestBody, err := json.Marshal(loginRequest)
if err != nil {
t.Fatalf("Failed to marshal request %+v: %v", loginRequest, err)
}
req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody))
res := httptest.NewRecorder()
handler(res, req)
AssertStatusCodes(t, res.Code, http.StatusInternalServerError)
})
}
func AssertStatusCodes(t testing.TB, got, want int) {
t.Helper()
if got != want {
t.Errorf("got status code: %v, want status code: %v", want, got)
}
}
func AssertToken(t testing.TB, res *httptest.ResponseRecorder, want string) {
t.Helper()
got := res.Result().Cookies()[0].Value
if got != want {
t.Errorf("got cookie: %q, want cookie: %q", got, want)
}
}

70
internal/auth/hashing.go Normal file
View File

@ -0,0 +1,70 @@
package auth
import (
"bytes"
"crypto/rand"
"errors"
"golang.org/x/crypto/argon2"
)
type HashSalt struct {
Hash, Salt []byte
}
type Argon2IdHash struct {
time uint32
memory uint32
threads uint8
keyLen uint32
saltLen uint32
}
var (
ErrNoMatch error = errors.New("hash doesn't match")
)
func NewArgon2IdHash(time, saltLen uint32, memory uint32, threads uint8, keyLen uint32) *Argon2IdHash {
return &Argon2IdHash{
time: time,
saltLen: saltLen,
memory: memory,
threads: threads,
keyLen: keyLen,
}
}
func (a *Argon2IdHash) GenerateHash(password, salt []byte) (*HashSalt, error) {
var err error
if len(salt) == 0 {
salt, err = randomSecret(a.saltLen)
}
if err != nil {
return nil, err
}
hash := argon2.IDKey(password, salt, a.time, a.memory, a.threads, a.keyLen)
return &HashSalt{Hash: hash, Salt: salt}, nil
}
func (a *Argon2IdHash) Compare(hash, salt, password []byte) error {
hashSalt, err := a.GenerateHash(password, salt)
if err != nil {
return err
}
if !bytes.Equal(hash, hashSalt.Hash) {
return ErrNoMatch
}
return nil
}
func randomSecret(length uint32) ([]byte, error) {
secret := make([]byte, length)
_, err := rand.Read(secret)
if err != nil {
return nil, err
}
return secret, nil
}

View File

@ -0,0 +1,90 @@
package service
import (
"context"
"errors"
"log/slog"
"time"
"gitea.michaelthomson.dev/mthomson/habits/internal/auth"
userrepository "gitea.michaelthomson.dev/mthomson/habits/internal/user/repository"
"github.com/golang-jwt/jwt/v5"
)
var (
ErrNotFound error = errors.New("user cannot be found")
ErrUnauthorized error = errors.New("user password incorrect")
)
type UserRepository interface {
Create(ctx context.Context, user userrepository.UserRow) (userrepository.UserRow, error)
GetByEmail(ctx context.Context, email string) (userrepository.UserRow, error)
}
type AuthService struct {
logger *slog.Logger
jwtKey []byte
userRepository UserRepository
argon2IdHash *auth.Argon2IdHash
}
func NewAuthService(logger *slog.Logger, jwtKey []byte, userRepository UserRepository, argon2IdHash *auth.Argon2IdHash) *AuthService {
return &AuthService{
logger: logger,
jwtKey: jwtKey,
userRepository: userRepository,
argon2IdHash: argon2IdHash,
}
}
func (a AuthService) Login(ctx context.Context, email string, password string) (string, error) {
// get user if exists
userRow, err := a.userRepository.GetByEmail(ctx, email)
if err != nil {
if err == userrepository.ErrNotFound {
return "", ErrNotFound
}
a.logger.ErrorContext(ctx, err.Error())
return "", err
}
// compare hashed passswords
err = a.argon2IdHash.Compare(userRow.HashedPassword, userRow.Salt, []byte(password))
if err == auth.ErrNoMatch {
return "", ErrUnauthorized
}
if err != nil {
a.logger.ErrorContext(ctx, err.Error())
return "", err
}
// create token and return it
token, err := a.CreateToken(ctx, email)
if err != nil {
a.logger.ErrorContext(ctx, err.Error())
return "", err
}
return token, nil
}
func (a AuthService) CreateToken(ctx context.Context, email string) (string, error) {
// Create a new JWT token with claims
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": email,
"iss": "todo-app",
"aud": "user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
})
tokenString, err := claims.SignedString(a.jwtKey)
if err != nil {
a.logger.ErrorContext(ctx, err.Error())
}
return tokenString, nil
}

View File

@ -0,0 +1,81 @@
package service_test
import (
"context"
"log/slog"
"testing"
"gitea.michaelthomson.dev/mthomson/habits/internal/auth"
authservice "gitea.michaelthomson.dev/mthomson/habits/internal/auth/service"
"gitea.michaelthomson.dev/mthomson/habits/internal/test"
"gitea.michaelthomson.dev/mthomson/habits/internal/user/repository"
userservice "gitea.michaelthomson.dev/mthomson/habits/internal/user/service"
"github.com/gofrs/uuid/v5"
)
func TestLogin(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slog.Default()
tdb := test.NewTestDatabase(t)
defer tdb.TearDown()
userRepository := repository.NewUserRepository(logger, tdb.Db)
argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
userService := userservice.NewUserService(logger, userRepository, argon2IdHash)
authService := authservice.NewAuthService(logger, []byte("secretkey"), userRepository, argon2IdHash)
_, err := userService.Register(ctx, "test@test.com", "supersecurepassword")
AssertNoError(t, err)
t.Run("login existing user with correct credentials", func(t *testing.T) {
want, err := authService.CreateToken(ctx, "test@test.com")
AssertNoError(t, err)
got, err := authService.Login(ctx, "test@test.com", "supersecurepassword")
AssertNoError(t, err)
AssertTokens(t, got, want)
})
t.Run("login existing user with incorrect credentials", func(t *testing.T) {
_, err := authService.Login(ctx, "test@test.com", "superwrongpassword")
AssertErrors(t, err, authservice.ErrUnauthorized)
})
t.Run("login nonexistant user", func(t *testing.T) {
_, err := authService.Login(ctx, "foo@test.com", "supersecurepassword")
AssertErrors(t, err, authservice.ErrNotFound)
})
}
func AssertErrors(t testing.TB, got, want error) {
t.Helper()
if got != want {
t.Errorf("got error: %v, want error: %v", want, got)
}
}
func AssertNoError(t testing.TB, err error) {
t.Helper()
if err != nil {
t.Errorf("expected no error, got %v", err)
}
}
func AssertTokens(t testing.TB, got, want string) {
t.Helper()
if got != want {
t.Errorf("expected matching tokens, got %q, want %q", got, want)
}
}
func NewUUID(t testing.TB) uuid.UUID {
t.Helper()
uuid, err := uuid.NewV4()
if err != nil {
t.Errorf("error generation uuid: %v", err)
}
return uuid
}

View File

@ -20,7 +20,7 @@ func (h *ContextHandler) Handle(ctx context.Context, r slog.Record) error {
} }
func NewLogger() *slog.Logger { func NewLogger() *slog.Logger {
baseHandler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{AddSource: false}) baseHandler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{AddSource: true})
customHandler := &ContextHandler{Handler: baseHandler} customHandler := &ContextHandler{Handler: baseHandler}
logger := slog.New(customHandler) logger := slog.New(customHandler)

View File

@ -0,0 +1,60 @@
package middleware
import (
"context"
"log/slog"
"net/http"
"github.com/golang-jwt/jwt/v5"
)
const EmailKey contextKey = "email"
func AuthMiddleware(logger *slog.Logger, jwtKey []byte) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("token")
if err == http.ErrNoCookie {
logger.WarnContext(r.Context(), "token not provided")
w.WriteHeader(http.StatusUnauthorized)
return
}
if err != nil {
logger.ErrorContext(r.Context(), err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
}
tokenString := cookie.Value
token, err := jwt.Parse(tokenString, func(t *jwt.Token) (any, error) {
return jwtKey, nil
})
if !token.Valid {
logger.WarnContext(r.Context(), err.Error())
w.WriteHeader(http.StatusUnauthorized)
return
}
if err != nil {
logger.ErrorContext(r.Context(), err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
}
email, err := token.Claims.GetSubject()
if err != nil {
logger.ErrorContext(r.Context(), err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
}
ctx := context.WithValue(r.Context(), EmailKey, email)
newReq := r.WithContext(ctx)
next.ServeHTTP(w, newReq)
})
}
}

View File

@ -8,17 +8,16 @@ import (
"github.com/gofrs/uuid/v5" "github.com/gofrs/uuid/v5"
) )
type contextKey string
const TraceIdKey contextKey = "trace_id" const TraceIdKey contextKey = "trace_id"
func ContextMiddleware(logger *slog.Logger) func(http.Handler) http.Handler { func TraceMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
traceid, err := uuid.NewV4() traceid, err := uuid.NewV4()
if err != nil { if err != nil {
logger.ErrorContext(r.Context(), err.Error()) logger.ErrorContext(r.Context(), err.Error())
} }
ctx := context.WithValue(r.Context(), TraceIdKey, traceid) ctx := context.WithValue(r.Context(), TraceIdKey, traceid.String())
newReq := r.WithContext(ctx) newReq := r.WithContext(ctx)
next.ServeHTTP(w, newReq) next.ServeHTTP(w, newReq)

View File

@ -4,6 +4,8 @@ import "net/http"
type Middleware func(http.Handler) http.Handler type Middleware func(http.Handler) http.Handler
type contextKey string
func CompileMiddleware(h http.Handler, m []Middleware) http.Handler { func CompileMiddleware(h http.Handler, m []Middleware) http.Handler {
if len(m) < 1 { if len(m) < 1 {
return h return h

View File

@ -6,6 +6,7 @@ CREATE TABLE todo(
CREATE TABLE users( CREATE TABLE users(
id uuid PRIMARY KEY DEFAULT gen_random_uuid(), id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
email VARCHAR NOT NULL, email VARCHAR NOT NULL UNIQUE,
hashed_password VARCHAR NOT NULL hashed_password bytea NOT NULL,
salt bytea NOT NULL
); );

View File

@ -7,12 +7,12 @@ import (
"time" "time"
"gitea.michaelthomson.dev/mthomson/habits/internal/migrate" "gitea.michaelthomson.dev/mthomson/habits/internal/migrate"
pgxuuid "github.com/jackc/pgx-gofrs-uuid"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
"github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait" "github.com/testcontainers/testcontainers-go/wait"
pgxuuid "github.com/jackc/pgx-gofrs-uuid"
) )
type TestDatabase struct { type TestDatabase struct {

View File

@ -39,7 +39,7 @@ func TestGetTodo(t *testing.T) {
row := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false} row := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false}
_, err := r.Create(ctx, row) _, err := r.Create(ctx, row)
AssertNoError(t, err); AssertNoError(t, err)
todoService := service.NewTodoService(logger, r) todoService := service.NewTodoService(logger, r)
@ -66,7 +66,7 @@ func TestDeleteTodo(t *testing.T) {
row := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false} row := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false}
_, err := r.Create(ctx, row) _, err := r.Create(ctx, row)
AssertNoError(t, err); AssertNoError(t, err)
todoService := service.NewTodoService(logger, r) todoService := service.NewTodoService(logger, r)
@ -93,7 +93,7 @@ func TestUpdateTodo(t *testing.T) {
row := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false} row := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false}
_, err := r.Create(ctx, row) _, err := r.Create(ctx, row)
AssertNoError(t, err); AssertNoError(t, err)
todoService := service.NewTodoService(logger, r) todoService := service.NewTodoService(logger, r)

View File

@ -0,0 +1,67 @@
package handler
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"gitea.michaelthomson.dev/mthomson/habits/internal/user/service"
"github.com/gofrs/uuid/v5"
)
type UserRegisterer interface {
Register(ctx context.Context, email string, password string) (uuid.UUID, error)
}
type RegisterRequest struct {
Email string `json:"email"`
Password string `json:"password"`
}
type RegisterResponse struct {
Id string `json:"id"`
}
func HandleRegisterUser(logger *slog.Logger, userService UserRegisterer) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
registerRequest := RegisterRequest{}
decoder := json.NewDecoder(r.Body)
decoder.DisallowUnknownFields()
err := decoder.Decode(&registerRequest)
if err != nil {
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusBadRequest)
return
}
uuid, err := userService.Register(ctx, registerRequest.Email, registerRequest.Password)
if err != nil {
if err == service.ErrUserExists {
http.Error(w, "", http.StatusConflict)
return
}
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusInternalServerError)
return
}
response := RegisterResponse{uuid.String()}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
err = json.NewEncoder(w).Encode(response)
if err != nil {
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusInternalServerError)
return
}
}
}

View File

@ -0,0 +1,148 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"log/slog"
"errors"
"net/http"
"net/http/httptest"
"testing"
"gitea.michaelthomson.dev/mthomson/habits/internal/user/service"
"github.com/gofrs/uuid/v5"
)
type MockUserRegisterer struct {
RegisterUserFunc func(ctx context.Context, email string, password string) (uuid.UUID, error)
}
func (m *MockUserRegisterer) Register(ctx context.Context, email string, password string) (uuid.UUID, error) {
return m.RegisterUserFunc(ctx, email, password)
}
func TestCreateUser(t *testing.T) {
logger := slog.Default()
t.Run("create user", func(t *testing.T) {
createUserRequest := RegisterRequest{Email: "test@test.com", Password: "password"}
newUUID := NewUUID(t)
service := MockUserRegisterer{
RegisterUserFunc: func(ctx context.Context, email string, password string) (uuid.UUID, error) {
return newUUID, nil
},
}
handler := HandleRegisterUser(logger, &service)
requestBody, err := json.Marshal(createUserRequest)
if err != nil {
t.Fatalf("Failed to marshal request %+v: %v", createUserRequest, err)
}
req := httptest.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(requestBody))
res := httptest.NewRecorder()
handler(res, req)
AssertStatusCodes(t, res.Code, http.StatusOK)
})
t.Run("returns 409 when user exists", func(t *testing.T) {
createUserRequest := RegisterRequest{Email: "test@test.com", Password: "password"}
newUUID := NewUUID(t)
service := MockUserRegisterer{
RegisterUserFunc: func(ctx context.Context, email string, password string) (uuid.UUID, error) {
return newUUID, service.ErrUserExists
},
}
handler := HandleRegisterUser(logger, &service)
requestBody, err := json.Marshal(createUserRequest)
if err != nil {
t.Fatalf("Failed to marshal request %+v: %v", createUserRequest, err)
}
req := httptest.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(requestBody))
res := httptest.NewRecorder()
handler(res, req)
AssertStatusCodes(t, res.Code, http.StatusConflict)
})
t.Run("returns 400 with bad json", func(t *testing.T) {
handler := HandleRegisterUser(logger, nil)
badStruct := struct {
Foo string
}{
Foo: "bar",
}
requestBody, err := json.Marshal(badStruct)
if err != nil {
t.Fatalf("Failed to marshal request %+v: %v", badStruct, err)
}
req := httptest.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(requestBody))
res := httptest.NewRecorder()
handler(res, req)
AssertStatusCodes(t, res.Code, http.StatusBadRequest)
})
t.Run("returns 500 arbitrary errors", func(t *testing.T) {
createUserRequest := RegisterRequest{Email: "test@test.com", Password: "password"}
newUUID := NewUUID(t)
service := MockUserRegisterer{
RegisterUserFunc: func(ctx context.Context, email string, password string) (uuid.UUID, error) {
return newUUID, errors.New("foo bar")
},
}
handler := HandleRegisterUser(logger, &service)
requestBody, err := json.Marshal(createUserRequest)
if err != nil {
t.Fatalf("Failed to marshal request %+v: %v", createUserRequest, err)
}
req := httptest.NewRequest(http.MethodPost, "/user", bytes.NewBuffer(requestBody))
res := httptest.NewRecorder()
handler(res, req)
AssertStatusCodes(t, res.Code, http.StatusInternalServerError)
})
}
func AssertStatusCodes(t testing.TB, got, want int) {
t.Helper()
if got != want {
t.Errorf("got status code: %v, want status code: %v", want, got)
}
}
func NewUUID(t testing.TB) uuid.UUID {
t.Helper()
uuid, err := uuid.NewV4()
if err != nil {
t.Errorf("error generation uuid: %v", err)
}
return uuid
}

View File

@ -1,6 +1,7 @@
package repository package repository
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"log/slog" "log/slog"
@ -17,15 +18,12 @@ var (
type UserRow struct { type UserRow struct {
Id uuid.UUID Id uuid.UUID
Email string Email string
HashedPassword string HashedPassword []byte
} Salt []byte
func NewUserRow(id uuid.UUID, email string, hashedPassword string) UserRow {
return UserRow{Id: id, Email: email, HashedPassword: hashedPassword}
} }
func (u UserRow) Equal(user UserRow) bool { func (u UserRow) Equal(user UserRow) bool {
return u.Id == user.Id && u.Email == user.Email && u.HashedPassword == user.HashedPassword return u.Id == user.Id && u.Email == user.Email && bytes.Equal(u.HashedPassword, user.HashedPassword) && bytes.Equal(u.Salt, user.Salt)
} }
type UserRepository struct { type UserRepository struct {
@ -43,7 +41,24 @@ func NewUserRepository(logger *slog.Logger, db *pgxpool.Pool) *UserRepository {
func (r *UserRepository) GetById(ctx context.Context, id uuid.UUID) (UserRow, error) { func (r *UserRepository) GetById(ctx context.Context, id uuid.UUID) (UserRow, error) {
user := UserRow{} user := UserRow{}
err := r.db.QueryRow(ctx, "SELECT * FROM users WHERE id = $1;", id).Scan(&user.Id, &user.Email, &user.HashedPassword) err := r.db.QueryRow(ctx, "SELECT * FROM users WHERE id = $1;", id).Scan(&user.Id, &user.Email, &user.HashedPassword, &user.Salt)
if err != nil {
if err == pgx.ErrNoRows {
return user, ErrNotFound
}
r.logger.ErrorContext(ctx, err.Error())
return user, err
}
return user, nil
}
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (UserRow, error) {
user := UserRow{}
err := r.db.QueryRow(ctx, "SELECT * FROM users WHERE email = $1;", email).Scan(&user.Id, &user.Email, &user.HashedPassword, &user.Salt)
if err != nil { if err != nil {
if err == pgx.ErrNoRows { if err == pgx.ErrNoRows {
@ -60,7 +75,7 @@ func (r *UserRepository) GetById(ctx context.Context, id uuid.UUID) (UserRow, er
func (r *UserRepository) Create(ctx context.Context, user UserRow) (UserRow, error) { func (r *UserRepository) Create(ctx context.Context, user UserRow) (UserRow, error) {
var result pgx.Row var result pgx.Row
if user.Id.IsNil() { if user.Id.IsNil() {
result = r.db.QueryRow(ctx, "INSERT INTO users (email, hashed_password) VALUES ($1, $2) RETURNING id;", user.Email, user.HashedPassword) result = r.db.QueryRow(ctx, "INSERT INTO users (email, hashed_password, salt) VALUES ($1, $2, $3) RETURNING id;", user.Email, user.HashedPassword, user.Salt)
err := result.Scan(&user.Id) err := result.Scan(&user.Id)
@ -69,7 +84,7 @@ func (r *UserRepository) Create(ctx context.Context, user UserRow) (UserRow, err
return UserRow{}, err return UserRow{}, err
} }
} else { } else {
_, err := r.db.Exec(ctx, "INSERT INTO users (id, email, hashed_password) VALUES ($1, $2, $3);", user.Id, user.Email, user.HashedPassword) _, err := r.db.Exec(ctx, "INSERT INTO users (id, email, hashed_password, salt) VALUES ($1, $2, $3, $4);", user.Id, user.Email, user.HashedPassword, user.Salt)
if err != nil { if err != nil {
r.logger.ErrorContext(ctx, err.Error()) r.logger.ErrorContext(ctx, err.Error())
@ -81,7 +96,7 @@ func (r *UserRepository) Create(ctx context.Context, user UserRow) (UserRow, err
} }
func (r *UserRepository) Update(ctx context.Context, user UserRow) error { func (r *UserRepository) Update(ctx context.Context, user UserRow) error {
result, err := r.db.Exec(ctx, "UPDATE users SET email = $1, hashed_password = $2 WHERE id = $3;", user.Email, user.HashedPassword, user.Id) result, err := r.db.Exec(ctx, "UPDATE users SET email = $1, hashed_password = $2, salt = $3 WHERE id = $4;", user.Email, user.HashedPassword, user.Salt, user.Id)
if err != nil { if err != nil {
r.logger.ErrorContext(ctx, err.Error()) r.logger.ErrorContext(ctx, err.Error())

View File

@ -6,6 +6,7 @@ import (
"log/slog" "log/slog"
"testing" "testing"
"gitea.michaelthomson.dev/mthomson/habits/internal/auth"
"gitea.michaelthomson.dev/mthomson/habits/internal/test" "gitea.michaelthomson.dev/mthomson/habits/internal/test"
"gitea.michaelthomson.dev/mthomson/habits/internal/user/repository" "gitea.michaelthomson.dev/mthomson/habits/internal/user/repository"
"github.com/gofrs/uuid/v5" "github.com/gofrs/uuid/v5"
@ -18,22 +19,36 @@ func TestCRUD(t *testing.T) {
defer tdb.TearDown() defer tdb.TearDown()
r := repository.NewUserRepository(logger, tdb.Db) r := repository.NewUserRepository(logger, tdb.Db)
uuid := NewUUID(t) uuid := NewUUID(t)
argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
hashSalt, err := argon2IdHash.GenerateHash([]byte("supersecurepassword"), []byte("supersecuresalt"))
if err != nil {
t.Errorf("could not generate hash: %v", err)
}
t.Run("creates new user", func(t *testing.T) { t.Run("creates new user", func(t *testing.T) {
newUser := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"} newUser := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt}
_, err := r.Create(ctx, newUser) _, err := r.Create(ctx, newUser)
AssertNoError(t, err) AssertNoError(t, err)
}) })
t.Run("gets user", func(t *testing.T) { t.Run("gets user by id", func(t *testing.T) {
want := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"} want := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt}
got, err := r.GetById(ctx, uuid) got, err := r.GetById(ctx, uuid)
AssertNoError(t, err) AssertNoError(t, err)
AssertUserRows(t, got, want) AssertUserRows(t, got, want)
}) })
t.Run("gets user by email", func(t *testing.T) {
want := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt}
got, err := r.GetByEmail(ctx, "test@test.com")
AssertNoError(t, err)
AssertUserRows(t, got, want)
})
t.Run("updates user", func(t *testing.T) { t.Run("updates user", func(t *testing.T) {
want := repository.UserRow{Id: uuid, Email: "new@test.com", HashedPassword: "supersecurehash"} want := repository.UserRow{Id: uuid, Email: "new@test.com", HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt}
err := r.Update(ctx, want) err := r.Update(ctx, want)
AssertNoError(t, err) AssertNoError(t, err)

View File

@ -5,52 +5,48 @@ import (
"errors" "errors"
"log/slog" "log/slog"
"gitea.michaelthomson.dev/mthomson/habits/internal/auth"
"gitea.michaelthomson.dev/mthomson/habits/internal/user/repository" "gitea.michaelthomson.dev/mthomson/habits/internal/user/repository"
"github.com/gofrs/uuid/v5" "github.com/gofrs/uuid/v5"
) )
var ( var (
ErrNotFound error = errors.New("user cannot be found") ErrNotFound error = errors.New("user cannot be found")
ErrUserExists error = errors.New("user already exists")
) )
type User struct { type User struct {
Id uuid.UUID Id uuid.UUID
Email string Email string
HashedPassword string
}
func NewUser(id uuid.UUID, email string, hashedPassword string) User {
return User{Id: id, Email: email, HashedPassword: hashedPassword}
} }
func UserFromUserRow(userRow repository.UserRow) User { func UserFromUserRow(userRow repository.UserRow) User {
return User{Id: userRow.Id, Email: userRow.Email, HashedPassword: userRow.HashedPassword} return User{Id: userRow.Id, Email: userRow.Email}
}
func UserRowFromUser(user User) repository.UserRow {
return repository.UserRow{Id: user.Id, Email: user.Email, HashedPassword: user.HashedPassword}
} }
func (t User) Equal(user User) bool { func (t User) Equal(user User) bool {
return t.Id == user.Id && t.Email == user.Email && t.HashedPassword == user.HashedPassword return t.Id == user.Id && t.Email == user.Email
} }
type UserRepository interface { type UserRepository interface {
Create(ctx context.Context, user repository.UserRow) (repository.UserRow, error) Create(ctx context.Context, user repository.UserRow) (repository.UserRow, error)
GetById(ctx context.Context, id uuid.UUID) (repository.UserRow, error) GetById(ctx context.Context, id uuid.UUID) (repository.UserRow, error)
GetByEmail(ctx context.Context, email string) (repository.UserRow, error)
Update(ctx context.Context, user repository.UserRow) error Update(ctx context.Context, user repository.UserRow) error
Delete(ctx context.Context, id uuid.UUID) error Delete(ctx context.Context, id uuid.UUID) error
} }
type UserService struct { type UserService struct {
logger *slog.Logger logger *slog.Logger
repo UserRepository repo UserRepository
argon2IdHash *auth.Argon2IdHash
} }
func NewUserService(logger *slog.Logger, userRepo UserRepository) *UserService { func NewUserService(logger *slog.Logger, userRepo UserRepository, argon2IdHash *auth.Argon2IdHash) *UserService {
return &UserService{ return &UserService{
logger: logger, logger: logger,
repo: userRepo, repo: userRepo,
argon2IdHash: argon2IdHash,
} }
} }
@ -69,17 +65,37 @@ func (s *UserService) GetUser(ctx context.Context, id uuid.UUID) (User, error) {
return UserFromUserRow(user), err return UserFromUserRow(user), err
} }
func (s *UserService) CreateUser(ctx context.Context, user User) (User, error) { func (s *UserService) Register(ctx context.Context, email string, password string) (uuid.UUID, error) {
userRow := UserRowFromUser(user) uuid, err := uuid.NewV4()
newUserRow, err := s.repo.Create(ctx, userRow)
if err != nil { if err != nil {
s.logger.ErrorContext(ctx, err.Error()) s.logger.ErrorContext(ctx, err.Error())
return User{}, err return uuid, err
} }
return UserFromUserRow(newUserRow), err _, err = s.repo.GetByEmail(ctx, email)
if err != repository.ErrNotFound {
return uuid, ErrUserExists
}
hashSalt, err := s.argon2IdHash.GenerateHash([]byte(password), nil)
if err != nil {
s.logger.ErrorContext(ctx, err.Error())
return uuid, err
}
userRow := repository.UserRow{Id: uuid, Email: email, HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt}
_, err = s.repo.Create(ctx, userRow)
if err != nil {
s.logger.ErrorContext(ctx, err.Error())
return uuid, err
}
return uuid, err
} }
func (s *UserService) DeleteUser(ctx context.Context, id uuid.UUID) error { func (s *UserService) DeleteUser(ctx context.Context, id uuid.UUID) error {
@ -96,10 +112,66 @@ func (s *UserService) DeleteUser(ctx context.Context, id uuid.UUID) error {
return err return err
} }
func (s *UserService) UpdateUser(ctx context.Context, user User) error { func (s *UserService) UpdateUserEmail(ctx context.Context, id uuid.UUID, email string) error {
userRow := UserRowFromUser(user) user, err := s.repo.GetById(ctx, id)
err := s.repo.Update(ctx, userRow) if err == repository.ErrNotFound {
return ErrNotFound
}
if err != nil {
s.logger.ErrorContext(ctx, err.Error())
return err
}
_, err = s.repo.GetByEmail(ctx, email)
switch err {
case repository.ErrNotFound:
user.Email = email
err = s.repo.Update(ctx, user)
if err == repository.ErrNotFound {
return ErrNotFound
}
if err != nil {
s.logger.ErrorContext(ctx, err.Error())
}
return err
case nil:
return ErrUserExists
default:
s.logger.ErrorContext(ctx, err.Error())
return err
}
}
func (s *UserService) UpdateUserPassword(ctx context.Context, id uuid.UUID, password string) error {
user, err := s.repo.GetById(ctx, id)
if err == repository.ErrNotFound {
return ErrNotFound
}
if err != nil {
s.logger.ErrorContext(ctx, err.Error())
return err
}
hashSalt, err := s.argon2IdHash.GenerateHash([]byte(password), nil)
if err != nil {
s.logger.ErrorContext(ctx, err.Error())
return err
}
user.HashedPassword = hashSalt.Hash
user.Salt = hashSalt.Salt
err = s.repo.Update(ctx, user)
if err == repository.ErrNotFound { if err == repository.ErrNotFound {
return ErrNotFound return ErrNotFound

View File

@ -5,27 +5,26 @@ import (
"log/slog" "log/slog"
"testing" "testing"
"gitea.michaelthomson.dev/mthomson/habits/internal/auth"
"gitea.michaelthomson.dev/mthomson/habits/internal/test" "gitea.michaelthomson.dev/mthomson/habits/internal/test"
"gitea.michaelthomson.dev/mthomson/habits/internal/user/repository" "gitea.michaelthomson.dev/mthomson/habits/internal/user/repository"
"gitea.michaelthomson.dev/mthomson/habits/internal/user/service" "gitea.michaelthomson.dev/mthomson/habits/internal/user/service"
"github.com/gofrs/uuid/v5" "github.com/gofrs/uuid/v5"
) )
func TestCreateUser(t *testing.T) { func TestRegisterUser(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
logger := slog.Default() logger := slog.Default()
tdb := test.NewTestDatabase(t) tdb := test.NewTestDatabase(t)
defer tdb.TearDown() defer tdb.TearDown()
r := repository.NewUserRepository(logger, tdb.Db) r := repository.NewUserRepository(logger, tdb.Db)
argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
userService := service.NewUserService(logger, r) userService := service.NewUserService(logger, r, argon2IdHash)
t.Run("Create user", func(t *testing.T) { t.Run("Create user", func(t *testing.T) {
uuid := NewUUID(t) _, err := userService.Register(ctx, "test@test.com", "supersecurepassword")
user := service.NewUser(uuid, "test@test.com", "supersecurehash")
_, err := userService.CreateUser(ctx, user)
AssertNoError(t, err) AssertNoError(t, err)
}) })
@ -38,13 +37,12 @@ func TestGetUser(t *testing.T) {
tdb := test.NewTestDatabase(t) tdb := test.NewTestDatabase(t)
defer tdb.TearDown() defer tdb.TearDown()
r := repository.NewUserRepository(logger, tdb.Db) r := repository.NewUserRepository(logger, tdb.Db)
uuid := NewUUID(t) argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
row := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"} userService := service.NewUserService(logger, r, argon2IdHash)
_, err := r.Create(ctx, row)
AssertNoError(t, err);
userService := service.NewUserService(logger, r) uuid, err := userService.Register(ctx, "test@test.com", "supersecurepassword")
AssertNoError(t, err)
t.Run("Get exisiting user", func(t *testing.T) { t.Run("Get exisiting user", func(t *testing.T) {
_, err := userService.GetUser(ctx, uuid) _, err := userService.GetUser(ctx, uuid)
@ -66,13 +64,12 @@ func TestDeleteUser(t *testing.T) {
tdb := test.NewTestDatabase(t) tdb := test.NewTestDatabase(t)
defer tdb.TearDown() defer tdb.TearDown()
r := repository.NewUserRepository(logger, tdb.Db) r := repository.NewUserRepository(logger, tdb.Db)
uuid := NewUUID(t) argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
row := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"} userService := service.NewUserService(logger, r, argon2IdHash)
_, err := r.Create(ctx, row)
AssertNoError(t, err);
userService := service.NewUserService(logger, r) uuid, err := userService.Register(ctx, "test@test.com", "supersecurepassword")
AssertNoError(t, err)
t.Run("Delete exisiting user", func(t *testing.T) { t.Run("Delete exisiting user", func(t *testing.T) {
err := userService.DeleteUser(ctx, uuid) err := userService.DeleteUser(ctx, uuid)
@ -87,25 +84,22 @@ func TestDeleteUser(t *testing.T) {
}) })
} }
func TestUpdateUser(t *testing.T) { func TestUpdateUserEmail(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
logger := slog.Default() logger := slog.Default()
tdb := test.NewTestDatabase(t) tdb := test.NewTestDatabase(t)
defer tdb.TearDown() defer tdb.TearDown()
r := repository.NewUserRepository(logger, tdb.Db) r := repository.NewUserRepository(logger, tdb.Db)
uuid := NewUUID(t) argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
row := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"} userService := service.NewUserService(logger, r, argon2IdHash)
_, err := r.Create(ctx, row)
AssertNoError(t, err);
userService := service.NewUserService(logger, r) uuid, err := userService.Register(ctx, "test@test.com", "supersecurepassword")
AssertNoError(t, err)
t.Run("Update exisiting user", func(t *testing.T) { t.Run("Update existing user email", func(t *testing.T) {
user := service.User{uuid, "new@email.com", "supersecurehash"} err := userService.UpdateUserEmail(ctx, uuid, "newemail@test.com")
err := userService.UpdateUser(ctx, user)
AssertNoError(t, err) AssertNoError(t, err)
@ -113,13 +107,13 @@ func TestUpdateUser(t *testing.T) {
AssertNoError(t, err) AssertNoError(t, err)
AssertUsers(t, newUser, user) if newUser.Email != "newemail@test.com" {
t.Errorf("Emails do not match wanted %q, got %q", "newemail@test.com", newUser.Email)
}
}) })
t.Run("Update non-existant user", func(t *testing.T) { t.Run("Update non-existant user", func(t *testing.T) {
user := service.User{NewUUID(t), "new@email.com", "supersecurehash"} err := userService.UpdateUserEmail(ctx, NewUUID(t), "newemail@test.com")
err := userService.UpdateUser(ctx, user)
AssertErrors(t, err, service.ErrNotFound) AssertErrors(t, err, service.ErrNotFound)
}) })