Merge pull request 'Logging, middleware, and context passing' (#3) from feature/logging into main
All checks were successful
ci/woodpecker/push/build Pipeline was successful
ci/woodpecker/push/lint Pipeline was successful
ci/woodpecker/push/test Pipeline was successful

Reviewed-on: #3
This commit is contained in:
Michael Thomson 2025-05-15 17:47:37 +00:00
commit 60698001fd
24 changed files with 381 additions and 187 deletions

View File

@ -3,51 +3,56 @@ package main
import ( import (
"database/sql" "database/sql"
"log" "log"
"log/slog"
"net/http" "net/http"
"os" "os"
"gitea.michaelthomson.dev/mthomson/habits/internal/logging"
"gitea.michaelthomson.dev/mthomson/habits/internal/middleware" "gitea.michaelthomson.dev/mthomson/habits/internal/middleware"
"gitea.michaelthomson.dev/mthomson/habits/internal/migrate" "gitea.michaelthomson.dev/mthomson/habits/internal/migrate"
todohandler "gitea.michaelthomson.dev/mthomson/habits/internal/todo/handler" todohandler "gitea.michaelthomson.dev/mthomson/habits/internal/todo/handler"
todorepository "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository/postgres" todorepository "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository/postgres"
todoservice "gitea.michaelthomson.dev/mthomson/habits/internal/todo/service" todoservice "gitea.michaelthomson.dev/mthomson/habits/internal/todo/service"
_ "github.com/jackc/pgx/v5/stdlib" _ "github.com/jackc/pgx/v5/stdlib"
_ "github.com/mattn/go-sqlite3"
) )
func main() { func main() {
// create logger // create logger
httpLogger := slog.New(slog.NewTextHandler(os.Stdout, nil)) logger := logging.NewLogger()
// create middlewares // create middlewares
loggingMiddleware := middleware.LoggingMiddleware(httpLogger) contextMiddleware := middleware.ContextMiddleware(logger)
loggingMiddleware := middleware.LoggingMiddleware(logger)
stack := []middleware.Middleware{
contextMiddleware,
loggingMiddleware,
}
// create db pool // create db pool
postgresUrl := "postgres://todo:password@localhost:5432/todo" postgresUrl := "postgres://todo:password@localhost:5432/todo"
db, err := sql.Open("pgx", postgresUrl) db, err := sql.Open("pgx", postgresUrl)
if err != nil { if err != nil {
log.Fatalf("Failed to open db pool: %v", err) logger.Error(err.Error())
os.Exit(1);
} }
defer db.Close()
// run migrations // run migrations
migrate.Migrate(db) migrate.Migrate(logger, db)
// create repos // create repos
todoRepository := todorepository.NewPostgresTodoRepository(db) todoRepository := todorepository.NewPostgresTodoRepository(logger, db)
// create services // create services
todoService := todoservice.NewTodoService(todoRepository) todoService := todoservice.NewTodoService(logger, todoRepository)
// create mux // create mux
mux := http.NewServeMux() mux := http.NewServeMux()
// register handlers // register handlers
mux.Handle("GET /todo/{id}", loggingMiddleware(todohandler.HandleTodoGet(todoService))) mux.Handle("GET /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoGet(logger, todoService), stack))
mux.Handle("POST /todo", loggingMiddleware(todohandler.HandleTodoCreate(todoService))) mux.Handle("POST /todo", middleware.CompileMiddleware(todohandler.HandleTodoCreate(logger, todoService), stack))
mux.Handle("DELETE /todo/{id}", loggingMiddleware(todohandler.HandleTodoDelete(todoService))) mux.Handle("DELETE /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoDelete(logger, todoService), stack))
mux.Handle("PUT /todo/{id}", loggingMiddleware(todohandler.HandleTodoUpdate(todoService))) mux.Handle("PUT /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoUpdate(logger, todoService), stack))
// create server // create server
server := &http.Server{ server := &http.Server{

10
flake.lock generated
View File

@ -2,12 +2,12 @@
"nodes": { "nodes": {
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1731676054, "lastModified": 1747179050,
"narHash": "sha256-OZiZ3m8SCMfh3B6bfGC/Bm4x3qc1m2SVEAlkV6iY7Yg=", "narHash": "sha256-qhFMmDkeJX9KJwr5H32f1r7Prs7XbQWtO0h3V0a0rFY=",
"rev": "5e4fbfb6b3de1aa2872b76d49fafc942626e2add", "rev": "adaa24fbf46737f3f1b5497bf64bae750f82942e",
"revCount": 708622, "revCount": 799423,
"type": "tarball", "type": "tarball",
"url": "https://api.flakehub.com/f/pinned/NixOS/nixpkgs/0.1.708622%2Brev-5e4fbfb6b3de1aa2872b76d49fafc942626e2add/0193363c-ab27-7bbd-af1d-3e6093ed5e2d/source.tar.gz" "url": "https://api.flakehub.com/f/pinned/NixOS/nixpkgs/0.1.799423%2Brev-adaa24fbf46737f3f1b5497bf64bae750f82942e/0196d1c3-1974-7bf1-bcf6-06620ac40c8c/source.tar.gz"
}, },
"original": { "original": {
"type": "tarball", "type": "tarball",

View File

@ -5,7 +5,7 @@
outputs = { self, nixpkgs }: outputs = { self, nixpkgs }:
let let
goVersion = 22; # Change this to update the whole stack goVersion = 23; # Change this to update the whole stack
supportedSystems = [ "x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin" ]; supportedSystems = [ "x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin" ];
forEachSupportedSystem = f: nixpkgs.lib.genAttrs supportedSystems (system: f { forEachSupportedSystem = f: nixpkgs.lib.genAttrs supportedSystems (system: f {
@ -36,6 +36,7 @@
docker docker
docker-compose docker-compose
gopls
]; ];
}; };
}); });

View File

@ -0,0 +1,28 @@
package logging
import (
"context"
"log/slog"
"os"
"gitea.michaelthomson.dev/mthomson/habits/internal/middleware"
)
type ContextHandler struct {
slog.Handler
}
func (h *ContextHandler) Handle(ctx context.Context, r slog.Record) error {
if requestID, ok := ctx.Value(middleware.TraceIdKey).(string); ok {
r.AddAttrs(slog.String(string(middleware.TraceIdKey), requestID))
}
return h.Handler.Handle(ctx, r)
}
func NewLogger() *slog.Logger {
baseHandler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{AddSource: false})
customHandler := &ContextHandler{Handler: baseHandler}
logger := slog.New(customHandler)
return logger
}

View File

@ -0,0 +1,24 @@
package middleware
import (
"context"
"log/slog"
"net/http"
"github.com/google/uuid"
)
type contextKey string
const TraceIdKey contextKey = "trace_id"
func ContextMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
traceid := uuid.NewString()
ctx := context.WithValue(r.Context(), TraceIdKey, traceid)
newReq := r.WithContext(ctx)
next.ServeHTTP(w, newReq)
})
}
}

View File

@ -1,23 +1,39 @@
package middleware package middleware
import ( import (
"context"
"log/slog" "log/slog"
"net/http" "net/http"
) )
type LoggingResponseWriter struct {
http.ResponseWriter
statusCode int
}
func NewLoggingResponseWriter(w http.ResponseWriter) *LoggingResponseWriter {
return &LoggingResponseWriter{w, http.StatusOK}
}
func (lrw *LoggingResponseWriter) WriteHeader(code int) {
lrw.statusCode = code
lrw.ResponseWriter.WriteHeader(code)
}
func LoggingMiddleware(logger *slog.Logger) func(http.Handler) http.Handler { func LoggingMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.LogAttrs( logger.InfoContext(r.Context(), "Incoming request",
context.Background(),
slog.LevelInfo,
"Incoming request",
slog.String("method", r.Method), slog.String("method", r.Method),
slog.String("path", r.URL.String()), slog.String("path", r.URL.String()),
) )
next.ServeHTTP(w, r) lrw := NewLoggingResponseWriter(w)
next.ServeHTTP(lrw, r)
logger.InfoContext(r.Context(), "Sent response",
slog.Int("code", lrw.statusCode),
slog.String("message", http.StatusText(lrw.statusCode)),
)
}) })
} }
} }

