auth services, middleware, and other stuff
This commit is contained in:
parent
70bb4e66b4
commit
e55d419d44
2
.env.example
Normal file
2
.env.example
Normal file
@ -0,0 +1,2 @@
|
||||
POSTGRESQL_CONNECTION_STRING=postgres://todo:password@localhost:5432/todo
|
||||
JWT_SECRET_KEY="supersecretjwtkey"
|
58
cmd/main.go
58
cmd/main.go
@ -6,33 +6,59 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/auth"
|
||||
authhandler "gitea.michaelthomson.dev/mthomson/habits/internal/auth/handler"
|
||||
authservice "gitea.michaelthomson.dev/mthomson/habits/internal/auth/service"
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/logging"
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/middleware"
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/migrate"
|
||||
todohandler "gitea.michaelthomson.dev/mthomson/habits/internal/todo/handler"
|
||||
todorepository "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
|
||||
todoservice "gitea.michaelthomson.dev/mthomson/habits/internal/todo/service"
|
||||
userhandler "gitea.michaelthomson.dev/mthomson/habits/internal/user/handler"
|
||||
userrepository "gitea.michaelthomson.dev/mthomson/habits/internal/user/repository"
|
||||
userservice "gitea.michaelthomson.dev/mthomson/habits/internal/user/service"
|
||||
pgxuuid "github.com/jackc/pgx-gofrs-uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
pgxuuid "github.com/jackc/pgx-gofrs-uuid"
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// create logger
|
||||
logger := logging.NewLogger()
|
||||
|
||||
// create middlewares
|
||||
contextMiddleware := middleware.ContextMiddleware(logger)
|
||||
loggingMiddleware := middleware.LoggingMiddleware(logger)
|
||||
// load env
|
||||
err := godotenv.Load()
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
stack := []middleware.Middleware{
|
||||
contextMiddleware,
|
||||
jwtSecretKey := os.Getenv("JWT_SECRET_KEY")
|
||||
postgresqlConnectionString := os.Getenv("POSTGRESQL_CONNECTION_STRING")
|
||||
|
||||
// create hasher instance
|
||||
argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
|
||||
|
||||
// create middlewares
|
||||
traceMiddleware := middleware.TraceMiddleware(logger)
|
||||
loggingMiddleware := middleware.LoggingMiddleware(logger)
|
||||
authMiddleware := middleware.AuthMiddleware(logger, []byte(jwtSecretKey))
|
||||
|
||||
unauthenticatedStack := []middleware.Middleware{
|
||||
traceMiddleware,
|
||||
loggingMiddleware,
|
||||
}
|
||||
|
||||
authenticatedStack := []middleware.Middleware{
|
||||
traceMiddleware,
|
||||
loggingMiddleware,
|
||||
authMiddleware,
|
||||
}
|
||||
|
||||
// create db pool
|
||||
postgresUrl := "postgres://todo:password@localhost:5432/todo"
|
||||
dbconfig, err := pgxpool.ParseConfig(postgresUrl)
|
||||
dbconfig, err := pgxpool.ParseConfig(postgresqlConnectionString)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
os.Exit(1)
|
||||
@ -54,18 +80,26 @@ func main() {
|
||||
|
||||
// create repos
|
||||
todoRepository := todorepository.NewTodoRepository(logger, db)
|
||||
userRepository := userrepository.NewUserRepository(logger, db)
|
||||
|
||||
// create services
|
||||
todoService := todoservice.NewTodoService(logger, todoRepository)
|
||||
userService := userservice.NewUserService(logger, userRepository, argon2IdHash)
|
||||
authService := authservice.NewAuthService(logger, []byte(jwtSecretKey), userRepository, argon2IdHash)
|
||||
|
||||
// create mux
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// register handlers
|
||||
mux.Handle("GET /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoGet(logger, todoService), stack))
|
||||
mux.Handle("POST /todo", middleware.CompileMiddleware(todohandler.HandleTodoCreate(logger, todoService), stack))
|
||||
mux.Handle("DELETE /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoDelete(logger, todoService), stack))
|
||||
mux.Handle("PUT /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoUpdate(logger, todoService), stack))
|
||||
// auth
|
||||
mux.Handle("POST /login", middleware.CompileMiddleware(authhandler.HandleLogin(logger, authService), unauthenticatedStack))
|
||||
// users
|
||||
mux.Handle("POST /register", middleware.CompileMiddleware(userhandler.HandleRegisterUser(logger, userService), unauthenticatedStack))
|
||||
// todos
|
||||
mux.Handle("GET /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoGet(logger, todoService), authenticatedStack))
|
||||
mux.Handle("POST /todo", middleware.CompileMiddleware(todohandler.HandleTodoCreate(logger, todoService), authenticatedStack))
|
||||
mux.Handle("DELETE /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoDelete(logger, todoService), authenticatedStack))
|
||||
mux.Handle("PUT /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoUpdate(logger, todoService), authenticatedStack))
|
||||
|
||||
// create server
|
||||
server := &http.Server{
|
||||
|
4
go.mod
4
go.mod
@ -6,10 +6,13 @@ toolchain go1.23.9
|
||||
|
||||
require (
|
||||
github.com/gofrs/uuid/v5 v5.3.2
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2
|
||||
github.com/jackc/pgx-gofrs-uuid v0.0.0-20230224015001-1d428863c2e2
|
||||
github.com/jackc/pgx/v5 v5.7.4
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/testcontainers/testcontainers-go v0.37.0
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.37.0
|
||||
golang.org/x/crypto v0.38.0
|
||||
)
|
||||
|
||||
require (
|
||||
@ -64,7 +67,6 @@ require (
|
||||
go.opentelemetry.io/otel v1.35.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.35.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
||||
golang.org/x/crypto v0.38.0 // indirect
|
||||
golang.org/x/sync v0.14.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/text v0.25.0 // indirect
|
||||
|
4
go.sum
4
go.sum
@ -43,6 +43,8 @@ github.com/gofrs/uuid/v5 v5.3.2 h1:2jfO8j3XgSwlz/wHqemAEugfnTlikAYHhnqQ8Xh4fE0=
|
||||
github.com/gofrs/uuid/v5 v5.3.2/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
@ -59,6 +61,8 @@ github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg=
|
||||
github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
|
66
internal/auth/handler/login.go
Normal file
66
internal/auth/handler/login.go
Normal file
@ -0,0 +1,66 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/auth/service"
|
||||
)
|
||||
|
||||
type Loginer interface {
|
||||
Login(ctx context.Context, email string, password string) (string, error)
|
||||
}
|
||||
|
||||
type LoginRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
func HandleLogin(logger *slog.Logger, authService Loginer) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
loginRequest := LoginRequest{}
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
decoder.DisallowUnknownFields()
|
||||
err := decoder.Decode(&loginRequest)
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err.Error())
|
||||
http.Error(w, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := authService.Login(ctx, loginRequest.Email, loginRequest.Password)
|
||||
|
||||
if err == service.ErrUnauthorized {
|
||||
http.Error(w, "", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if err == service.ErrNotFound {
|
||||
http.Error(w, "", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err.Error())
|
||||
http.Error(w, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
cookie := http.Cookie{
|
||||
Name: "token",
|
||||
Value: token,
|
||||
Path: "/",
|
||||
MaxAge: 3600,
|
||||
HttpOnly: true,
|
||||
Secure: false,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
http.SetCookie(w, &cookie)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
}
|
168
internal/auth/handler/login_test.go
Normal file
168
internal/auth/handler/login_test.go
Normal file
@ -0,0 +1,168 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/auth/service"
|
||||
)
|
||||
|
||||
type MockLoginer struct {
|
||||
LoginFunc func(ctx context.Context, email string, password string) (string, error)
|
||||
}
|
||||
|
||||
func (m *MockLoginer) Login(ctx context.Context, email string, password string) (string, error) {
|
||||
return m.LoginFunc(ctx, email, password)
|
||||
}
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
logger := slog.Default()
|
||||
t.Run("returns 200 for existing user with correct credentials", func(t *testing.T) {
|
||||
loginRequest := LoginRequest{Email: "test@test.com", Password: "password"}
|
||||
|
||||
token := "examplejwt"
|
||||
|
||||
service := MockLoginer{
|
||||
LoginFunc: func(ctx context.Context, email string, password string) (string, error) {
|
||||
return token, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := HandleLogin(logger, &service)
|
||||
|
||||
requestBody, err := json.Marshal(loginRequest)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request %+v: %v", loginRequest, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
handler(res, req)
|
||||
|
||||
AssertStatusCodes(t, res.Code, http.StatusOK)
|
||||
AssertToken(t, res, token)
|
||||
})
|
||||
|
||||
t.Run("returns 401 for existing user with incorrect credentials", func(t *testing.T) {
|
||||
loginRequest := LoginRequest{Email: "test@test.com", Password: "password"}
|
||||
|
||||
service := MockLoginer{
|
||||
LoginFunc: func(ctx context.Context, email string, password string) (string, error) {
|
||||
return "", service.ErrUnauthorized
|
||||
},
|
||||
}
|
||||
|
||||
handler := HandleLogin(logger, &service)
|
||||
|
||||
requestBody, err := json.Marshal(loginRequest)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request %+v: %v", loginRequest, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
handler(res, req)
|
||||
|
||||
AssertStatusCodes(t, res.Code, http.StatusUnauthorized)
|
||||
})
|
||||
|
||||
t.Run("returns 401 for non-existing user", func(t *testing.T) {
|
||||
loginRequest := LoginRequest{Email: "test@test.com", Password: "password"}
|
||||
|
||||
service := MockLoginer{
|
||||
LoginFunc: func(ctx context.Context, email string, password string) (string, error) {
|
||||
return "", service.ErrNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := HandleLogin(logger, &service)
|
||||
|
||||
requestBody, err := json.Marshal(loginRequest)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request %+v: %v", loginRequest, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
handler(res, req)
|
||||
|
||||
AssertStatusCodes(t, res.Code, http.StatusUnauthorized)
|
||||
})
|
||||
|
||||
t.Run("returns 400 with bad json", func(t *testing.T) {
|
||||
handler := HandleLogin(logger, nil)
|
||||
|
||||
badStruct := struct {
|
||||
Foo string
|
||||
}{
|
||||
Foo: "bar",
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(badStruct)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request %+v: %v", badStruct, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
handler(res, req)
|
||||
|
||||
AssertStatusCodes(t, res.Code, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("returns 500 arbitrary errors", func(t *testing.T) {
|
||||
loginRequest := LoginRequest{Email: "test@test.com", Password: "password"}
|
||||
|
||||
service := MockLoginer{
|
||||
LoginFunc: func(ctx context.Context, email string, password string) (string, error) {
|
||||
return "", errors.New("foo bar")
|
||||
},
|
||||
}
|
||||
|
||||
handler := HandleLogin(logger, &service)
|
||||
|
||||
requestBody, err := json.Marshal(loginRequest)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request %+v: %v", loginRequest, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
handler(res, req)
|
||||
|
||||
AssertStatusCodes(t, res.Code, http.StatusInternalServerError)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func AssertStatusCodes(t testing.TB, got, want int) {
|
||||
t.Helper()
|
||||
if got != want {
|
||||
t.Errorf("got status code: %v, want status code: %v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func AssertToken(t testing.TB, res *httptest.ResponseRecorder, want string) {
|
||||
t.Helper()
|
||||
got := res.Result().Cookies()[0].Value
|
||||
if got != want {
|
||||
t.Errorf("got cookie: %q, want cookie: %q", got, want)
|
||||
}
|
||||
}
|
70
internal/auth/hashing.go
Normal file
70
internal/auth/hashing.go
Normal file
@ -0,0 +1,70 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
type HashSalt struct {
|
||||
Hash, Salt []byte
|
||||
}
|
||||
|
||||
type Argon2IdHash struct {
|
||||
time uint32
|
||||
memory uint32
|
||||
threads uint8
|
||||
keyLen uint32
|
||||
saltLen uint32
|
||||
}
|
||||
|
||||
var (
|
||||
ErrNoMatch error = errors.New("hash doesn't match")
|
||||
)
|
||||
|
||||
func NewArgon2IdHash(time, saltLen uint32, memory uint32, threads uint8, keyLen uint32) *Argon2IdHash {
|
||||
return &Argon2IdHash{
|
||||
time: time,
|
||||
saltLen: saltLen,
|
||||
memory: memory,
|
||||
threads: threads,
|
||||
keyLen: keyLen,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Argon2IdHash) GenerateHash(password, salt []byte) (*HashSalt, error) {
|
||||
var err error
|
||||
if len(salt) == 0 {
|
||||
salt, err = randomSecret(a.saltLen)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hash := argon2.IDKey(password, salt, a.time, a.memory, a.threads, a.keyLen)
|
||||
return &HashSalt{Hash: hash, Salt: salt}, nil
|
||||
}
|
||||
|
||||
func (a *Argon2IdHash) Compare(hash, salt, password []byte) error {
|
||||
hashSalt, err := a.GenerateHash(password, salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !bytes.Equal(hash, hashSalt.Hash) {
|
||||
return ErrNoMatch
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func randomSecret(length uint32) ([]byte, error) {
|
||||
secret := make([]byte, length)
|
||||
|
||||
_, err := rand.Read(secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return secret, nil
|
||||
}
|
90
internal/auth/service/service.go
Normal file
90
internal/auth/service/service.go
Normal file
@ -0,0 +1,90 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/auth"
|
||||
userrepository "gitea.michaelthomson.dev/mthomson/habits/internal/user/repository"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotFound error = errors.New("user cannot be found")
|
||||
ErrUnauthorized error = errors.New("user password incorrect")
|
||||
)
|
||||
|
||||
type UserRepository interface {
|
||||
Create(ctx context.Context, user userrepository.UserRow) (userrepository.UserRow, error)
|
||||
GetByEmail(ctx context.Context, email string) (userrepository.UserRow, error)
|
||||
}
|
||||
|
||||
type AuthService struct {
|
||||
logger *slog.Logger
|
||||
jwtKey []byte
|
||||
userRepository UserRepository
|
||||
argon2IdHash *auth.Argon2IdHash
|
||||
}
|
||||
|
||||
func NewAuthService(logger *slog.Logger, jwtKey []byte, userRepository UserRepository, argon2IdHash *auth.Argon2IdHash) *AuthService {
|
||||
return &AuthService{
|
||||
logger: logger,
|
||||
jwtKey: jwtKey,
|
||||
userRepository: userRepository,
|
||||
argon2IdHash: argon2IdHash,
|
||||
}
|
||||
}
|
||||
|
||||
func (a AuthService) Login(ctx context.Context, email string, password string) (string, error) {
|
||||
// get user if exists
|
||||
userRow, err := a.userRepository.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if err == userrepository.ErrNotFound {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
|
||||
a.logger.ErrorContext(ctx, err.Error())
|
||||
return "", err
|
||||
}
|
||||
|
||||
// compare hashed passswords
|
||||
err = a.argon2IdHash.Compare(userRow.HashedPassword, userRow.Salt, []byte(password))
|
||||
|
||||
if err == auth.ErrNoMatch {
|
||||
return "", ErrUnauthorized
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
a.logger.ErrorContext(ctx, err.Error())
|
||||
return "", err
|
||||
}
|
||||
|
||||
// create token and return it
|
||||
token, err := a.CreateToken(ctx, email)
|
||||
if err != nil {
|
||||
a.logger.ErrorContext(ctx, err.Error())
|
||||
return "", err
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (a AuthService) CreateToken(ctx context.Context, email string) (string, error) {
|
||||
// Create a new JWT token with claims
|
||||
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": email,
|
||||
"iss": "todo-app",
|
||||
"aud": "user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
})
|
||||
|
||||
tokenString, err := claims.SignedString(a.jwtKey)
|
||||
if err != nil {
|
||||
a.logger.ErrorContext(ctx, err.Error())
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
81
internal/auth/service/service_test.go
Normal file
81
internal/auth/service/service_test.go
Normal file
@ -0,0 +1,81 @@
|
||||
package service_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"testing"
|
||||
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/auth"
|
||||
authservice "gitea.michaelthomson.dev/mthomson/habits/internal/auth/service"
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/test"
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/user/repository"
|
||||
userservice "gitea.michaelthomson.dev/mthomson/habits/internal/user/service"
|
||||
"github.com/gofrs/uuid/v5"
|
||||
)
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slog.Default()
|
||||
tdb := test.NewTestDatabase(t)
|
||||
defer tdb.TearDown()
|
||||
userRepository := repository.NewUserRepository(logger, tdb.Db)
|
||||
argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
|
||||
|
||||
userService := userservice.NewUserService(logger, userRepository, argon2IdHash)
|
||||
authService := authservice.NewAuthService(logger, []byte("secretkey"), userRepository, argon2IdHash)
|
||||
|
||||
_, err := userService.Register(ctx, "test@test.com", "supersecurepassword")
|
||||
AssertNoError(t, err)
|
||||
|
||||
t.Run("login existing user with correct credentials", func(t *testing.T) {
|
||||
want, err := authService.CreateToken(ctx, "test@test.com")
|
||||
AssertNoError(t, err)
|
||||
got, err := authService.Login(ctx, "test@test.com", "supersecurepassword")
|
||||
|
||||
AssertNoError(t, err)
|
||||
AssertTokens(t, got, want)
|
||||
})
|
||||
|
||||
t.Run("login existing user with incorrect credentials", func(t *testing.T) {
|
||||
_, err := authService.Login(ctx, "test@test.com", "superwrongpassword")
|
||||
|
||||
AssertErrors(t, err, authservice.ErrUnauthorized)
|
||||
})
|
||||
|
||||
t.Run("login nonexistant user", func(t *testing.T) {
|
||||
_, err := authService.Login(ctx, "foo@test.com", "supersecurepassword")
|
||||
|
||||
AssertErrors(t, err, authservice.ErrNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func AssertErrors(t testing.TB, got, want error) {
|
||||
t.Helper()
|
||||
if got != want {
|
||||
t.Errorf("got error: %v, want error: %v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func AssertNoError(t testing.TB, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func AssertTokens(t testing.TB, got, want string) {
|
||||
t.Helper()
|
||||
if got != want {
|
||||
t.Errorf("expected matching tokens, got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func NewUUID(t testing.TB) uuid.UUID {
|
||||
t.Helper()
|
||||
uuid, err := uuid.NewV4()
|
||||
if err != nil {
|
||||
t.Errorf("error generation uuid: %v", err)
|
||||
}
|
||||
return uuid
|
||||
}
|
@ -20,7 +20,7 @@ func (h *ContextHandler) Handle(ctx context.Context, r slog.Record) error {
|
||||
}
|
||||
|
||||
func NewLogger() *slog.Logger {
|
||||
baseHandler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{AddSource: false})
|
||||
baseHandler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{AddSource: true})
|
||||
customHandler := &ContextHandler{Handler: baseHandler}
|
||||
logger := slog.New(customHandler)
|
||||
|
||||
|
60
internal/middleware/auth.go
Normal file
60
internal/middleware/auth.go
Normal file
@ -0,0 +1,60 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
const EmailKey contextKey = "email"
|
||||
|
||||
func AuthMiddleware(logger *slog.Logger, jwtKey []byte) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("token")
|
||||
if err == http.ErrNoCookie {
|
||||
logger.WarnContext(r.Context(), "token not provided")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorContext(r.Context(), err.Error())
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := cookie.Value
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(t *jwt.Token) (any, error) {
|
||||
return jwtKey, nil
|
||||
})
|
||||
|
||||
if !token.Valid {
|
||||
logger.WarnContext(r.Context(), err.Error())
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorContext(r.Context(), err.Error())
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
email, err := token.Claims.GetSubject()
|
||||
if err != nil {
|
||||
logger.ErrorContext(r.Context(), err.Error())
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), EmailKey, email)
|
||||
newReq := r.WithContext(ctx)
|
||||
|
||||
next.ServeHTTP(w, newReq)
|
||||
})
|
||||
}
|
||||
}
|
@ -8,17 +8,16 @@ import (
|
||||
"github.com/gofrs/uuid/v5"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
const TraceIdKey contextKey = "trace_id"
|
||||
|
||||
func ContextMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
|
||||
func TraceMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
traceid, err := uuid.NewV4()
|
||||
if err != nil {
|
||||
logger.ErrorContext(r.Context(), err.Error())
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), TraceIdKey, traceid)
|
||||
ctx := context.WithValue(r.Context(), TraceIdKey, traceid.String())
|
||||
newReq := r.WithContext(ctx)
|
||||
|
||||
next.ServeHTTP(w, newReq)
|
@ -4,6 +4,8 @@ import "net/http"
|
||||
|
||||
type Middleware func(http.Handler) http.Handler
|
||||
|
||||
type contextKey string
|
||||
|
||||
func CompileMiddleware(h http.Handler, m []Middleware) http.Handler {
|
||||
if len(m) < 1 {
|
||||
return h
|
||||
|
@ -6,6 +6,7 @@ CREATE TABLE todo(
|
||||
|
||||
CREATE TABLE users(
|
||||
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
email VARCHAR NOT NULL,
|
||||
hashed_password VARCHAR NOT NULL
|
||||
email VARCHAR NOT NULL UNIQUE,
|
||||
hashed_password bytea NOT NULL,
|
||||
salt bytea NOT NULL
|
||||
);
|
||||
|
@ -7,12 +7,12 @@ import (
|
||||
"time"
|
||||
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/migrate"
|
||||
pgxuuid "github.com/jackc/pgx-gofrs-uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
pgxuuid "github.com/jackc/pgx-gofrs-uuid"
|
||||
)
|
||||
|
||||
type TestDatabase struct {
|
||||
|
@ -39,7 +39,7 @@ func TestGetTodo(t *testing.T) {
|
||||
|
||||
row := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false}
|
||||
_, err := r.Create(ctx, row)
|
||||
AssertNoError(t, err);
|
||||
AssertNoError(t, err)
|
||||
|
||||
todoService := service.NewTodoService(logger, r)
|
||||
|
||||
@ -66,7 +66,7 @@ func TestDeleteTodo(t *testing.T) {
|
||||
|
||||
row := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false}
|
||||
_, err := r.Create(ctx, row)
|
||||
AssertNoError(t, err);
|
||||
AssertNoError(t, err)
|
||||
|
||||
todoService := service.NewTodoService(logger, r)
|
||||
|
||||
@ -93,7 +93,7 @@ func TestUpdateTodo(t *testing.T) {
|
||||
|
||||
row := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false}
|
||||
_, err := r.Create(ctx, row)
|
||||
AssertNoError(t, err);
|
||||
AssertNoError(t, err)
|
||||
|
||||
todoService := service.NewTodoService(logger, r)
|
||||
|
||||
|
67
internal/user/handler/register.go
Normal file
67
internal/user/handler/register.go
Normal file
@ -0,0 +1,67 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/user/service"
|
||||
"github.com/gofrs/uuid/v5"
|
||||
)
|
||||
|
||||
type UserRegisterer interface {
|
||||
Register(ctx context.Context, email string, password string) (uuid.UUID, error)
|
||||
}
|
||||
|
||||
type RegisterRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type RegisterResponse struct {
|
||||
Id string `json:"id"`
|
||||
}
|
||||
|
||||
func HandleRegisterUser(logger *slog.Logger, userService UserRegisterer) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
registerRequest := RegisterRequest{}
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
decoder.DisallowUnknownFields()
|
||||
err := decoder.Decode(®isterRequest)
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err.Error())
|
||||
http.Error(w, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
uuid, err := userService.Register(ctx, registerRequest.Email, registerRequest.Password)
|
||||
|
||||
if err != nil {
|
||||
if err == service.ErrUserExists {
|
||||
http.Error(w, "", http.StatusConflict)
|
||||
return
|
||||
}
|
||||
|
||||
logger.ErrorContext(ctx, err.Error())
|
||||
http.Error(w, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
response := RegisterResponse{uuid.String()}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
err = json.NewEncoder(w).Encode(response)
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err.Error())
|
||||
http.Error(w, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
}
|
148
internal/user/handler/register_test.go
Normal file
148
internal/user/handler/register_test.go
Normal file
@ -0,0 +1,148 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/user/service"
|
||||
"github.com/gofrs/uuid/v5"
|
||||
)
|
||||
|
||||
type MockUserRegisterer struct {
|
||||
RegisterUserFunc func(ctx context.Context, email string, password string) (uuid.UUID, error)
|
||||
}
|
||||
|
||||
func (m *MockUserRegisterer) Register(ctx context.Context, email string, password string) (uuid.UUID, error) {
|
||||
return m.RegisterUserFunc(ctx, email, password)
|
||||
}
|
||||
|
||||
func TestCreateUser(t *testing.T) {
|
||||
logger := slog.Default()
|
||||
t.Run("create user", func(t *testing.T) {
|
||||
createUserRequest := RegisterRequest{Email: "test@test.com", Password: "password"}
|
||||
|
||||
newUUID := NewUUID(t)
|
||||
|
||||
service := MockUserRegisterer{
|
||||
RegisterUserFunc: func(ctx context.Context, email string, password string) (uuid.UUID, error) {
|
||||
return newUUID, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := HandleRegisterUser(logger, &service)
|
||||
|
||||
requestBody, err := json.Marshal(createUserRequest)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request %+v: %v", createUserRequest, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(requestBody))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
handler(res, req)
|
||||
|
||||
AssertStatusCodes(t, res.Code, http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("returns 409 when user exists", func(t *testing.T) {
|
||||
createUserRequest := RegisterRequest{Email: "test@test.com", Password: "password"}
|
||||
|
||||
newUUID := NewUUID(t)
|
||||
|
||||
service := MockUserRegisterer{
|
||||
RegisterUserFunc: func(ctx context.Context, email string, password string) (uuid.UUID, error) {
|
||||
return newUUID, service.ErrUserExists
|
||||
},
|
||||
}
|
||||
|
||||
handler := HandleRegisterUser(logger, &service)
|
||||
|
||||
requestBody, err := json.Marshal(createUserRequest)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request %+v: %v", createUserRequest, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(requestBody))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
handler(res, req)
|
||||
|
||||
AssertStatusCodes(t, res.Code, http.StatusConflict)
|
||||
})
|
||||
|
||||
t.Run("returns 400 with bad json", func(t *testing.T) {
|
||||
handler := HandleRegisterUser(logger, nil)
|
||||
|
||||
badStruct := struct {
|
||||
Foo string
|
||||
}{
|
||||
Foo: "bar",
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(badStruct)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request %+v: %v", badStruct, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(requestBody))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
handler(res, req)
|
||||
|
||||
AssertStatusCodes(t, res.Code, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("returns 500 arbitrary errors", func(t *testing.T) {
|
||||
createUserRequest := RegisterRequest{Email: "test@test.com", Password: "password"}
|
||||
|
||||
newUUID := NewUUID(t)
|
||||
|
||||
service := MockUserRegisterer{
|
||||
RegisterUserFunc: func(ctx context.Context, email string, password string) (uuid.UUID, error) {
|
||||
return newUUID, errors.New("foo bar")
|
||||
},
|
||||
}
|
||||
|
||||
handler := HandleRegisterUser(logger, &service)
|
||||
|
||||
requestBody, err := json.Marshal(createUserRequest)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request %+v: %v", createUserRequest, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/user", bytes.NewBuffer(requestBody))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
handler(res, req)
|
||||
|
||||
AssertStatusCodes(t, res.Code, http.StatusInternalServerError)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func AssertStatusCodes(t testing.TB, got, want int) {
|
||||
t.Helper()
|
||||
if got != want {
|
||||
t.Errorf("got status code: %v, want status code: %v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func NewUUID(t testing.TB) uuid.UUID {
|
||||
t.Helper()
|
||||
uuid, err := uuid.NewV4()
|
||||
if err != nil {
|
||||
t.Errorf("error generation uuid: %v", err)
|
||||
}
|
||||
return uuid
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
@ -17,15 +18,12 @@ var (
|
||||
type UserRow struct {
|
||||
Id uuid.UUID
|
||||
Email string
|
||||
HashedPassword string
|
||||
}
|
||||
|
||||
func NewUserRow(id uuid.UUID, email string, hashedPassword string) UserRow {
|
||||
return UserRow{Id: id, Email: email, HashedPassword: hashedPassword}
|
||||
HashedPassword []byte
|
||||
Salt []byte
|
||||
}
|
||||
|
||||
func (u UserRow) Equal(user UserRow) bool {
|
||||
return u.Id == user.Id && u.Email == user.Email && u.HashedPassword == user.HashedPassword
|
||||
return u.Id == user.Id && u.Email == user.Email && bytes.Equal(u.HashedPassword, user.HashedPassword) && bytes.Equal(u.Salt, user.Salt)
|
||||
}
|
||||
|
||||
type UserRepository struct {
|
||||
@ -43,7 +41,24 @@ func NewUserRepository(logger *slog.Logger, db *pgxpool.Pool) *UserRepository {
|
||||
func (r *UserRepository) GetById(ctx context.Context, id uuid.UUID) (UserRow, error) {
|
||||
user := UserRow{}
|
||||
|
||||
err := r.db.QueryRow(ctx, "SELECT * FROM users WHERE id = $1;", id).Scan(&user.Id, &user.Email, &user.HashedPassword)
|
||||
err := r.db.QueryRow(ctx, "SELECT * FROM users WHERE id = $1;", id).Scan(&user.Id, &user.Email, &user.HashedPassword, &user.Salt)
|
||||
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return user, ErrNotFound
|
||||
}
|
||||
|
||||
r.logger.ErrorContext(ctx, err.Error())
|
||||
return user, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (UserRow, error) {
|
||||
user := UserRow{}
|
||||
|
||||
err := r.db.QueryRow(ctx, "SELECT * FROM users WHERE email = $1;", email).Scan(&user.Id, &user.Email, &user.HashedPassword, &user.Salt)
|
||||
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
@ -60,7 +75,7 @@ func (r *UserRepository) GetById(ctx context.Context, id uuid.UUID) (UserRow, er
|
||||
func (r *UserRepository) Create(ctx context.Context, user UserRow) (UserRow, error) {
|
||||
var result pgx.Row
|
||||
if user.Id.IsNil() {
|
||||
result = r.db.QueryRow(ctx, "INSERT INTO users (email, hashed_password) VALUES ($1, $2) RETURNING id;", user.Email, user.HashedPassword)
|
||||
result = r.db.QueryRow(ctx, "INSERT INTO users (email, hashed_password, salt) VALUES ($1, $2, $3) RETURNING id;", user.Email, user.HashedPassword, user.Salt)
|
||||
|
||||
err := result.Scan(&user.Id)
|
||||
|
||||
@ -69,7 +84,7 @@ func (r *UserRepository) Create(ctx context.Context, user UserRow) (UserRow, err
|
||||
return UserRow{}, err
|
||||
}
|
||||
} else {
|
||||
_, err := r.db.Exec(ctx, "INSERT INTO users (id, email, hashed_password) VALUES ($1, $2, $3);", user.Id, user.Email, user.HashedPassword)
|
||||
_, err := r.db.Exec(ctx, "INSERT INTO users (id, email, hashed_password, salt) VALUES ($1, $2, $3, $4);", user.Id, user.Email, user.HashedPassword, user.Salt)
|
||||
|
||||
if err != nil {
|
||||
r.logger.ErrorContext(ctx, err.Error())
|
||||
@ -81,7 +96,7 @@ func (r *UserRepository) Create(ctx context.Context, user UserRow) (UserRow, err
|
||||
}
|
||||
|
||||
func (r *UserRepository) Update(ctx context.Context, user UserRow) error {
|
||||
result, err := r.db.Exec(ctx, "UPDATE users SET email = $1, hashed_password = $2 WHERE id = $3;", user.Email, user.HashedPassword, user.Id)
|
||||
result, err := r.db.Exec(ctx, "UPDATE users SET email = $1, hashed_password = $2, salt = $3 WHERE id = $4;", user.Email, user.HashedPassword, user.Salt, user.Id)
|
||||
|
||||
if err != nil {
|
||||
r.logger.ErrorContext(ctx, err.Error())
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"log/slog"
|
||||
"testing"
|
||||
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/auth"
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/test"
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/user/repository"
|
||||
"github.com/gofrs/uuid/v5"
|
||||
@ -18,22 +19,36 @@ func TestCRUD(t *testing.T) {
|
||||
defer tdb.TearDown()
|
||||
r := repository.NewUserRepository(logger, tdb.Db)
|
||||
uuid := NewUUID(t)
|
||||
argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
|
||||
|
||||
hashSalt, err := argon2IdHash.GenerateHash([]byte("supersecurepassword"), []byte("supersecuresalt"))
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("could not generate hash: %v", err)
|
||||
}
|
||||
|
||||
t.Run("creates new user", func(t *testing.T) {
|
||||
newUser := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"}
|
||||
newUser := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt}
|
||||
_, err := r.Create(ctx, newUser)
|
||||
AssertNoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("gets user", func(t *testing.T) {
|
||||
want := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"}
|
||||
t.Run("gets user by id", func(t *testing.T) {
|
||||
want := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt}
|
||||
got, err := r.GetById(ctx, uuid)
|
||||
AssertNoError(t, err)
|
||||
AssertUserRows(t, got, want)
|
||||
})
|
||||
|
||||
t.Run("gets user by email", func(t *testing.T) {
|
||||
want := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt}
|
||||
got, err := r.GetByEmail(ctx, "test@test.com")
|
||||
AssertNoError(t, err)
|
||||
AssertUserRows(t, got, want)
|
||||
})
|
||||
|
||||
t.Run("updates user", func(t *testing.T) {
|
||||
want := repository.UserRow{Id: uuid, Email: "new@test.com", HashedPassword: "supersecurehash"}
|
||||
want := repository.UserRow{Id: uuid, Email: "new@test.com", HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt}
|
||||
err := r.Update(ctx, want)
|
||||
AssertNoError(t, err)
|
||||
|
||||
|
@ -5,52 +5,48 @@ import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/auth"
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/user/repository"
|
||||
"github.com/gofrs/uuid/v5"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotFound error = errors.New("user cannot be found")
|
||||
ErrNotFound error = errors.New("user cannot be found")
|
||||
ErrUserExists error = errors.New("user already exists")
|
||||
)
|
||||
|
||||
type User struct {
|
||||
Id uuid.UUID
|
||||
Email string
|
||||
HashedPassword string
|
||||
}
|
||||
|
||||
func NewUser(id uuid.UUID, email string, hashedPassword string) User {
|
||||
return User{Id: id, Email: email, HashedPassword: hashedPassword}
|
||||
Id uuid.UUID
|
||||
Email string
|
||||
}
|
||||
|
||||
func UserFromUserRow(userRow repository.UserRow) User {
|
||||
return User{Id: userRow.Id, Email: userRow.Email, HashedPassword: userRow.HashedPassword}
|
||||
}
|
||||
|
||||
func UserRowFromUser(user User) repository.UserRow {
|
||||
return repository.UserRow{Id: user.Id, Email: user.Email, HashedPassword: user.HashedPassword}
|
||||
return User{Id: userRow.Id, Email: userRow.Email}
|
||||
}
|
||||
|
||||
func (t User) Equal(user User) bool {
|
||||
return t.Id == user.Id && t.Email == user.Email && t.HashedPassword == user.HashedPassword
|
||||
return t.Id == user.Id && t.Email == user.Email
|
||||
}
|
||||
|
||||
type UserRepository interface {
|
||||
Create(ctx context.Context, user repository.UserRow) (repository.UserRow, error)
|
||||
GetById(ctx context.Context, id uuid.UUID) (repository.UserRow, error)
|
||||
GetByEmail(ctx context.Context, email string) (repository.UserRow, error)
|
||||
Update(ctx context.Context, user repository.UserRow) error
|
||||
Delete(ctx context.Context, id uuid.UUID) error
|
||||
}
|
||||
|
||||
type UserService struct {
|
||||
logger *slog.Logger
|
||||
repo UserRepository
|
||||
logger *slog.Logger
|
||||
repo UserRepository
|
||||
argon2IdHash *auth.Argon2IdHash
|
||||
}
|
||||
|
||||
func NewUserService(logger *slog.Logger, userRepo UserRepository) *UserService {
|
||||
func NewUserService(logger *slog.Logger, userRepo UserRepository, argon2IdHash *auth.Argon2IdHash) *UserService {
|
||||
return &UserService{
|
||||
logger: logger,
|
||||
repo: userRepo,
|
||||
logger: logger,
|
||||
repo: userRepo,
|
||||
argon2IdHash: argon2IdHash,
|
||||
}
|
||||
}
|
||||
|
||||
@ -69,17 +65,37 @@ func (s *UserService) GetUser(ctx context.Context, id uuid.UUID) (User, error) {
|
||||
return UserFromUserRow(user), err
|
||||
}
|
||||
|
||||
func (s *UserService) CreateUser(ctx context.Context, user User) (User, error) {
|
||||
userRow := UserRowFromUser(user)
|
||||
|
||||
newUserRow, err := s.repo.Create(ctx, userRow)
|
||||
func (s *UserService) Register(ctx context.Context, email string, password string) (uuid.UUID, error) {
|
||||
uuid, err := uuid.NewV4()
|
||||
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, err.Error())
|
||||
return User{}, err
|
||||
return uuid, err
|
||||
}
|
||||
|
||||
return UserFromUserRow(newUserRow), err
|
||||
_, err = s.repo.GetByEmail(ctx, email)
|
||||
|
||||
if err != repository.ErrNotFound {
|
||||
return uuid, ErrUserExists
|
||||
}
|
||||
|
||||
hashSalt, err := s.argon2IdHash.GenerateHash([]byte(password), nil)
|
||||
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, err.Error())
|
||||
return uuid, err
|
||||
}
|
||||
|
||||
userRow := repository.UserRow{Id: uuid, Email: email, HashedPassword: hashSalt.Hash, Salt: hashSalt.Salt}
|
||||
|
||||
_, err = s.repo.Create(ctx, userRow)
|
||||
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, err.Error())
|
||||
return uuid, err
|
||||
}
|
||||
|
||||
return uuid, err
|
||||
}
|
||||
|
||||
func (s *UserService) DeleteUser(ctx context.Context, id uuid.UUID) error {
|
||||
@ -96,10 +112,66 @@ func (s *UserService) DeleteUser(ctx context.Context, id uuid.UUID) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *UserService) UpdateUser(ctx context.Context, user User) error {
|
||||
userRow := UserRowFromUser(user)
|
||||
func (s *UserService) UpdateUserEmail(ctx context.Context, id uuid.UUID, email string) error {
|
||||
user, err := s.repo.GetById(ctx, id)
|
||||
|
||||
err := s.repo.Update(ctx, userRow)
|
||||
if err == repository.ErrNotFound {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = s.repo.GetByEmail(ctx, email)
|
||||
|
||||
switch err {
|
||||
case repository.ErrNotFound:
|
||||
user.Email = email
|
||||
|
||||
err = s.repo.Update(ctx, user)
|
||||
|
||||
if err == repository.ErrNotFound {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, err.Error())
|
||||
}
|
||||
|
||||
return err
|
||||
case nil:
|
||||
return ErrUserExists
|
||||
default:
|
||||
s.logger.ErrorContext(ctx, err.Error())
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UserService) UpdateUserPassword(ctx context.Context, id uuid.UUID, password string) error {
|
||||
user, err := s.repo.GetById(ctx, id)
|
||||
|
||||
if err == repository.ErrNotFound {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
hashSalt, err := s.argon2IdHash.GenerateHash([]byte(password), nil)
|
||||
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
user.HashedPassword = hashSalt.Hash
|
||||
user.Salt = hashSalt.Salt
|
||||
|
||||
err = s.repo.Update(ctx, user)
|
||||
|
||||
if err == repository.ErrNotFound {
|
||||
return ErrNotFound
|
||||
|
@ -5,27 +5,26 @@ import (
|
||||
"log/slog"
|
||||
"testing"
|
||||
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/auth"
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/test"
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/user/repository"
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/user/service"
|
||||
"github.com/gofrs/uuid/v5"
|
||||
)
|
||||
|
||||
func TestCreateUser(t *testing.T) {
|
||||
func TestRegisterUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slog.Default()
|
||||
tdb := test.NewTestDatabase(t)
|
||||
defer tdb.TearDown()
|
||||
r := repository.NewUserRepository(logger, tdb.Db)
|
||||
argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
|
||||
|
||||
userService := service.NewUserService(logger, r)
|
||||
userService := service.NewUserService(logger, r, argon2IdHash)
|
||||
|
||||
t.Run("Create user", func(t *testing.T) {
|
||||
uuid := NewUUID(t)
|
||||
user := service.NewUser(uuid, "test@test.com", "supersecurehash")
|
||||
|
||||
_, err := userService.CreateUser(ctx, user)
|
||||
_, err := userService.Register(ctx, "test@test.com", "supersecurepassword")
|
||||
|
||||
AssertNoError(t, err)
|
||||
})
|
||||
@ -38,13 +37,12 @@ func TestGetUser(t *testing.T) {
|
||||
tdb := test.NewTestDatabase(t)
|
||||
defer tdb.TearDown()
|
||||
r := repository.NewUserRepository(logger, tdb.Db)
|
||||
uuid := NewUUID(t)
|
||||
argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
|
||||
|
||||
row := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"}
|
||||
_, err := r.Create(ctx, row)
|
||||
AssertNoError(t, err);
|
||||
userService := service.NewUserService(logger, r, argon2IdHash)
|
||||
|
||||
userService := service.NewUserService(logger, r)
|
||||
uuid, err := userService.Register(ctx, "test@test.com", "supersecurepassword")
|
||||
AssertNoError(t, err)
|
||||
|
||||
t.Run("Get exisiting user", func(t *testing.T) {
|
||||
_, err := userService.GetUser(ctx, uuid)
|
||||
@ -66,13 +64,12 @@ func TestDeleteUser(t *testing.T) {
|
||||
tdb := test.NewTestDatabase(t)
|
||||
defer tdb.TearDown()
|
||||
r := repository.NewUserRepository(logger, tdb.Db)
|
||||
uuid := NewUUID(t)
|
||||
argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
|
||||
|
||||
row := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"}
|
||||
_, err := r.Create(ctx, row)
|
||||
AssertNoError(t, err);
|
||||
userService := service.NewUserService(logger, r, argon2IdHash)
|
||||
|
||||
userService := service.NewUserService(logger, r)
|
||||
uuid, err := userService.Register(ctx, "test@test.com", "supersecurepassword")
|
||||
AssertNoError(t, err)
|
||||
|
||||
t.Run("Delete exisiting user", func(t *testing.T) {
|
||||
err := userService.DeleteUser(ctx, uuid)
|
||||
@ -87,25 +84,22 @@ func TestDeleteUser(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateUser(t *testing.T) {
|
||||
func TestUpdateUserEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slog.Default()
|
||||
tdb := test.NewTestDatabase(t)
|
||||
defer tdb.TearDown()
|
||||
r := repository.NewUserRepository(logger, tdb.Db)
|
||||
uuid := NewUUID(t)
|
||||
argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
|
||||
|
||||
row := repository.UserRow{Id: uuid, Email: "test@test.com", HashedPassword: "supersecurehash"}
|
||||
_, err := r.Create(ctx, row)
|
||||
AssertNoError(t, err);
|
||||
userService := service.NewUserService(logger, r, argon2IdHash)
|
||||
|
||||
userService := service.NewUserService(logger, r)
|
||||
uuid, err := userService.Register(ctx, "test@test.com", "supersecurepassword")
|
||||
AssertNoError(t, err)
|
||||
|
||||
t.Run("Update exisiting user", func(t *testing.T) {
|
||||
user := service.User{uuid, "new@email.com", "supersecurehash"}
|
||||
|
||||
err := userService.UpdateUser(ctx, user)
|
||||
t.Run("Update existing user email", func(t *testing.T) {
|
||||
err := userService.UpdateUserEmail(ctx, uuid, "newemail@test.com")
|
||||
|
||||
AssertNoError(t, err)
|
||||
|
||||
@ -113,13 +107,13 @@ func TestUpdateUser(t *testing.T) {
|
||||
|
||||
AssertNoError(t, err)
|
||||
|
||||
AssertUsers(t, newUser, user)
|
||||
if newUser.Email != "newemail@test.com" {
|
||||
t.Errorf("Emails do not match wanted %q, got %q", "newemail@test.com", newUser.Email)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Update non-existant user", func(t *testing.T) {
|
||||
user := service.User{NewUUID(t), "new@email.com", "supersecurehash"}
|
||||
|
||||
err := userService.UpdateUser(ctx, user)
|
||||
err := userService.UpdateUserEmail(ctx, NewUUID(t), "newemail@test.com")
|
||||
|
||||
AssertErrors(t, err, service.ErrNotFound)
|
||||
})
|
||||
|
Loading…
x
Reference in New Issue
Block a user