From e55d419d44f6d9a0fc70c902039bc1485becd2b0 Mon Sep 17 00:00:00 2001 From: Michael Thomson Date: Thu, 22 May 2025 13:55:43 -0400 Subject: [PATCH] auth services, middleware, and other stuff --- .env.example | 2 + cmd/main.go | 58 +++++-- go.mod | 4 +- go.sum | 4 + internal/auth/handler/login.go | 66 ++++++++ internal/auth/handler/login_test.go | 168 +++++++++++++++++++ internal/auth/hashing.go | 70 ++++++++ internal/auth/service/service.go | 90 ++++++++++ internal/auth/service/service_test.go | 81 +++++++++ internal/logging/logging.go | 2 +- internal/middleware/auth.go | 60 +++++++ internal/middleware/{context.go => trace.go} | 5 +- internal/middleware/util.go | 2 + internal/migrate/migrations/1-init_db.sql | 5 +- internal/test/test_database.go | 2 +- internal/todo/service/service_test.go | 6 +- internal/user/handler/register.go | 67 ++++++++ internal/user/handler/register_test.go | 148 ++++++++++++++++ internal/user/repository/repository.go | 35 ++-- internal/user/repository/repository_test.go | 23 ++- internal/user/service/service.go | 128 ++++++++++---- internal/user/service/service_test.go | 54 +++--- 22 files changed, 985 insertions(+), 95 deletions(-) create mode 100644 .env.example create mode 100644 internal/auth/handler/login.go create mode 100644 internal/auth/handler/login_test.go create mode 100644 internal/auth/hashing.go create mode 100644 internal/auth/service/service.go create mode 100644 internal/auth/service/service_test.go create mode 100644 internal/middleware/auth.go rename internal/middleware/{context.go => trace.go} (72%) create mode 100644 internal/user/handler/register.go create mode 100644 internal/user/handler/register_test.go diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..b4044f1 --- /dev/null +++ b/.env.example @@ -0,0 +1,2 @@ +POSTGRESQL_CONNECTION_STRING=postgres://todo:password@localhost:5432/todo +JWT_SECRET_KEY="supersecretjwtkey" diff --git a/cmd/main.go b/cmd/main.go index 4875051..97d9da7 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -6,33 +6,59 @@ import ( "net/http" "os" + "gitea.michaelthomson.dev/mthomson/habits/internal/auth" + authhandler "gitea.michaelthomson.dev/mthomson/habits/internal/auth/handler" + authservice "gitea.michaelthomson.dev/mthomson/habits/internal/auth/service" "gitea.michaelthomson.dev/mthomson/habits/internal/logging" "gitea.michaelthomson.dev/mthomson/habits/internal/middleware" "gitea.michaelthomson.dev/mthomson/habits/internal/migrate" todohandler "gitea.michaelthomson.dev/mthomson/habits/internal/todo/handler" todorepository "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository" todoservice "gitea.michaelthomson.dev/mthomson/habits/internal/todo/service" + userhandler "gitea.michaelthomson.dev/mthomson/habits/internal/user/handler" + userrepository "gitea.michaelthomson.dev/mthomson/habits/internal/user/repository" + userservice "gitea.michaelthomson.dev/mthomson/habits/internal/user/service" + pgxuuid "github.com/jackc/pgx-gofrs-uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" - pgxuuid "github.com/jackc/pgx-gofrs-uuid" + "github.com/joho/godotenv" ) func main() { // create logger logger := logging.NewLogger() - // create middlewares - contextMiddleware := middleware.ContextMiddleware(logger) - loggingMiddleware := middleware.LoggingMiddleware(logger) + // load env + err := godotenv.Load() + if err != nil { + logger.Error(err.Error()) + os.Exit(1) + } - stack := []middleware.Middleware{ - contextMiddleware, + jwtSecretKey := os.Getenv("JWT_SECRET_KEY") + postgresqlConnectionString := os.Getenv("POSTGRESQL_CONNECTION_STRING") + + // create hasher instance + argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256) + + // create middlewares + traceMiddleware := middleware.TraceMiddleware(logger) + loggingMiddleware := middleware.LoggingMiddleware(logger) + authMiddleware := middleware.AuthMiddleware(logger, []byte(jwtSecretKey)) + + unauthenticatedStack := []middleware.Middleware{ + traceMiddleware, loggingMiddleware, } + authenticatedStack := []middleware.Middleware{ + traceMiddleware, + loggingMiddleware, + authMiddleware, + } + // create db pool - postgresUrl := "postgres://todo:password@localhost:5432/todo" - dbconfig, err := pgxpool.ParseConfig(postgresUrl) + dbconfig, err := pgxpool.ParseConfig(postgresqlConnectionString) if err != nil { logger.Error(err.Error()) os.Exit(1) @@ -54,18 +80,26 @@ func main() { // create repos todoRepository := todorepository.NewTodoRepository(logger, db) + userRepository := userrepository.NewUserRepository(logger, db) // create services todoService := todoservice.NewTodoService(logger, todoRepository) + userService := userservice.NewUserService(logger, userRepository, argon2IdHash) + authService := authservice.NewAuthService(logger, []byte(jwtSecretKey), userRepository, argon2IdHash) // create mux mux := http.NewServeMux() // register handlers - mux.Handle("GET /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoGet(logger, todoService), stack)) - mux.Handle("POST /todo", middleware.CompileMiddleware(todohandler.HandleTodoCreate(logger, todoService), stack)) - mux.Handle("DELETE /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoDelete(logger, todoService), stack)) - mux.Handle("PUT /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoUpdate(logger, todoService), stack)) + // auth + mux.Handle("POST /login", middleware.CompileMiddleware(authhandler.HandleLogin(logger, authService), unauthenticatedStack)) + // users + mux.Handle("POST /register", middleware.CompileMiddleware(userhandler.HandleRegisterUser(logger, userService), unauthenticatedStack)) + // todos + mux.Handle("GET /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoGet(logger, todoService), authenticatedStack)) + mux.Handle("POST /todo", middleware.CompileMiddleware(todohandler.HandleTodoCreate(logger, todoService), authenticatedStack)) + mux.Handle("DELETE /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoDelete(logger, todoService), authenticatedStack)) + mux.Handle("PUT /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoUpdate(logger, todoService), authenticatedStack)) // create server server := &http.Server{ diff --git a/go.mod b/go.mod index 016e2e7..3375f33 100644 --- a/go.mod +++ b/go.mod @@ -6,10 +6,13 @@ toolchain go1.23.9 require ( github.com/gofrs/uuid/v5 v5.3.2 + github.com/golang-jwt/jwt/v5 v5.2.2 github.com/jackc/pgx-gofrs-uuid v0.0.0-20230224015001-1d428863c2e2 github.com/jackc/pgx/v5 v5.7.4 + github.com/joho/godotenv v1.5.1 github.com/testcontainers/testcontainers-go v0.37.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.37.0 + golang.org/x/crypto v0.38.0 ) require ( @@ -64,7 +67,6 @@ require ( go.opentelemetry.io/otel v1.35.0 // indirect go.opentelemetry.io/otel/metric v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect - golang.org/x/crypto v0.38.0 // indirect golang.org/x/sync v0.14.0 // indirect golang.org/x/sys v0.33.0 // indirect golang.org/x/text v0.25.0 // indirect diff --git a/go.sum b/go.sum index 4d00626..2dcb4f7 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,8 @@ github.com/gofrs/uuid/v5 v5.3.2 h1:2jfO8j3XgSwlz/wHqemAEugfnTlikAYHhnqQ8Xh4fE0= github.com/gofrs/uuid/v5 v5.3.2/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -59,6 +61,8 @@ github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg= github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= diff --git a/internal/auth/handler/login.go b/internal/auth/handler/login.go new file mode 100644 index 0000000..03fac61 --- /dev/null +++ b/internal/auth/handler/login.go @@ -0,0 +1,66 @@ +package handler + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + + "gitea.michaelthomson.dev/mthomson/habits/internal/auth/service" +) + +type Loginer interface { + Login(ctx context.Context, email string, password string) (string, error) +} + +type LoginRequest struct { + Email string `json:"email"` + Password string `json:"password"` +} + +func HandleLogin(logger *slog.Logger, authService Loginer) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + loginRequest := LoginRequest{} + decoder := json.NewDecoder(r.Body) + decoder.DisallowUnknownFields() + err := decoder.Decode(&loginRequest) + + if err != nil { + logger.ErrorContext(ctx, err.Error()) + http.Error(w, "", http.StatusBadRequest) + return + } + + token, err := authService.Login(ctx, loginRequest.Email, loginRequest.Password) + + if err == service.ErrUnauthorized { + http.Error(w, "", http.StatusUnauthorized) + return + } + + if err == service.ErrNotFound { + http.Error(w, "", http.StatusUnauthorized) + return + } + + if err != nil { + logger.ErrorContext(ctx, err.Error()) + http.Error(w, "", http.StatusInternalServerError) + return + } + + cookie := http.Cookie{ + Name: "token", + Value: token, + Path: "/", + MaxAge: 3600, + HttpOnly: true, + Secure: false, + SameSite: http.SameSiteLaxMode, + } + http.SetCookie(w, &cookie) + + w.WriteHeader(http.StatusOK) + } +} diff --git a/internal/auth/handler/login_test.go b/internal/auth/handler/login_test.go new file mode 100644 index 0000000..0d3c389 --- /dev/null +++ b/internal/auth/handler/login_test.go @@ -0,0 +1,168 @@ +package handler + +import ( + "bytes" + "context" + "encoding/json" + "log/slog" + + "errors" + "net/http" + "net/http/httptest" + "testing" + + "gitea.michaelthomson.dev/mthomson/habits/internal/auth/service" +) + +type MockLoginer struct { + LoginFunc func(ctx context.Context, email string, password string) (string, error) +} + +func (m *MockLoginer) Login(ctx context.Context, email string, password string) (string, error) { + return m.LoginFunc(ctx, email, password) +} + +func TestLogin(t *testing.T) { + logger := slog.Default() + t.Run("returns 200 for existing user with correct credentials", func(t *testing.T) { + loginRequest := LoginRequest{Email: "test@test.com", Password: "password"} + + token := "examplejwt" + + service := MockLoginer{ + LoginFunc: func(ctx context.Context, email string, password string) (string, error) { + return token, nil + }, + } + + handler := HandleLogin(logger, &service) + + requestBody, err := json.Marshal(loginRequest) + + if err != nil { + t.Fatalf("Failed to marshal request %+v: %v", loginRequest, err) + } + + req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody)) + res := httptest.NewRecorder() + + handler(res, req) + + AssertStatusCodes(t, res.Code, http.StatusOK) + AssertToken(t, res, token) + }) + + t.Run("returns 401 for existing user with incorrect credentials", func(t *testing.T) { + loginRequest := LoginRequest{Email: "test@test.com", Password: "password"} + + service := MockLoginer{ + LoginFunc: func(ctx context.Context, email string, password string) (string, error) { + return "", service.ErrUnauthorized + }, + } + + handler := HandleLogin(logger, &service) + + requestBody, err := json.Marshal(loginRequest) + + if err != nil { + t.Fatalf("Failed to marshal request %+v: %v", loginRequest, err) + } + + req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody)) + res := httptest.NewRecorder() + + handler(res, req) + + AssertStatusCodes(t, res.Code, http.StatusUnauthorized) + }) + + t.Run("returns 401 for non-existing user", func(t *testing.T) { + loginRequest := LoginRequest{Email: "test@test.com", Password: "password"} + + service := MockLoginer{ + LoginFunc: func(ctx context.Context, email string, password string) (string, error) { + return "", service.ErrNotFound + }, + } + + handler := HandleLogin(logger, &service) + + requestBody, err := json.Marshal(loginRequest) + + if err != nil { + t.Fatalf("Failed to marshal request %+v: %v", loginRequest, err) + } + + req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody)) + res := httptest.NewRecorder() + + handler(res, req) + + AssertStatusCodes(t, res.Code, http.StatusUnauthorized) + }) + + t.Run("returns 400 with bad json", func(t *testing.T) { + handler := HandleLogin(logger, nil) + + badStruct := struct { + Foo string + }{ + Foo: "bar", + } + + requestBody, err := json.Marshal(badStruct) + + if err != nil { + t.Fatalf("Failed to marshal request %+v: %v", badStruct, err) + } + + req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody)) + res := httptest.NewRecorder() + + handler(res, req) + + AssertStatusCodes(t, res.Code, http.StatusBadRequest) + }) + + t.Run("returns 500 arbitrary errors", func(t *testing.T) { + loginRequest := LoginRequest{Email: "test@test.com", Password: "password"} + + service := MockLoginer{ + LoginFunc: func(ctx context.Context, email string, password string) (string, error) { + return "", errors.New("foo bar") + }, + } + + handler := HandleLogin(logger, &service) + + requestBody, err := json.Marshal(loginRequest) + + if err != nil { + t.Fatalf("Failed to marshal request %+v: %v", loginRequest, err) + } + + req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody)) + res := httptest.NewRecorder() + + handler(res, req) + + AssertStatusCodes(t, res.Code, http.StatusInternalServerError) + }) + +} + +func AssertStatusCodes(t testing.TB, got, want int) { + t.Helper() + if got != want { + t.Errorf("got status code: %v, want status code: %v", want, got) + } +} + +func AssertToken(t testing.TB, res *httptest.ResponseRecorder, want string) { + t.Helper() + got := res.Result().Cookies()[0].Value + if got != want { + t.Errorf("got cookie: %q, want cookie: %q", got, want) + } +} diff --git a/internal/auth/hashing.go b/internal/auth/hashing.go new file mode 100644 index 0000000..9610c9b --- /dev/null +++ b/internal/auth/hashing.go @@ -0,0 +1,70 @@ +package auth + +import ( + "bytes" + "crypto/rand" + "errors" + + "golang.org/x/crypto/argon2" +) + +type HashSalt struct { + Hash, Salt []byte +} + +type Argon2IdHash struct { + time uint32 + memory uint32 + threads uint8 + keyLen uint32 + saltLen uint32 +} + +var ( + ErrNoMatch error = errors.New("hash doesn't match") +) + +func NewArgon2IdHash(time, saltLen uint32, memory uint32, threads uint8, keyLen uint32) *Argon2IdHash { + return &Argon2IdHash{ + time: time, + saltLen: saltLen, + memory: memory, + threads: threads, + keyLen: keyLen, + } +} + +func (a *Argon2IdHash) GenerateHash(password, salt []byte) (*HashSalt, error) { + var err error + if len(salt) == 0 { + salt, err = randomSecret(a.saltLen) + } + if err != nil { + return nil, err + } + hash := argon2.IDKey(password, salt, a.time, a.memory, a.threads, a.keyLen) + return &HashSalt{Hash: hash, Salt: salt}, nil +} + +func (a *Argon2IdHash) Compare(hash, salt, password []byte) error { + hashSalt, err := a.GenerateHash(password, salt) + if err != nil { + return err + } + + if !bytes.Equal(hash, hashSalt.Hash) { + return ErrNoMatch + } + return nil +} + +func randomSecret(length uint32) ([]byte, error) { + secret := make([]byte, length) + + _, err := rand.Read(secret) + if err != nil { + return nil, err + } + + return secret, nil +} diff --git a/internal/auth/service/service.go b/internal/auth/service/service.go new file mode 100644 index 0000000..1dabbae --- /dev/null +++ b/internal/auth/service/service.go @@ -0,0 +1,90 @@ +package service + +import ( + "context" + "errors" + "log/slog" + "time" + + "gitea.michaelthomson.dev/mthomson/habits/internal/auth" + userrepository "gitea.michaelthomson.dev/mthomson/habits/internal/user/repository" + "github.com/golang-jwt/jwt/v5" +) + +var ( + ErrNotFound error = errors.New("user cannot be found") + ErrUnauthorized error = errors.New("user password incorrect") +) + +type UserRepository interface { + Create(ctx context.Context, user userrepository.UserRow) (userrepository.UserRow, error) + GetByEmail(ctx context.Context, email string) (userrepository.UserRow, error) +} + +type AuthService struct { + logger *slog.Logger + jwtKey []byte + userRepository UserRepository + argon2IdHash *auth.Argon2IdHash +} + +func NewAuthService(logger *slog.Logger, jwtKey []byte, userRepository UserRepository, argon2IdHash *auth.Argon2IdHash) *AuthService { + return &AuthService{ + logger: logger, + jwtKey: jwtKey, + userRepository: userRepository, + argon2IdHash: argon2IdHash, + } +} + +func (a AuthService) Login(ctx context.Context, email string, password string) (string, error) { + // get user if exists + userRow, err := a.userRepository.GetByEmail(ctx, email) + if err != nil { + if err == userrepository.ErrNotFound { + return "", ErrNotFound + } + + a.logger.ErrorContext(ctx, err.Error()) + return "", err + } + + // compare hashed passswords + err = a.argon2IdHash.Compare(userRow.HashedPassword, userRow.Salt, []byte(password)) + + if err == auth.ErrNoMatch { + return "", ErrUnauthorized + } + + if err != nil { + a.logger.ErrorContext(ctx, err.Error()) + return "", err + } + + // create token and return it + token, err := a.CreateToken(ctx, email) + if err != nil { + a.logger.ErrorContext(ctx, err.Error()) + return "", err + } + + return token, nil +} + +func (a AuthService) CreateToken(ctx context.Context, email string) (string, error) { + // Create a new JWT token with claims + claims := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": email, + "iss": "todo-app", + "aud": "user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }) + + tokenString, err := claims.SignedString(a.jwtKey) + if err != nil { + a.logger.ErrorContext(ctx, err.Error()) + } + + return tokenString, nil +} diff --git a/internal/auth/service/service_test.go b/internal/auth/service/service_test.go new file mode 100644 index 0000000..cffa456 --- /dev/null +++ b/internal/auth/service/service_test.go @@ -0,0 +1,81 @@ +package service_test + +import ( + "context" + "log/slog" + "testing" + + "gitea.michaelthomson.dev/mthomson/habits/internal/auth" + authservice "gitea.michaelthomson.dev/mthomson/habits/internal/auth/service" + "gitea.michaelthomson.dev/mthomson/habits/internal/test" + "gitea.michaelthomson.dev/mthomson/habits/internal/user/repository" + userservice "gitea.michaelthomson.dev/mthomson/habits/internal/user/service" + "github.com/gofrs/uuid/v5" +) + +func TestLogin(t *testing.T) { + t.Parallel() + ctx := context.Background() + logger := slog.Default() + tdb := test.NewTestDatabase(t) + defer tdb.TearDown() + userRepository := repository.NewUserRepository(logger, tdb.Db) + argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256) + + userService := userservice.NewUserService(logger, userRepository, argon2IdHash) + authService := authservice.NewAuthService(logger, []byte("secretkey"), userRepository, argon2IdHash) + + _, err := userService.Register(ctx, "test@test.com", "supersecurepassword") + AssertNoError(t, err) + + t.Run("login existing user with correct credentials", func(t *testing.T) { + want, err := authService.CreateToken(ctx, "test@test.com") + AssertNoError(t, err) + got, err := authService.Login(ctx, "test@test.com", "supersecurepassword") + + AssertNoError(t, err) + AssertTokens(t, got, want) + }) + + t.Run("login existing user with incorrect credentials", func(t *testing.T) { + _, err := authService.Login(ctx, "test@test.com", "superwrongpassword") + + AssertErrors(t, err, authservice.ErrUnauthorized) + }) + + t.Run("login nonexistant user", func(t *testing.T) { + _, err := authService.Login(ctx, "foo@test.com", "supersecurepassword") + + AssertErrors(t, err, authservice.ErrNotFound) + }) +} + +func AssertErrors(t testing.TB, got, want error) { + t.Helper() + if got != want { + t.Errorf("got error: %v, want error: %v", want, got) + } +} + +func AssertNoError(t testing.TB, err error) { + t.Helper() + if err != nil { + t.Errorf("expected no error, got %v", err) + } +} + +func AssertTokens(t testing.TB, got, want string) { + t.Helper() + if got != want { + t.Errorf("expected matching tokens, got %q, want %q", got, want) + } +} + +func NewUUID(t testing.TB) uuid.UUID { + t.Helper() + uuid, err := uuid.NewV4() + if err != nil { + t.Errorf("error generation uuid: %v", err) + } + return uuid +} diff --git a/internal/logging/logging.go b/internal/logging/logging.go index df6f811..d967cc6 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -20,7 +20,7 @@ func (h *ContextHandler) Handle(ctx context.Context, r slog.Record) error { } func NewLogger() *slog.Logger { - baseHandler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{AddSource: false}) + baseHandler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{AddSource: true}) customHandler := &ContextHandler{Handler: baseHandler} logger := slog.New(customHandler) diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go new file mode 100644 index 0000000..3728a95 --- /dev/null +++ b/internal/middleware/auth.go @@ -0,0 +1,60 @@ +package middleware + +import ( + "context" + "log/slog" + "net/http" + + "github.com/golang-jwt/jwt/v5" +) + +const EmailKey contextKey = "email" + +func AuthMiddleware(logger *slog.Logger, jwtKey []byte) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("token") + if err == http.ErrNoCookie { + logger.WarnContext(r.Context(), "token not provided") + w.WriteHeader(http.StatusUnauthorized) + return + } + + if err != nil { + logger.ErrorContext(r.Context(), err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } + + tokenString := cookie.Value + + token, err := jwt.Parse(tokenString, func(t *jwt.Token) (any, error) { + return jwtKey, nil + }) + + if !token.Valid { + logger.WarnContext(r.Context(), err.Error()) + w.WriteHeader(http.StatusUnauthorized) + return + } + + if err != nil { + logger.ErrorContext(r.Context(), err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } + + email, err := token.Claims.GetSubject() + if err != nil { + logger.ErrorContext(r.Context(), err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } + + ctx := context.WithValue(r.Context(), EmailKey, email) + newReq := r.WithContext(ctx) + + next.ServeHTTP(w, newReq) + }) + } +} diff --git a/internal/middleware/context.go b/internal/middleware/trace.go similarity index 72% rename from internal/middleware/context.go rename to internal/middleware/trace.go index 89f4b6e..3029add 100644 --- a/internal/middleware/context.go +++ b/internal/middleware/trace.go @@ -8,17 +8,16 @@ import ( "github.com/gofrs/uuid/v5" ) -type contextKey string const TraceIdKey contextKey = "trace_id" -func ContextMiddleware(logger *slog.Logger) func(http.Handler) http.Handler { +func TraceMiddleware(logger *slog.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { traceid, err := uuid.NewV4() if err != nil { logger.ErrorContext(r.Context(), err.Error()) } - ctx := context.WithValue(r.Context(), TraceIdKey, traceid) + ctx := context.WithValue(r.Context(), TraceIdKey, traceid.String()) newReq := r.WithContext(ctx) next.ServeHTTP(w, newReq) diff --git a/internal/middleware/util.go b/internal/middleware/util.go index 50ed4fb..cf06793 100644 --- a/internal/middleware/util.go +++ b/internal/middleware/util.go @@ -4,6 +4,8 @@ import "net/http" type Middleware func(http.Handler) http.Handler +type contextKey string + func CompileMiddleware(h http.Handler, m []Middleware) http.Handler { if len(m) < 1 { return h diff --git a/internal/migrate/migrations/1-init_db.sql b/internal/migrate/migrations/1-init_db.sql index bc2f837..79d92cc 100644 --- a/internal/migrate/migrations/1-init_db.sql +++ b/internal/migrate/migrations/1-init_db.sql @@ -6,6 +6,7 @@ CREATE TABLE todo( CREATE TABLE users( id uuid PRIMARY KEY DEFAULT gen_random_uuid(), - email VARCHAR NOT NULL, - hashed_password VARCHAR NOT NULL + email VARCHAR NOT NULL UNIQUE, + hashed_password bytea NOT NULL, + salt bytea NOT NULL ); diff --git a/internal/test/test_database.go b/internal/test/test_database.go index 126e9d2..1ca7562 100644 --- a/internal/test/test_database.go +++ b/internal/test/test_database.go @@ -7,12 +7,12 @@ import ( "time" "gitea.michaelthomson.dev/mthomson/habits/internal/migrate" + pgxuuid "github.com/jackc/pgx-gofrs-uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" - pgxuuid "github.com/jackc/pgx-gofrs-uuid" ) type TestDatabase struct { diff --git a/internal/todo/service/service_test.go b/internal/todo/service/service_test.go index 5a274d2..4f9ea5a 100644 --- a/internal/todo/service/service_test.go +++ b/internal/todo/service/service_test.go @@ -39,7 +39,7 @@ func TestGetTodo(t *testing.T) { row := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false} _, err := r.Create(ctx, row) - AssertNoError(t, err); + AssertNoError(t, err) todoService := service.NewTodoService(logger, r) @@ -66,7 +66,7 @@ func TestDeleteTodo(t *testing.T) { row := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false} _, err := r.Create(ctx, row) - AssertNoError(t, err); + AssertNoError(t, err) todoService := service.NewTodoService(logger, r) @@ -93,7 +93,7 @@ func TestUpdateTodo(t *testing.T) { row := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false} _, err := r.Create(ctx, row) - AssertNoError(t, err); + AssertNoError(t, err) todoService := service.NewTodoService(logger, r) diff --git a/internal/user/handler/register.go b/internal/user/handler/register.go new file mode 100644 index 0000000..98d9f1b --- /dev/null +++ b/internal/user/handler/register.go @@ -0,0 +1,67 @@ +package handler + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + + "gitea.michaelthomson.dev/mthomson/habits/internal/user/service" + "github.com/gofrs/uuid/v5" +) + +type UserRegisterer interface { + Register(ctx context.Context, email string, password string) (uuid.UUID, error) +} + +type RegisterRequest struct { + Email string `json:"email"` + Password string `json:"password"` +} + +type RegisterResponse struct { + Id string `json:"id"` +} + +func HandleRegisterUser(logger *slog.Logger, userService UserRegisterer) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + registerRequest := RegisterRequest{} + decoder := json.NewDecoder(r.Body) + decoder.DisallowUnknownFields() + err := decoder.Decode(®isterRequest) + + if err != nil { + logger.ErrorContext(ctx, err.Error()) + http.Error(w, "", http.StatusBadRequest) + return + } + + uuid, err := userService.Register(ctx, registerRequest.Email, registerRequest.Password) + + if err != nil { + if err == service.ErrUserExists { + http.Error(w, "", http.StatusConflict) + return + } + + logger.ErrorContext(ctx, err.Error()) + http.Error(w, "", http.StatusInternalServerError) + return + } + + response := RegisterResponse{uuid.String()} + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + err = json.NewEncoder(w).Encode(response) + + if err != nil { + logger.ErrorContext(ctx, err.Error()) + http.Error(w, "", http.StatusInternalServerError) + return + } + + } +} diff --git a/internal/user/handler/register_test.go b/internal/user/handler/register_test.go new file mode 100644 index 0000000..4cd1c43 --- /dev/null +++ b/internal/user/handler/register_test.go @@ -0,0 +1,148 @@ +package handler + +import ( + "bytes" + "context" + "encoding/json" + "log/slog" + + "errors" + "net/http" + "net/http/httptest" + "testing" + + "gitea.michaelthomson.dev/mthomson/habits/internal/user/service" + "github.com/gofrs/uuid/v5" +) + +type MockUserRegisterer struct { + RegisterUserFunc func(ctx context.Context, email string, password string) (uuid.UUID, error) +} + +func (m *MockUserRegisterer) Register(ctx context.Context, email string, password string) (uuid.UUID, error) { + return m.RegisterUserFunc(ctx, email, password) +} + +func TestCreateUser(t *testing.T) { + logger := slog.Default() + t.Run("create user", func(t *testing.T) { + createUserRequest := RegisterRequest{Email: "test@test.com", Password: "password"} + + newUUID := NewUUID(t) + + service := MockUserRegisterer{ + RegisterUserFunc: func(ctx context.Context, email string, password string) (uuid.UUID, error) { + return newUUID, nil + }, + } + + handler := HandleRegisterUser(logger, &service) + + requestBody, err := json.Marshal(createUserRequest) + + if err != nil { + t.Fatalf("Failed to marshal request %+v: %v", createUserRequest, err) + } + + req := httptest.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(requestBody)) + res := httptest.NewRecorder() + + handler(res, req) + + AssertStatusCodes(t, res.Code, http.StatusOK) + }) + + t.Run("returns 409 when user exists", func(t *testing.T) { + createUserRequest := RegisterRequest{Email: "test@test.com", Password: "password"} + + newUUID := NewUUID(t) + + service := MockUserRegisterer{ + RegisterUserFunc: func(ctx context.Context, email string, password string) (uuid.UUID, error) { + return newUUID, service.ErrUserExists + }, + } + + handler := HandleRegisterUser(logger, &service) + + requestBody, err := json.Marshal(createUserRequest) + + if err != nil { + t.Fatalf("Failed to marshal request %+v: %v", createUserRequest, err) + } + + req := httptest.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(requestBody)) + res := httptest.NewRecorder() + + handler(res, req) + + AssertStatusCodes(t, res.Code, http.StatusConflict) + }) + + t.Run("returns 400 with bad json", func(t *testing.T) { + handler := HandleRegisterUser(logger, nil) + + badStruct := struct { + Foo string + }{ + Foo: "bar", + } + + requestBody, err := json.Marshal(badStruct) + + if err != nil { + t.Fatalf("Failed to marshal request %+v: %v", badStruct, err) + } + + req := httptest.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(requestBody)) + res := httptest.NewRecorder() + + handler(res, req) + + AssertStatusCodes(t, res.Code, http.StatusBadRequest) + }) + + t.Run("returns 500 arbitrary errors", func(t *testing.T) { + createUserRequest := RegisterRequest{Email: "test@test.com", Password: "password"} + + newUUID := NewUUID(t) + + service := MockUserRegisterer{ + RegisterUserFunc: func(ctx context.Context, email string, password string) (uuid.UUID, error) { + return newUUID, errors.New("foo bar") + }, + } + + handler := HandleRegisterUser(logger, &service) + + requestBody, err := json.Marshal(createUserRequest) + + if err != nil { + t.Fatalf("Failed to marshal request %+v: %v", createUserRequest, err) + } + + req := httptest.NewRequest(http.MethodPost, "/user", bytes.NewBuffer(requestBody)) + res := httptest.NewRecorder() + + handler(res, req) + + AssertStatusCodes(t, res.Code, http.StatusInternalServerError) + }) + +} + +func AssertStatusCodes(t testing.TB, got, want int) { + t.Helper() + if got != want { + t.Errorf("got status code: %v, want status code: %v", want, got) + } +} + +func NewUUID(t testing.TB) uuid.UUID { + t.Helper() + uuid, err := uuid.NewV4() + if err != nil { + t.Errorf("error generation uuid: %v", err) + } + return uuid +} diff --git a/internal/user/repository/repository.go b/internal/user/repository/repository.go index 0dc657d..72dc086 100644 --- a/internal/user/repository/repository.go +++ b/internal/user/repository/repository.go @@ -1,6 +1,7 @@ package repository import ( + "bytes" "context" "errors" "log/slog" @@ -17,15 +18,12 @@ var ( type UserRow struct { Id uuid.UUID Email string - HashedPassword string -} - -func NewUserRow(id uuid.UUID, email string, hashedPassword string) UserRow { - return UserRow{Id: id, Email: email, HashedPassword: hashedPassword} + HashedPassword []byte + Salt []byte } func (u UserRow) Equal(user UserRow) bool { - return u.Id == user.Id && u.Email == user.Email && u.HashedPassword == user.HashedPassword + return u.Id == user.Id && u.Email == user.Email && bytes.Equal(u.HashedPassword, user.HashedPassword) && bytes.Equal(u.Salt, user.Salt) } type UserRepository struct { @@ -43,7 +41,24 @@ func NewUserRepository(logger *slog.Logger, db *pgxpool.Pool) *UserRepository { func (r *UserRepository) GetById(ctx context.Context, id uuid.UUID) (UserRow, error) { user := UserRow{} - err := r.db.QueryRow(ctx, "SELECT * FROM users WHERE id = $1;", id).Scan(&user.Id, &user.Email, &user.HashedPassword) + err := r.db.QueryRow(ctx, "SELECT * FROM users WHERE id = $1;", id).Scan(&user.Id, &user.Email, &user.HashedPassword, &user.Salt) + + if err != nil { + if err == pgx.ErrNoRows { + return user, ErrNotFound + } + + r.logger.ErrorContext(ctx, err.Error()) + return user, err + } + + return user, nil +} + +func (r *UserRepository) GetByEmail(ctx context.Context, email string) (UserRow, error) { + user := UserRow{} + + err := r.db.QueryRow(ctx, "SELECT * FROM users WHERE email = $1;", email).Scan(&user.Id, &user.Email, &user.HashedPassword, &user.Salt) if err != nil { if err == pgx.ErrNoRows { @@ -60,7 +75,7 @@ func (r *UserRepository) GetById(ctx context.Context, id uuid.UUID) (UserRow, er func (r *UserRepository) Create(ctx context.Context, user UserRow) (UserRow, error) { var result pgx.Row if user.Id.IsNil() { - result = r.db.QueryRow(ctx, "INSERT INTO users (email, hashed_password) VALUES ($1, $2) RETURNING id;", user.Email, user.HashedPassword) + result = r.db.QueryRow(ctx, "INSERT INTO users (email, hashed_password, salt) VALUES ($1, $2, $3) RETURNING id;", user.Email, user.HashedPassword, user.Salt) err := result.Scan(&user.Id) @@ -69,7 +84,7 @@ func (r *UserRepository) Create(ctx context.Context, user UserRow) (UserRow, err return UserRow{}, err } } else { - _, err := r.db.Exec(ctx, "INSERT INTO users (id, email, hashed_password) VALUES ($1, $2, $3);", user.Id, user.Email, user.HashedPassword) + _, err := r.db.Exec(ctx, "INSERT INTO users (id, email, hashed_password, salt) VALUES ($1, $2, $3, $4);", user.Id, user.Email, user.HashedPassword, user.Salt) if err != nil { r.logger.ErrorContext(ctx, err.Error()) @@ -81,7 +96,7 @@ func (r *UserRepository) Create(ctx context.Context, user UserRow) (UserRow, err } func (r *UserRepository) Update(ctx context.Context, user UserRow) error { - result, err := r.db.Exec(ctx, "UPDATE users SET email = $1, hashed_password = $2 WHERE id = $3;", user.Email, user.HashedPassword, user.Id) + result, err := r.db.Exec(ctx, "UPDATE users SET email = $1, hashed_password = $2, salt = $3 WHERE id = $4;", user.Email, user.HashedPassword, user.Salt, user.Id) if err != nil { r.logger.ErrorContext(ctx, err.Error()) diff --git a/internal/user/repository/repository_test.go b/internal/user/repository/repository_test.go index 0e8ccdc..79d3933 100644 --- a/internal/user/repository/repository_test.go +++ b/internal/user/repository/repository_test.go @@ -6,6 +6,7 @@ import ( "log/slog" "testing" + "gitea.michaelthomson.dev/mthomson/habits/internal/auth" "gitea.michaelthomson.dev/mthomson/habits/internal/test" "gitea.michaelthomson.dev/mthomson/habits/internal/user/repository" "github.com/gofrs/uuid/v5" @@ -18,22 +19,36 @@ func TestCRUD(t *testing.T) { defer tdb.TearDown() r := repository.NewUserRepository(logger, tdb.Db) uuid := NewUUID(t) + argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256) + + hashSalt, err := argon2IdHash.GenerateHash([]byte("supersecurepassword"), []byte("supersecuresalt")) + + if err != nil { + t.Errorf("could not generate hash: %v", err) + } t.Run("creates new user", func(t *testing.T) { - newUser := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"} + newUser := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt} _, err := r.Create(ctx, newUser) AssertNoError(t, err) }) - t.Run("gets user", func(t *testing.T) { - want := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"} + t.Run("gets user by id", func(t *testing.T) { + want := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt} got, err := r.GetById(ctx, uuid) AssertNoError(t, err) AssertUserRows(t, got, want) }) + t.Run("gets user by email", func(t *testing.T) { + want := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt} + got, err := r.GetByEmail(ctx, "test@test.com") + AssertNoError(t, err) + AssertUserRows(t, got, want) + }) + t.Run("updates user", func(t *testing.T) { - want := repository.UserRow{Id: uuid, Email: "new@test.com", HashedPassword: "supersecurehash"} + want := repository.UserRow{Id: uuid, Email: "new@test.com", HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt} err := r.Update(ctx, want) AssertNoError(t, err) diff --git a/internal/user/service/service.go b/internal/user/service/service.go index 22bed69..fac8c2e 100644 --- a/internal/user/service/service.go +++ b/internal/user/service/service.go @@ -5,52 +5,48 @@ import ( "errors" "log/slog" + "gitea.michaelthomson.dev/mthomson/habits/internal/auth" "gitea.michaelthomson.dev/mthomson/habits/internal/user/repository" "github.com/gofrs/uuid/v5" ) var ( - ErrNotFound error = errors.New("user cannot be found") + ErrNotFound error = errors.New("user cannot be found") + ErrUserExists error = errors.New("user already exists") ) type User struct { - Id uuid.UUID - Email string - HashedPassword string -} - -func NewUser(id uuid.UUID, email string, hashedPassword string) User { - return User{Id: id, Email: email, HashedPassword: hashedPassword} + Id uuid.UUID + Email string } func UserFromUserRow(userRow repository.UserRow) User { - return User{Id: userRow.Id, Email: userRow.Email, HashedPassword: userRow.HashedPassword} -} - -func UserRowFromUser(user User) repository.UserRow { - return repository.UserRow{Id: user.Id, Email: user.Email, HashedPassword: user.HashedPassword} + return User{Id: userRow.Id, Email: userRow.Email} } func (t User) Equal(user User) bool { - return t.Id == user.Id && t.Email == user.Email && t.HashedPassword == user.HashedPassword + return t.Id == user.Id && t.Email == user.Email } type UserRepository interface { Create(ctx context.Context, user repository.UserRow) (repository.UserRow, error) GetById(ctx context.Context, id uuid.UUID) (repository.UserRow, error) + GetByEmail(ctx context.Context, email string) (repository.UserRow, error) Update(ctx context.Context, user repository.UserRow) error Delete(ctx context.Context, id uuid.UUID) error } type UserService struct { - logger *slog.Logger - repo UserRepository + logger *slog.Logger + repo UserRepository + argon2IdHash *auth.Argon2IdHash } -func NewUserService(logger *slog.Logger, userRepo UserRepository) *UserService { +func NewUserService(logger *slog.Logger, userRepo UserRepository, argon2IdHash *auth.Argon2IdHash) *UserService { return &UserService{ - logger: logger, - repo: userRepo, + logger: logger, + repo: userRepo, + argon2IdHash: argon2IdHash, } } @@ -69,17 +65,37 @@ func (s *UserService) GetUser(ctx context.Context, id uuid.UUID) (User, error) { return UserFromUserRow(user), err } -func (s *UserService) CreateUser(ctx context.Context, user User) (User, error) { - userRow := UserRowFromUser(user) - - newUserRow, err := s.repo.Create(ctx, userRow) +func (s *UserService) Register(ctx context.Context, email string, password string) (uuid.UUID, error) { + uuid, err := uuid.NewV4() if err != nil { s.logger.ErrorContext(ctx, err.Error()) - return User{}, err + return uuid, err } - return UserFromUserRow(newUserRow), err + _, err = s.repo.GetByEmail(ctx, email) + + if err != repository.ErrNotFound { + return uuid, ErrUserExists + } + + hashSalt, err := s.argon2IdHash.GenerateHash([]byte(password), nil) + + if err != nil { + s.logger.ErrorContext(ctx, err.Error()) + return uuid, err + } + + userRow := repository.UserRow{Id: uuid, Email: email, HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt} + + _, err = s.repo.Create(ctx, userRow) + + if err != nil { + s.logger.ErrorContext(ctx, err.Error()) + return uuid, err + } + + return uuid, err } func (s *UserService) DeleteUser(ctx context.Context, id uuid.UUID) error { @@ -96,10 +112,66 @@ func (s *UserService) DeleteUser(ctx context.Context, id uuid.UUID) error { return err } -func (s *UserService) UpdateUser(ctx context.Context, user User) error { - userRow := UserRowFromUser(user) +func (s *UserService) UpdateUserEmail(ctx context.Context, id uuid.UUID, email string) error { + user, err := s.repo.GetById(ctx, id) - err := s.repo.Update(ctx, userRow) + if err == repository.ErrNotFound { + return ErrNotFound + } + + if err != nil { + s.logger.ErrorContext(ctx, err.Error()) + return err + } + + _, err = s.repo.GetByEmail(ctx, email) + + switch err { + case repository.ErrNotFound: + user.Email = email + + err = s.repo.Update(ctx, user) + + if err == repository.ErrNotFound { + return ErrNotFound + } + + if err != nil { + s.logger.ErrorContext(ctx, err.Error()) + } + + return err + case nil: + return ErrUserExists + default: + s.logger.ErrorContext(ctx, err.Error()) + return err + } +} + +func (s *UserService) UpdateUserPassword(ctx context.Context, id uuid.UUID, password string) error { + user, err := s.repo.GetById(ctx, id) + + if err == repository.ErrNotFound { + return ErrNotFound + } + + if err != nil { + s.logger.ErrorContext(ctx, err.Error()) + return err + } + + hashSalt, err := s.argon2IdHash.GenerateHash([]byte(password), nil) + + if err != nil { + s.logger.ErrorContext(ctx, err.Error()) + return err + } + + user.HashedPassword = hashSalt.Hash + user.Salt = hashSalt.Salt + + err = s.repo.Update(ctx, user) if err == repository.ErrNotFound { return ErrNotFound diff --git a/internal/user/service/service_test.go b/internal/user/service/service_test.go index 4e5ccbc..d63f67e 100644 --- a/internal/user/service/service_test.go +++ b/internal/user/service/service_test.go @@ -5,27 +5,26 @@ import ( "log/slog" "testing" + "gitea.michaelthomson.dev/mthomson/habits/internal/auth" "gitea.michaelthomson.dev/mthomson/habits/internal/test" "gitea.michaelthomson.dev/mthomson/habits/internal/user/repository" "gitea.michaelthomson.dev/mthomson/habits/internal/user/service" "github.com/gofrs/uuid/v5" ) -func TestCreateUser(t *testing.T) { +func TestRegisterUser(t *testing.T) { t.Parallel() ctx := context.Background() logger := slog.Default() tdb := test.NewTestDatabase(t) defer tdb.TearDown() r := repository.NewUserRepository(logger, tdb.Db) + argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256) - userService := service.NewUserService(logger, r) + userService := service.NewUserService(logger, r, argon2IdHash) t.Run("Create user", func(t *testing.T) { - uuid := NewUUID(t) - user := service.NewUser(uuid, "test@test.com", "supersecurehash") - - _, err := userService.CreateUser(ctx, user) + _, err := userService.Register(ctx, "test@test.com", "supersecurepassword") AssertNoError(t, err) }) @@ -38,13 +37,12 @@ func TestGetUser(t *testing.T) { tdb := test.NewTestDatabase(t) defer tdb.TearDown() r := repository.NewUserRepository(logger, tdb.Db) - uuid := NewUUID(t) + argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256) - row := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"} - _, err := r.Create(ctx, row) - AssertNoError(t, err); + userService := service.NewUserService(logger, r, argon2IdHash) - userService := service.NewUserService(logger, r) + uuid, err := userService.Register(ctx, "test@test.com", "supersecurepassword") + AssertNoError(t, err) t.Run("Get exisiting user", func(t *testing.T) { _, err := userService.GetUser(ctx, uuid) @@ -66,13 +64,12 @@ func TestDeleteUser(t *testing.T) { tdb := test.NewTestDatabase(t) defer tdb.TearDown() r := repository.NewUserRepository(logger, tdb.Db) - uuid := NewUUID(t) + argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256) - row := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"} - _, err := r.Create(ctx, row) - AssertNoError(t, err); + userService := service.NewUserService(logger, r, argon2IdHash) - userService := service.NewUserService(logger, r) + uuid, err := userService.Register(ctx, "test@test.com", "supersecurepassword") + AssertNoError(t, err) t.Run("Delete exisiting user", func(t *testing.T) { err := userService.DeleteUser(ctx, uuid) @@ -87,25 +84,22 @@ func TestDeleteUser(t *testing.T) { }) } -func TestUpdateUser(t *testing.T) { +func TestUpdateUserEmail(t *testing.T) { t.Parallel() ctx := context.Background() logger := slog.Default() tdb := test.NewTestDatabase(t) defer tdb.TearDown() r := repository.NewUserRepository(logger, tdb.Db) - uuid := NewUUID(t) + argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256) - row := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"} - _, err := r.Create(ctx, row) - AssertNoError(t, err); + userService := service.NewUserService(logger, r, argon2IdHash) - userService := service.NewUserService(logger, r) + uuid, err := userService.Register(ctx, "test@test.com", "supersecurepassword") + AssertNoError(t, err) - t.Run("Update exisiting user", func(t *testing.T) { - user := service.User{uuid, "new@email.com", "supersecurehash"} - - err := userService.UpdateUser(ctx, user) + t.Run("Update existing user email", func(t *testing.T) { + err := userService.UpdateUserEmail(ctx, uuid, "newemail@test.com") AssertNoError(t, err) @@ -113,13 +107,13 @@ func TestUpdateUser(t *testing.T) { AssertNoError(t, err) - AssertUsers(t, newUser, user) + if newUser.Email != "newemail@test.com" { + t.Errorf("Emails do not match wanted %q, got %q", "newemail@test.com", newUser.Email) + } }) t.Run("Update non-existant user", func(t *testing.T) { - user := service.User{NewUUID(t), "new@email.com", "supersecurehash"} - - err := userService.UpdateUser(ctx, user) + err := userService.UpdateUserEmail(ctx, NewUUID(t), "newemail@test.com") AssertErrors(t, err, service.ErrNotFound) })