View File

@ -0,0 +1,19 @@
package middleware
import "net/http"
type Middleware func(http.Handler) http.Handler
func CompileMiddleware(h http.Handler, m []Middleware) http.Handler {
if len(m) < 1 {
return h
}
wrapped := h
for i := len(m) - 1; i >= 0; i-- {
wrapped = m[i](wrapped)
}
return wrapped
}

View File

@ -7,7 +7,7 @@ import (
"log" "log"
"log/slog" "log/slog"
_ "github.com/mattn/go-sqlite3" _ "github.com/jackc/pgx/v5/stdlib"
) )
//go:embed migrations/*.sql //go:embed migrations/*.sql
@ -18,8 +18,8 @@ type Migration struct {
Name string Name string
} }
func Migrate(db *sql.DB) { func Migrate(logger *slog.Logger, db *sql.DB) {
slog.Info("Running migrations...") logger.Info("Running migrations...")
migrationTableSql := ` migrationTableSql := `
CREATE TABLE IF NOT EXISTS migrations( CREATE TABLE IF NOT EXISTS migrations(
version SERIAL PRIMARY KEY, version SERIAL PRIMARY KEY,
@ -41,7 +41,7 @@ func Migrate(db *sql.DB) {
row := db.QueryRow("SELECT * FROM migrations WHERE name = $1;", file.Name()) row := db.QueryRow("SELECT * FROM migrations WHERE name = $1;", file.Name())
err = row.Scan(&migration.Version, &migration.Name) err = row.Scan(&migration.Version, &migration.Name)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
slog.Info(fmt.Sprintf("Running migration: %s", file.Name())) logger.Info(fmt.Sprintf("Running migration: %s", file.Name()))
migrationSql, err := migrations.ReadFile(fmt.Sprintf("migrations/%s", file.Name())) migrationSql, err := migrations.ReadFile(fmt.Sprintf("migrations/%s", file.Name()))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
@ -58,5 +58,5 @@ func Migrate(db *sql.DB) {
} }
} }
} }
slog.Info("Migrations completed") logger.Info("Migrations completed")
} }

View File

@ -0,0 +1,61 @@
package test
import (
"context"
"database/sql"
"log/slog"
"testing"
"time"
"gitea.michaelthomson.dev/mthomson/habits/internal/migrate"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
)
type TestDatabase struct {
Db *sql.DB
container testcontainers.Container
}
func NewTestDatabase(tb testing.TB) *TestDatabase {
tb.Helper()
ctx := context.Background()
// create container
postgresContainer, err := postgres.Run(ctx,
"postgres:16-alpine",
postgres.WithDatabase("todo"),
postgres.WithUsername("todo"),
postgres.WithPassword("password"),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(5*time.Second)),
)
if err != nil {
tb.Fatalf("Failed to create postgres container, %v", err)
}
connectionString, err := postgresContainer.ConnectionString(ctx)
if err != nil {
tb.Fatalf("Failed to get connection string: %v", err)
}
// create db pool
db, err := sql.Open("pgx", connectionString)
if err != nil {
tb.Fatalf("Failed to open db pool: %v", err)
}
migrate.Migrate(slog.Default(), db)
return &TestDatabase{
Db: db,
container: postgresContainer,
}
}
func (tdb *TestDatabase) TearDown() {
_ = tdb.container.Terminate(context.Background())
}

View File

@ -3,6 +3,7 @@ package handler
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog"
"net/http" "net/http"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/service" "gitea.michaelthomson.dev/mthomson/habits/internal/todo/service"
@ -27,19 +28,21 @@ func CreateTodoResponseFromTodo(todo service.Todo) CreateTodoResponse {
return CreateTodoResponse{Id: todo.Id, Name: todo.Name, Done: todo.Done} return CreateTodoResponse{Id: todo.Id, Name: todo.Name, Done: todo.Done}
} }
func HandleTodoCreate(todoService TodoCreater) http.HandlerFunc { func HandleTodoCreate(logger *slog.Logger, todoService TodoCreater) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
createTodoRequest := CreateTodoRequest{} createTodoRequest := CreateTodoRequest{}
decoder := json.NewDecoder(r.Body) decoder := json.NewDecoder(r.Body)
decoder.DisallowUnknownFields() decoder.DisallowUnknownFields()
err := decoder.Decode(&createTodoRequest) err := decoder.Decode(&createTodoRequest)
if err != nil { if err != nil {
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusBadRequest) http.Error(w, "", http.StatusBadRequest)
return return
} }
todo, err := todoService.CreateTodo(TodoFromCreateTodoRequest(createTodoRequest)) todo, err := todoService.CreateTodo(ctx, TodoFromCreateTodoRequest(createTodoRequest))
if err != nil { if err != nil {
if err == service.ErrNotFound { if err == service.ErrNotFound {
@ -47,6 +50,7 @@ func HandleTodoCreate(todoService TodoCreater) http.HandlerFunc {
return return
} }
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusInternalServerError) http.Error(w, "", http.StatusInternalServerError)
return return
} }
@ -60,6 +64,7 @@ func HandleTodoCreate(todoService TodoCreater) http.HandlerFunc {
err = json.NewEncoder(w).Encode(response) err = json.NewEncoder(w).Encode(response)
if err != nil { if err != nil {
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusInternalServerError) http.Error(w, "", http.StatusInternalServerError)
return return
} }

View File

@ -2,7 +2,9 @@ package handler
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"log/slog"
"errors" "errors"
"net/http" "net/http"
@ -14,26 +16,27 @@ import (
) )
type MockTodoCreater struct { type MockTodoCreater struct {
CreateTodoFunc func(todo service.Todo) (service.Todo, error) CreateTodoFunc func(cxt context.Context, todo service.Todo) (service.Todo, error)
} }
func (tg *MockTodoCreater) CreateTodo(todo service.Todo) (service.Todo, error) { func (tg *MockTodoCreater) CreateTodo(ctx context.Context, todo service.Todo) (service.Todo, error) {
return tg.CreateTodoFunc(todo) return tg.CreateTodoFunc(ctx, todo)
} }
func TestCreateTodo(t *testing.T) { func TestCreateTodo(t *testing.T) {
logger := slog.Default()
t.Run("create todo", func(t *testing.T) { t.Run("create todo", func(t *testing.T) {
createTodoRequest := CreateTodoRequest{Name: "clean dishes", Done: false} createTodoRequest := CreateTodoRequest{Name: "clean dishes", Done: false}
createdTodo := service.Todo{Id: 1, Name: "clean dishes", Done: false} createdTodo := service.Todo{Id: 1, Name: "clean dishes", Done: false}
want := CreateTodoResponse{Id: 1, Name: "clean dishes", Done: false} want := CreateTodoResponse{Id: 1, Name: "clean dishes", Done: false}
service := MockTodoCreater{ service := MockTodoCreater{
CreateTodoFunc: func(todo service.Todo) (service.Todo, error) { CreateTodoFunc: func(ctx context.Context, todo service.Todo) (service.Todo, error) {
return createdTodo, nil return createdTodo, nil
}, },
} }
handler := HandleTodoCreate(&service) handler := HandleTodoCreate(logger, &service)
requestBody, err := json.Marshal(createTodoRequest) requestBody, err := json.Marshal(createTodoRequest)
@ -67,7 +70,7 @@ func TestCreateTodo(t *testing.T) {
}) })
t.Run("returns 400 with bad json", func(t *testing.T) { t.Run("returns 400 with bad json", func(t *testing.T) {
handler := HandleTodoCreate(nil) handler := HandleTodoCreate(logger, nil)
badStruct := struct { badStruct := struct {
Foo string Foo string
@ -95,12 +98,12 @@ func TestCreateTodo(t *testing.T) {
createTodoRequest := CreateTodoRequest{Name: "clean dishes", Done: false} createTodoRequest := CreateTodoRequest{Name: "clean dishes", Done: false}
service := MockTodoCreater{ service := MockTodoCreater{
CreateTodoFunc: func(todo service.Todo) (service.Todo, error) { CreateTodoFunc: func(ctx context.Context, todo service.Todo) (service.Todo, error) {
return service.Todo{}, errors.New("foo bar") return service.Todo{}, errors.New("foo bar")
}, },
} }
handler := HandleTodoCreate(&service) handler := HandleTodoCreate(logger, &service)
requestBody, err := json.Marshal(createTodoRequest) requestBody, err := json.Marshal(createTodoRequest)

View File

@ -1,24 +1,27 @@
package handler package handler
import ( import (
"log/slog"
"net/http" "net/http"
"strconv" "strconv"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/service" "gitea.michaelthomson.dev/mthomson/habits/internal/todo/service"
) )
func HandleTodoDelete(todoService TodoDeleter) http.HandlerFunc { func HandleTodoDelete(logger *slog.Logger, todoService TodoDeleter) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
idString := r.PathValue("id") idString := r.PathValue("id")
id, err := strconv.ParseInt(idString, 10, 64) id, err := strconv.ParseInt(idString, 10, 64)
if err != nil { if err != nil {
slog.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusBadRequest) http.Error(w, "", http.StatusBadRequest)
return return
} }
err = todoService.DeleteTodo(id) err = todoService.DeleteTodo(ctx, id)
if err != nil { if err != nil {
if err == service.ErrNotFound { if err == service.ErrNotFound {
@ -26,6 +29,7 @@ func HandleTodoDelete(todoService TodoDeleter) http.HandlerFunc {
return return
} }
slog.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusInternalServerError) http.Error(w, "", http.StatusInternalServerError)
return return
} }

View File

@ -1,7 +1,9 @@
package handler package handler
import ( import (
"context"
"errors" "errors"
"log/slog"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -10,22 +12,23 @@ import (
) )
type MockTodoDeleter struct { type MockTodoDeleter struct {
DeleteTodoFunc func(id int64) error DeleteTodoFunc func(ctx context.Context, id int64) error
} }
func (tg *MockTodoDeleter) DeleteTodo(id int64) error { func (tg *MockTodoDeleter) DeleteTodo(ctx context.Context, id int64) error {
return tg.DeleteTodoFunc(id) return tg.DeleteTodoFunc(ctx, id)
} }
func TestDeleteTodo(t *testing.T) { func TestDeleteTodo(t *testing.T) {
logger := slog.Default()
t.Run("deletes existing todo", func(t *testing.T) { t.Run("deletes existing todo", func(t *testing.T) {
service := MockTodoDeleter{ service := MockTodoDeleter{
DeleteTodoFunc: func(id int64) error { DeleteTodoFunc: func(ctx context.Context, id int64) error {
return nil return nil
}, },
} }
handler := HandleTodoDelete(&service) handler := HandleTodoDelete(logger, &service)
req := httptest.NewRequest(http.MethodDelete, "/todo/1", nil) req := httptest.NewRequest(http.MethodDelete, "/todo/1", nil)
res := httptest.NewRecorder() res := httptest.NewRecorder()
@ -39,7 +42,7 @@ func TestDeleteTodo(t *testing.T) {
}) })
t.Run("returns 400 with bad id", func(t *testing.T) { t.Run("returns 400 with bad id", func(t *testing.T) {
handler := HandleTodoDelete(nil) handler := HandleTodoDelete(logger, nil)
req := httptest.NewRequest(http.MethodDelete, "/todo/hello", nil) req := httptest.NewRequest(http.MethodDelete, "/todo/hello", nil)
res := httptest.NewRecorder() res := httptest.NewRecorder()
@ -54,12 +57,12 @@ func TestDeleteTodo(t *testing.T) {
t.Run("returns 404 for not found todo", func(t *testing.T) { t.Run("returns 404 for not found todo", func(t *testing.T) {
service := MockTodoDeleter{ service := MockTodoDeleter{
DeleteTodoFunc: func(id int64) error { DeleteTodoFunc: func(ctx context.Context, id int64) error {
return service.ErrNotFound return service.ErrNotFound
}, },
} }
handler := HandleTodoDelete(&service) handler := HandleTodoDelete(logger, &service)
req := httptest.NewRequest(http.MethodDelete, "/todo/1", nil) req := httptest.NewRequest(http.MethodDelete, "/todo/1", nil)
res := httptest.NewRecorder() res := httptest.NewRecorder()
@ -74,12 +77,12 @@ func TestDeleteTodo(t *testing.T) {
t.Run("returns 500 for arbitrary errors", func(t *testing.T) { t.Run("returns 500 for arbitrary errors", func(t *testing.T) {
service := MockTodoDeleter{ service := MockTodoDeleter{
DeleteTodoFunc: func(id int64) error { DeleteTodoFunc: func(ctx context.Context, id int64) error {
return errors.New("foo bar") return errors.New("foo bar")
}, },
} }
handler := HandleTodoDelete(&service) handler := HandleTodoDelete(logger, &service)
req := httptest.NewRequest(http.MethodDelete, "/todo/1", nil) req := httptest.NewRequest(http.MethodDelete, "/todo/1", nil)
res := httptest.NewRecorder() res := httptest.NewRecorder()

View File

@ -2,6 +2,7 @@ package handler
import ( import (
"encoding/json" "encoding/json"
"log/slog"
"net/http" "net/http"
"strconv" "strconv"
@ -18,18 +19,20 @@ func GetTodoResponseFromTodo(todo service.Todo) GetTodoResponse {
return GetTodoResponse{Id: todo.Id, Name: todo.Name, Done: todo.Done} return GetTodoResponse{Id: todo.Id, Name: todo.Name, Done: todo.Done}
} }
func HandleTodoGet(todoService TodoGetter) http.HandlerFunc { func HandleTodoGet(logger *slog.Logger, todoService TodoGetter) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
idString := r.PathValue("id") idString := r.PathValue("id")
id, err := strconv.ParseInt(idString, 10, 64) id, err := strconv.ParseInt(idString, 10, 64)
if err != nil { if err != nil {
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusBadRequest) http.Error(w, "", http.StatusBadRequest)
return return
} }
todo, err := todoService.GetTodo(id) todo, err := todoService.GetTodo(ctx, id)
if err != nil { if err != nil {
if err == service.ErrNotFound { if err == service.ErrNotFound {
@ -37,6 +40,7 @@ func HandleTodoGet(todoService TodoGetter) http.HandlerFunc {
return return
} }
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusInternalServerError) http.Error(w, "", http.StatusInternalServerError)
return return
} }
@ -49,6 +53,7 @@ func HandleTodoGet(todoService TodoGetter) http.HandlerFunc {
err = json.NewEncoder(w).Encode(response) err = json.NewEncoder(w).Encode(response)
if err != nil { if err != nil {
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusInternalServerError) http.Error(w, "", http.StatusInternalServerError)
return return
} }

View File

@ -1,9 +1,11 @@
package handler package handler
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
@ -13,25 +15,26 @@ import (
) )
type MockTodoGetter struct { type MockTodoGetter struct {
GetTodoFunc func(id int64) (service.Todo, error) GetTodoFunc func(ctx context.Context, id int64) (service.Todo, error)
} }
func (tg *MockTodoGetter) GetTodo(id int64) (service.Todo, error) { func (tg *MockTodoGetter) GetTodo(ctx context.Context, id int64) (service.Todo, error) {
return tg.GetTodoFunc(id) return tg.GetTodoFunc(ctx, id)
} }
func TestGetTodo(t *testing.T) { func TestGetTodo(t *testing.T) {
logger := slog.Default()
t.Run("gets existing todo", func(t *testing.T) { t.Run("gets existing todo", func(t *testing.T) {
todo := service.Todo{Id: 1, Name: "clean dishes", Done: false} todo := service.Todo{Id: 1, Name: "clean dishes", Done: false}
wantedTodo := GetTodoResponse{Id: 1, Name: "clean dishes", Done: false} wantedTodo := GetTodoResponse{Id: 1, Name: "clean dishes", Done: false}
service := MockTodoGetter{ service := MockTodoGetter{
GetTodoFunc: func(id int64) (service.Todo, error) { GetTodoFunc: func(ctx context.Context, id int64) (service.Todo, error) {
return todo, nil return todo, nil
}, },
} }
handler := HandleTodoGet(&service) handler := HandleTodoGet(logger, &service)
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/todo/%d", todo.Id), nil) req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/todo/%d", todo.Id), nil)
res := httptest.NewRecorder() res := httptest.NewRecorder()
@ -56,7 +59,7 @@ func TestGetTodo(t *testing.T) {
}) })
t.Run("returns 400 with bad id", func(t *testing.T) { t.Run("returns 400 with bad id", func(t *testing.T) {
handler := HandleTodoGet(nil) handler := HandleTodoGet(logger, nil)
req := httptest.NewRequest(http.MethodGet, "/todo/hello", nil) req := httptest.NewRequest(http.MethodGet, "/todo/hello", nil)
res := httptest.NewRecorder() res := httptest.NewRecorder()
@ -71,12 +74,12 @@ func TestGetTodo(t *testing.T) {
t.Run("returns 404 for not found todo", func(t *testing.T) { t.Run("returns 404 for not found todo", func(t *testing.T) {
service := MockTodoGetter{ service := MockTodoGetter{
GetTodoFunc: func(id int64) (service.Todo, error) { GetTodoFunc: func(ctx context.Context, id int64) (service.Todo, error) {
return service.Todo{}, service.ErrNotFound return service.Todo{}, service.ErrNotFound
}, },
} }
handler := HandleTodoGet(&service) handler := HandleTodoGet(logger, &service)
req := httptest.NewRequest(http.MethodGet, "/todo/1", nil) req := httptest.NewRequest(http.MethodGet, "/todo/1", nil)
res := httptest.NewRecorder() res := httptest.NewRecorder()
@ -91,12 +94,12 @@ func TestGetTodo(t *testing.T) {
t.Run("returns 500 for arbitrary errors", func(t *testing.T) { t.Run("returns 500 for arbitrary errors", func(t *testing.T) {
service := MockTodoGetter{ service := MockTodoGetter{
GetTodoFunc: func(id int64) (service.Todo, error) { GetTodoFunc: func(ctx context.Context, id int64) (service.Todo, error) {
return service.Todo{}, errors.New("foo bar") return service.Todo{}, errors.New("foo bar")
}, },
} }
handler := HandleTodoGet(&service) handler := HandleTodoGet(logger, &service)
req := httptest.NewRequest(http.MethodGet, "/todo/1", nil) req := httptest.NewRequest(http.MethodGet, "/todo/1", nil)
res := httptest.NewRecorder() res := httptest.NewRecorder()

View File

@ -1,21 +1,23 @@
package handler package handler
import ( import (
"context"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/service" "gitea.michaelthomson.dev/mthomson/habits/internal/todo/service"
) )
type TodoGetter interface { type TodoGetter interface {
GetTodo(id int64) (service.Todo, error) GetTodo(ctx context.Context, id int64) (service.Todo, error)
} }
type TodoCreater interface { type TodoCreater interface {
CreateTodo(todo service.Todo) (service.Todo, error) CreateTodo(ctx context.Context, todo service.Todo) (service.Todo, error)
} }
type TodoDeleter interface { type TodoDeleter interface {
DeleteTodo(id int64) error DeleteTodo(ctx context.Context, id int64) error
} }
type TodoUpdater interface { type TodoUpdater interface {
UpdateTodo(todo service.Todo) error UpdateTodo(ctx context.Context, todo service.Todo) error
} }

View File

@ -2,6 +2,7 @@ package handler
import ( import (
"encoding/json" "encoding/json"
"log/slog"
"net/http" "net/http"
"strconv" "strconv"
@ -17,14 +18,16 @@ func TodoFromUpdateTodoRequest(todo UpdateTodoRequest, id int64) service.Todo {
return service.Todo{Id: id, Name: todo.Name, Done: todo.Done} return service.Todo{Id: id, Name: todo.Name, Done: todo.Done}
} }
func HandleTodoUpdate(todoService TodoUpdater) http.HandlerFunc { func HandleTodoUpdate(logger *slog.Logger, todoService TodoUpdater) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
updateTodoRequest := UpdateTodoRequest{} updateTodoRequest := UpdateTodoRequest{}
decoder := json.NewDecoder(r.Body) decoder := json.NewDecoder(r.Body)
decoder.DisallowUnknownFields() decoder.DisallowUnknownFields()
err := decoder.Decode(&updateTodoRequest) err := decoder.Decode(&updateTodoRequest)
if err != nil { if err != nil {
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusBadRequest) http.Error(w, "", http.StatusBadRequest)
return return
} }
@ -34,11 +37,12 @@ func HandleTodoUpdate(todoService TodoUpdater) http.HandlerFunc {
id, err := strconv.ParseInt(idString, 10, 64) id, err := strconv.ParseInt(idString, 10, 64)
if err != nil { if err != nil {
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusBadRequest) http.Error(w, "", http.StatusBadRequest)
return return
} }
err = todoService.UpdateTodo(TodoFromUpdateTodoRequest(updateTodoRequest, id)) err = todoService.UpdateTodo(ctx, TodoFromUpdateTodoRequest(updateTodoRequest, id))
if err != nil { if err != nil {
if err == service.ErrNotFound { if err == service.ErrNotFound {
@ -46,6 +50,7 @@ func HandleTodoUpdate(todoService TodoUpdater) http.HandlerFunc {
return return
} }
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusInternalServerError) http.Error(w, "", http.StatusInternalServerError)
return return
} }

View File

@ -2,7 +2,9 @@ package handler
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"log/slog"
"errors" "errors"
"net/http" "net/http"
@ -13,24 +15,25 @@ import (
) )
type MockTodoUpdater struct { type MockTodoUpdater struct {
UpdateTodoFunc func(todo service.Todo) error UpdateTodoFunc func(ctx context.Context, todo service.Todo) error
} }
func (tg *MockTodoUpdater) UpdateTodo(todo service.Todo) error { func (tg *MockTodoUpdater) UpdateTodo(ctx context.Context, todo service.Todo) error {
return tg.UpdateTodoFunc(todo) return tg.UpdateTodoFunc(ctx, todo)
} }
func TestUpdateTodo(t *testing.T) { func TestUpdateTodo(t *testing.T) {
logger := slog.Default()
t.Run("update todo", func(t *testing.T) { t.Run("update todo", func(t *testing.T) {
updateTodoRequest := UpdateTodoRequest{Name: "clean dishes", Done: false} updateTodoRequest := UpdateTodoRequest{Name: "clean dishes", Done: false}
service := MockTodoUpdater{ service := MockTodoUpdater{
UpdateTodoFunc: func(todo service.Todo) error { UpdateTodoFunc: func(ctx context.Context, todo service.Todo) error {
return nil return nil
}, },
} }
handler := HandleTodoUpdate(&service) handler := HandleTodoUpdate(logger, &service)
requestBody, err := json.Marshal(updateTodoRequest) requestBody, err := json.Marshal(updateTodoRequest)
@ -50,7 +53,7 @@ func TestUpdateTodo(t *testing.T) {
}) })
t.Run("returns 400 with bad json", func(t *testing.T) { t.Run("returns 400 with bad json", func(t *testing.T) {
handler := HandleTodoUpdate(nil) handler := HandleTodoUpdate(logger, nil)
badStruct := struct { badStruct := struct {
Foo string Foo string
@ -75,7 +78,7 @@ func TestUpdateTodo(t *testing.T) {
}) })
t.Run("returns 400 with bad id", func(t *testing.T) { t.Run("returns 400 with bad id", func(t *testing.T) {
handler := HandleTodoUpdate(nil) handler := HandleTodoUpdate(logger, nil)
req := httptest.NewRequest(http.MethodPut, "/todo/hello", nil) req := httptest.NewRequest(http.MethodPut, "/todo/hello", nil)
res := httptest.NewRecorder() res := httptest.NewRecorder()
@ -92,12 +95,12 @@ func TestUpdateTodo(t *testing.T) {
updateTodoRequest := UpdateTodoRequest{Name: "clean dishes", Done: false} updateTodoRequest := UpdateTodoRequest{Name: "clean dishes", Done: false}
service := MockTodoUpdater{ service := MockTodoUpdater{
UpdateTodoFunc: func(todo service.Todo) error { UpdateTodoFunc: func(ctx context.Context, todo service.Todo) error {
return errors.New("foo bar") return errors.New("foo bar")
}, },
} }
handler := HandleTodoUpdate(&service) handler := HandleTodoUpdate(logger, &service)
requestBody, err := json.Marshal(updateTodoRequest) requestBody, err := json.Marshal(updateTodoRequest)

View File

@ -1,22 +1,26 @@
package postgres package postgres
import ( import (
"context"
"database/sql" "database/sql"
"log/slog"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository" "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
) )
type PostgresTodoRepository struct { type PostgresTodoRepository struct {
db *sql.DB logger *slog.Logger
db *sql.DB
} }
func NewPostgresTodoRepository(db *sql.DB) *PostgresTodoRepository { func NewPostgresTodoRepository(logger *slog.Logger, db *sql.DB) *PostgresTodoRepository {
return &PostgresTodoRepository{ return &PostgresTodoRepository{
db: db, logger: logger,
db: db,
} }
} }
func (r *PostgresTodoRepository) GetById(id int64) (repository.TodoRow, error) { func (r *PostgresTodoRepository) GetById(ctx context.Context, id int64) (repository.TodoRow, error) {
todo := repository.TodoRow{} todo := repository.TodoRow{}
err := r.db.QueryRow("SELECT * FROM todo WHERE id = $1;", id).Scan(&todo.Id, &todo.Name, &todo.Done) err := r.db.QueryRow("SELECT * FROM todo WHERE id = $1;", id).Scan(&todo.Id, &todo.Name, &todo.Done)
@ -26,33 +30,37 @@ func (r *PostgresTodoRepository) GetById(id int64) (repository.TodoRow, error) {
return todo, repository.ErrNotFound return todo, repository.ErrNotFound
} }
r.logger.ErrorContext(ctx, err.Error())
return todo, err return todo, err
} }
return todo, nil return todo, nil
} }
func (r *PostgresTodoRepository) Create(todo repository.TodoRow) (repository.TodoRow, error) { func (r *PostgresTodoRepository) Create(ctx context.Context, todo repository.TodoRow) (repository.TodoRow, error) {
result := r.db.QueryRow("INSERT INTO todo (name, done) VALUES ($1, $2) RETURNING id;", todo.Name, todo.Done) result := r.db.QueryRow("INSERT INTO todo (name, done) VALUES ($1, $2) RETURNING id;", todo.Name, todo.Done)
err := result.Scan(&todo.Id) err := result.Scan(&todo.Id)
if err != nil { if err != nil {
r.logger.ErrorContext(ctx, err.Error())
return repository.TodoRow{}, err return repository.TodoRow{}, err
} }
return todo, nil return todo, nil
} }
func (r *PostgresTodoRepository) Update(todo repository.TodoRow) error { func (r *PostgresTodoRepository) Update(ctx context.Context, todo repository.TodoRow) error {
result, err := r.db.Exec("UPDATE todo SET name = $1, done = $2 WHERE id = $3;", todo.Name, todo.Done, todo.Id) result, err := r.db.Exec("UPDATE todo SET name = $1, done = $2 WHERE id = $3;", todo.Name, todo.Done, todo.Id)
if err != nil { if err != nil {
r.logger.ErrorContext(ctx, err.Error())
return err return err
} }
rowsAffected, err := result.RowsAffected() rowsAffected, err := result.RowsAffected()
if err != nil { if err != nil {
r.logger.ErrorContext(ctx, err.Error())
return err return err
} }
@ -63,16 +71,18 @@ func (r *PostgresTodoRepository) Update(todo repository.TodoRow) error {
return nil return nil
} }
func (r *PostgresTodoRepository) Delete(id int64) error { func (r *PostgresTodoRepository) Delete(ctx context.Context, id int64) error {
result, err := r.db.Exec("DELETE FROM todo WHERE id = $1;", id) result, err := r.db.Exec("DELETE FROM todo WHERE id = $1;", id)
if err != nil { if err != nil {
r.logger.ErrorContext(ctx, err.Error())
return err return err
} }
rowsAffected, err := result.RowsAffected() rowsAffected, err := result.RowsAffected()
if err != nil { if err != nil {
r.logger.ErrorContext(ctx, err.Error())
return err return err
} }

View File

@ -2,103 +2,53 @@ package postgres
import ( import (
"context" "context"
"database/sql"
"errors" "errors"
"log/slog"
"testing" "testing"
"time"
"gitea.michaelthomson.dev/mthomson/habits/internal/migrate" "gitea.michaelthomson.dev/mthomson/habits/internal/test"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository" "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
_ "github.com/jackc/pgx/v5/stdlib" _ "github.com/jackc/pgx/v5/stdlib"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
) )
type TestDatabase struct {
Db *sql.DB
container testcontainers.Container
}
func NewTestDatabase(tb testing.TB) *TestDatabase {
tb.Helper()
ctx := context.Background()
// create container
postgresContainer, err := postgres.Run(ctx,
"postgres:16-alpine",
postgres.WithDatabase("todo"),
postgres.WithUsername("todo"),
postgres.WithPassword("password"),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(5*time.Second)),
)
if err != nil {
tb.Fatalf("Failed to create postgres container, %v", err)
}
connectionString, err := postgresContainer.ConnectionString(ctx)
if err != nil {
tb.Fatalf("Failed to get connection string: %v", err)
}
// create db pool
db, err := sql.Open("pgx", connectionString)
if err != nil {
tb.Fatalf("Failed to open db pool: %v", err)
}
migrate.Migrate(db)
return &TestDatabase{
Db: db,
container: postgresContainer,
}
}
func (tdb *TestDatabase) TearDown() {
tdb.Db.Close()
_ = tdb.container.Terminate(context.Background())
}
func TestCRUD(t *testing.T) { func TestCRUD(t *testing.T) {
tdb := NewTestDatabase(t) ctx := context.Background()
logger := slog.Default()
tdb := test.NewTestDatabase(t)
defer tdb.TearDown() defer tdb.TearDown()
r := NewPostgresTodoRepository(tdb.Db) r := NewPostgresTodoRepository(logger, tdb.Db)
t.Run("creates new todo", func(t *testing.T) { t.Run("creates new todo", func(t *testing.T) {
want := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false} want := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false}
newTodo := repository.TodoRow{Name: "clean dishes", Done: false} newTodo := repository.TodoRow{Name: "clean dishes", Done: false}
got, err := r.Create(newTodo) got, err := r.Create(ctx, newTodo)
AssertNoError(t, err) AssertNoError(t, err)
AssertTodoRows(t, got, want) AssertTodoRows(t, got, want)
}) })
t.Run("gets todo", func(t *testing.T) { t.Run("gets todo", func(t *testing.T) {
want := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false} want := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false}
got, err := r.GetById(1) got, err := r.GetById(ctx, 1)
AssertNoError(t, err) AssertNoError(t, err)
AssertTodoRows(t, got, want) AssertTodoRows(t, got, want)
}) })
t.Run("updates todo", func(t *testing.T) { t.Run("updates todo", func(t *testing.T) {
want := repository.TodoRow{Id: 1, Name: "clean dishes", Done: true} want := repository.TodoRow{Id: 1, Name: "clean dishes", Done: true}
err := r.Update(want) err := r.Update(ctx, want)
AssertNoError(t, err) AssertNoError(t, err)
got, err := r.GetById(1) got, err := r.GetById(ctx, 1)
AssertNoError(t, err) AssertNoError(t, err)
AssertTodoRows(t, got, want) AssertTodoRows(t, got, want)
}) })
t.Run("deletes todo", func(t *testing.T) { t.Run("deletes todo", func(t *testing.T) {
err := r.Delete(1) err := r.Delete(ctx, 1)
AssertNoError(t, err) AssertNoError(t, err)
want := repository.ErrNotFound want := repository.ErrNotFound
_, got := r.GetById(1) _, got := r.GetById(ctx, 1)
AssertErrors(t, got, want) AssertErrors(t, got, want)
}) })

View File

@ -5,7 +5,7 @@ import (
) )
var ( var (
ErrNotFound error = errors.New("Todo cannot be found") ErrNotFound error = errors.New("todo cannot be found")
) )
type TodoRow struct { type TodoRow struct {

View File

@ -1,14 +1,15 @@
package service package service
import ( import (
"context"
"errors" "errors"
"log" "log/slog"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository" "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
) )
var ( var (
ErrNotFound error = errors.New("Todo cannot be found") ErrNotFound error = errors.New("todo cannot be found")
) )
type Todo struct { type Todo struct {
@ -34,65 +35,78 @@ func (t Todo) Equal(todo Todo) bool {
} }
type TodoRepository interface { type TodoRepository interface {
Create(todo repository.TodoRow) (repository.TodoRow, error) Create(ctx context.Context, todo repository.TodoRow) (repository.TodoRow, error)
GetById(id int64) (repository.TodoRow, error) GetById(ctx context.Context, id int64) (repository.TodoRow, error)
Update(todo repository.TodoRow) error Update(ctx context.Context, todo repository.TodoRow) error
Delete(id int64) error Delete(ctx context.Context, id int64) error
} }
type TodoService struct { type TodoService struct {
repo TodoRepository logger *slog.Logger
repo TodoRepository
} }
func NewTodoService(todoRepo TodoRepository) *TodoService { func NewTodoService(logger *slog.Logger, todoRepo TodoRepository) *TodoService {
return &TodoService{todoRepo} return &TodoService{
logger: logger,
repo: todoRepo,
}
} }
func (s *TodoService) GetTodo(id int64) (Todo, error) { func (s *TodoService) GetTodo(ctx context.Context, id int64) (Todo, error) {
todo, err := s.repo.GetById(id) todo, err := s.repo.GetById(ctx, id)
if err != nil { if err != nil {
if err == repository.ErrNotFound { if err == repository.ErrNotFound {
return Todo{}, ErrNotFound return Todo{}, ErrNotFound
} }
s.logger.ErrorContext(ctx, err.Error())
return Todo{}, err return Todo{}, err
} }
return TodoFromTodoRow(todo), err return TodoFromTodoRow(todo), err
} }
func (s *TodoService) CreateTodo(todo Todo) (Todo, error) { func (s *TodoService) CreateTodo(ctx context.Context, todo Todo) (Todo, error) {
todoRow := TodoRowFromTodo(todo) todoRow := TodoRowFromTodo(todo)
newTodoRow, err := s.repo.Create(todoRow) newTodoRow, err := s.repo.Create(ctx, todoRow)
if err != nil { if err != nil {
log.Print(err) s.logger.ErrorContext(ctx, err.Error())
return Todo{}, err return Todo{}, err
} }
return TodoFromTodoRow(newTodoRow), err return TodoFromTodoRow(newTodoRow), err
} }
func (s *TodoService) DeleteTodo(id int64) error { func (s *TodoService) DeleteTodo(ctx context.Context, id int64) error {
err := s.repo.Delete(id) err := s.repo.Delete(ctx, id)
if err == repository.ErrNotFound { if err == repository.ErrNotFound {
return ErrNotFound return ErrNotFound
} }
if err != nil {
s.logger.ErrorContext(ctx, err.Error())
}
return err return err
} }
func (s *TodoService) UpdateTodo(todo Todo) error { func (s *TodoService) UpdateTodo(ctx context.Context, todo Todo) error {
todoRow := TodoRowFromTodo(todo) todoRow := TodoRowFromTodo(todo)
err := s.repo.Update(todoRow) err := s.repo.Update(ctx, todoRow)
if err == repository.ErrNotFound { if err == repository.ErrNotFound {
return ErrNotFound return ErrNotFound
} }
if err != nil {
s.logger.ErrorContext(ctx, err.Error())
}
return err return err
} }

View File

@ -1,82 +1,112 @@
package service_test package service_test
import ( import (
"context"
"log/slog"
"testing" "testing"
"gitea.michaelthomson.dev/mthomson/habits/internal/test"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository" "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository/inmemory" "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository/postgres"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/service" "gitea.michaelthomson.dev/mthomson/habits/internal/todo/service"
_ "github.com/jackc/pgx/v5/stdlib"
) )
func TestCreateTodo(t *testing.T) { func TestCreateTodo(t *testing.T) {
todoRepository := inmemory.NewInMemoryTodoRepository() t.Parallel()
ctx := context.Background()
logger := slog.Default()
tdb := test.NewTestDatabase(t)
defer tdb.TearDown()
r := postgres.NewPostgresTodoRepository(logger, tdb.Db)
todoService := service.NewTodoService(&todoRepository) todoService := service.NewTodoService(logger, r)
t.Run("Create todo", func(t *testing.T) { t.Run("Create todo", func(t *testing.T) {
todo := service.NewTodo("clean dishes", false) todo := service.NewTodo("clean dishes", false)
_, err := todoService.CreateTodo(todo) _, err := todoService.CreateTodo(ctx, todo)
AssertNoError(t, err) AssertNoError(t, err)
}) })
} }
func TestGetTodo(t *testing.T) { func TestGetTodo(t *testing.T) {
todoRepository := inmemory.NewInMemoryTodoRepository() t.Parallel()
ctx := context.Background()
logger := slog.Default()
tdb := test.NewTestDatabase(t)
defer tdb.TearDown()
r := postgres.NewPostgresTodoRepository(logger, tdb.Db)
todoRepository.Db[1] = repository.TodoRow{Id: 1, Name: "clean dishes", Done: false} row := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false}
_, err := r.Create(ctx, row)
AssertNoError(t, err);
todoService := service.NewTodoService(&todoRepository) todoService := service.NewTodoService(logger, r)
t.Run("Get exisiting todo", func(t *testing.T) { t.Run("Get exisiting todo", func(t *testing.T) {
_, err := todoService.GetTodo(1) _, err := todoService.GetTodo(ctx, 1)
AssertNoError(t, err) AssertNoError(t, err)
}) })
t.Run("Get non-existant todo", func(t *testing.T) { t.Run("Get non-existant todo", func(t *testing.T) {
_, err := todoService.GetTodo(2) _, err := todoService.GetTodo(ctx, 2)
AssertErrors(t, err, service.ErrNotFound) AssertErrors(t, err, service.ErrNotFound)
}) })
} }
func TestDeleteTodo(t *testing.T) { func TestDeleteTodo(t *testing.T) {
todoRepository := inmemory.NewInMemoryTodoRepository() t.Parallel()
ctx := context.Background()
logger := slog.Default()
tdb := test.NewTestDatabase(t)
defer tdb.TearDown()
r := postgres.NewPostgresTodoRepository(logger, tdb.Db)
todoRepository.Db[1] = repository.TodoRow{Id: 1, Name: "clean dishes", Done: false} row := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false}
_, err := r.Create(ctx, row)
AssertNoError(t, err);
todoService := service.NewTodoService(&todoRepository) todoService := service.NewTodoService(logger, r)
t.Run("Delete exisiting todo", func(t *testing.T) { t.Run("Delete exisiting todo", func(t *testing.T) {
err := todoService.DeleteTodo(1) err := todoService.DeleteTodo(ctx, 1)
AssertNoError(t, err) AssertNoError(t, err)
}) })
t.Run("Delete non-existant todo", func(t *testing.T) { t.Run("Delete non-existant todo", func(t *testing.T) {
err := todoService.DeleteTodo(1) err := todoService.DeleteTodo(ctx, 1)
AssertErrors(t, err, service.ErrNotFound) AssertErrors(t, err, service.ErrNotFound)
}) })
} }
func TestUpdateTodo(t *testing.T) { func TestUpdateTodo(t *testing.T) {
todoRepository := inmemory.NewInMemoryTodoRepository() t.Parallel()
ctx := context.Background()
logger := slog.Default()
tdb := test.NewTestDatabase(t)
defer tdb.TearDown()
r := postgres.NewPostgresTodoRepository(logger, tdb.Db)
todoRepository.Db[1] = repository.TodoRow{Id: 1, Name: "clean dishes", Done: false} row := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false}
_, err := r.Create(ctx, row)
AssertNoError(t, err);
todoService := service.NewTodoService(&todoRepository) todoService := service.NewTodoService(logger, r)
t.Run("Update exisiting todo", func(t *testing.T) { t.Run("Update exisiting todo", func(t *testing.T) {
todo := service.Todo{1, "clean dishes", true} todo := service.Todo{1, "clean dishes", true}
err := todoService.UpdateTodo(todo) err := todoService.UpdateTodo(ctx, todo)
AssertNoError(t, err) AssertNoError(t, err)
newTodo, err := todoService.GetTodo(1) newTodo, err := todoService.GetTodo(ctx, 1)
AssertNoError(t, err) AssertNoError(t, err)
@ -86,7 +116,7 @@ func TestUpdateTodo(t *testing.T) {
t.Run("Update non-existant todo", func(t *testing.T) { t.Run("Update non-existant todo", func(t *testing.T) {
todo := service.Todo{2, "clean dishes", true} todo := service.Todo{2, "clean dishes", true}
err := todoService.UpdateTodo(todo) err := todoService.UpdateTodo(ctx, todo)
AssertErrors(t, err, service.ErrNotFound) AssertErrors(t, err, service.ErrNotFound)
}) })

View File

@ -7,5 +7,8 @@ build:
test: test:
go test ./... go test ./...
format:
go fmt ./...
clean: clean:
rm tmp/main habits.db rm tmp/main habits.db