Logging, middleware, and context passing #3

Merged
mthomson merged 4 commits from feature/logging into main 2025-05-15 17:47:37 +00:00
24 changed files with 381 additions and 187 deletions

View File

@ -3,51 +3,56 @@ package main
import (
"database/sql"
"log"
"log/slog"
"net/http"
"os"
"gitea.michaelthomson.dev/mthomson/habits/internal/logging"
"gitea.michaelthomson.dev/mthomson/habits/internal/middleware"
"gitea.michaelthomson.dev/mthomson/habits/internal/migrate"
todohandler "gitea.michaelthomson.dev/mthomson/habits/internal/todo/handler"
todorepository "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository/postgres"
todoservice "gitea.michaelthomson.dev/mthomson/habits/internal/todo/service"
_ "github.com/jackc/pgx/v5/stdlib"
_ "github.com/mattn/go-sqlite3"
)
func main() {
// create logger
httpLogger := slog.New(slog.NewTextHandler(os.Stdout, nil))
logger := logging.NewLogger()
// create middlewares
loggingMiddleware := middleware.LoggingMiddleware(httpLogger)
contextMiddleware := middleware.ContextMiddleware(logger)
loggingMiddleware := middleware.LoggingMiddleware(logger)
stack := []middleware.Middleware{
contextMiddleware,
loggingMiddleware,
}
// create db pool
postgresUrl := "postgres://todo:password@localhost:5432/todo"
db, err := sql.Open("pgx", postgresUrl)
if err != nil {
log.Fatalf("Failed to open db pool: %v", err)
logger.Error(err.Error())
os.Exit(1);
}
defer db.Close()
// run migrations
migrate.Migrate(db)
migrate.Migrate(logger, db)
// create repos
todoRepository := todorepository.NewPostgresTodoRepository(db)
todoRepository := todorepository.NewPostgresTodoRepository(logger, db)
// create services
todoService := todoservice.NewTodoService(todoRepository)
todoService := todoservice.NewTodoService(logger, todoRepository)
// create mux
mux := http.NewServeMux()
// register handlers
mux.Handle("GET /todo/{id}", loggingMiddleware(todohandler.HandleTodoGet(todoService)))
mux.Handle("POST /todo", loggingMiddleware(todohandler.HandleTodoCreate(todoService)))
mux.Handle("DELETE /todo/{id}", loggingMiddleware(todohandler.HandleTodoDelete(todoService)))
mux.Handle("PUT /todo/{id}", loggingMiddleware(todohandler.HandleTodoUpdate(todoService)))
mux.Handle("GET /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoGet(logger, todoService), stack))
mux.Handle("POST /todo", middleware.CompileMiddleware(todohandler.HandleTodoCreate(logger, todoService), stack))
mux.Handle("DELETE /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoDelete(logger, todoService), stack))
mux.Handle("PUT /todo/{id}", middleware.CompileMiddleware(todohandler.HandleTodoUpdate(logger, todoService), stack))
// create server
server := &http.Server{

10
flake.lock generated
View File

@ -2,12 +2,12 @@
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 1731676054,
"narHash": "sha256-OZiZ3m8SCMfh3B6bfGC/Bm4x3qc1m2SVEAlkV6iY7Yg=",
"rev": "5e4fbfb6b3de1aa2872b76d49fafc942626e2add",
"revCount": 708622,
"lastModified": 1747179050,
"narHash": "sha256-qhFMmDkeJX9KJwr5H32f1r7Prs7XbQWtO0h3V0a0rFY=",
"rev": "adaa24fbf46737f3f1b5497bf64bae750f82942e",
"revCount": 799423,
"type": "tarball",
"url": "https://api.flakehub.com/f/pinned/NixOS/nixpkgs/0.1.708622%2Brev-5e4fbfb6b3de1aa2872b76d49fafc942626e2add/0193363c-ab27-7bbd-af1d-3e6093ed5e2d/source.tar.gz"
"url": "https://api.flakehub.com/f/pinned/NixOS/nixpkgs/0.1.799423%2Brev-adaa24fbf46737f3f1b5497bf64bae750f82942e/0196d1c3-1974-7bf1-bcf6-06620ac40c8c/source.tar.gz"
},
"original": {
"type": "tarball",

View File

@ -5,7 +5,7 @@
outputs = { self, nixpkgs }:
let
goVersion = 22; # Change this to update the whole stack
goVersion = 23; # Change this to update the whole stack
supportedSystems = [ "x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin" ];
forEachSupportedSystem = f: nixpkgs.lib.genAttrs supportedSystems (system: f {
@ -36,6 +36,7 @@
docker
docker-compose
gopls
];
};
});

View File

@ -0,0 +1,28 @@
package logging
import (
"context"
"log/slog"
"os"
"gitea.michaelthomson.dev/mthomson/habits/internal/middleware"
)
type ContextHandler struct {
slog.Handler
}
func (h *ContextHandler) Handle(ctx context.Context, r slog.Record) error {
if requestID, ok := ctx.Value(middleware.TraceIdKey).(string); ok {
r.AddAttrs(slog.String(string(middleware.TraceIdKey), requestID))
}
return h.Handler.Handle(ctx, r)
}
func NewLogger() *slog.Logger {
baseHandler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{AddSource: false})
customHandler := &ContextHandler{Handler: baseHandler}
logger := slog.New(customHandler)
return logger
}

View File

@ -0,0 +1,24 @@
package middleware
import (
"context"
"log/slog"
"net/http"
"github.com/google/uuid"
)
type contextKey string
const TraceIdKey contextKey = "trace_id"
func ContextMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
traceid := uuid.NewString()
ctx := context.WithValue(r.Context(), TraceIdKey, traceid)
newReq := r.WithContext(ctx)
next.ServeHTTP(w, newReq)
})
}
}

View File

@ -1,23 +1,39 @@
package middleware
import (
"context"
"log/slog"
"net/http"
)
type LoggingResponseWriter struct {
http.ResponseWriter
statusCode int
}
func NewLoggingResponseWriter(w http.ResponseWriter) *LoggingResponseWriter {
return &LoggingResponseWriter{w, http.StatusOK}
}
func (lrw *LoggingResponseWriter) WriteHeader(code int) {
lrw.statusCode = code
lrw.ResponseWriter.WriteHeader(code)
}
func LoggingMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.LogAttrs(
context.Background(),
slog.LevelInfo,
"Incoming request",
logger.InfoContext(r.Context(), "Incoming request",
slog.String("method", r.Method),
slog.String("path", r.URL.String()),
)
next.ServeHTTP(w, r)
lrw := NewLoggingResponseWriter(w)
next.ServeHTTP(lrw, r)
logger.InfoContext(r.Context(), "Sent response",
slog.Int("code", lrw.statusCode),
slog.String("message", http.StatusText(lrw.statusCode)),
)
})
}
}

View File

@ -0,0 +1,19 @@
package middleware
import "net/http"
type Middleware func(http.Handler) http.Handler
func CompileMiddleware(h http.Handler, m []Middleware) http.Handler {
if len(m) < 1 {
return h
}
wrapped := h
for i := len(m) - 1; i >= 0; i-- {
wrapped = m[i](wrapped)
}
return wrapped
}

View File

@ -7,7 +7,7 @@ import (
"log"
"log/slog"
_ "github.com/mattn/go-sqlite3"
_ "github.com/jackc/pgx/v5/stdlib"
)
//go:embed migrations/*.sql
@ -18,8 +18,8 @@ type Migration struct {
Name string
}
func Migrate(db *sql.DB) {
slog.Info("Running migrations...")
func Migrate(logger *slog.Logger, db *sql.DB) {
logger.Info("Running migrations...")
migrationTableSql := `
CREATE TABLE IF NOT EXISTS migrations(
version SERIAL PRIMARY KEY,
@ -41,7 +41,7 @@ func Migrate(db *sql.DB) {
row := db.QueryRow("SELECT * FROM migrations WHERE name = $1;", file.Name())
err = row.Scan(&migration.Version, &migration.Name)
if err == sql.ErrNoRows {
slog.Info(fmt.Sprintf("Running migration: %s", file.Name()))
logger.Info(fmt.Sprintf("Running migration: %s", file.Name()))
migrationSql, err := migrations.ReadFile(fmt.Sprintf("migrations/%s", file.Name()))
if err != nil {
log.Fatal(err)
@ -58,5 +58,5 @@ func Migrate(db *sql.DB) {
}
}
}
slog.Info("Migrations completed")
logger.Info("Migrations completed")
}

View File

@ -0,0 +1,61 @@
package test
import (
"context"
"database/sql"
"log/slog"
"testing"
"time"
"gitea.michaelthomson.dev/mthomson/habits/internal/migrate"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
)
type TestDatabase struct {
Db *sql.DB
container testcontainers.Container
}
func NewTestDatabase(tb testing.TB) *TestDatabase {
tb.Helper()
ctx := context.Background()
// create container
postgresContainer, err := postgres.Run(ctx,
"postgres:16-alpine",
postgres.WithDatabase("todo"),
postgres.WithUsername("todo"),
postgres.WithPassword("password"),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(5*time.Second)),
)
if err != nil {
tb.Fatalf("Failed to create postgres container, %v", err)
}
connectionString, err := postgresContainer.ConnectionString(ctx)
if err != nil {
tb.Fatalf("Failed to get connection string: %v", err)
}
// create db pool
db, err := sql.Open("pgx", connectionString)
if err != nil {
tb.Fatalf("Failed to open db pool: %v", err)
}
migrate.Migrate(slog.Default(), db)
return &TestDatabase{
Db: db,
container: postgresContainer,
}
}
func (tdb *TestDatabase) TearDown() {
_ = tdb.container.Terminate(context.Background())
}

View File

@ -3,6 +3,7 @@ package handler
import (
"encoding/json"
"fmt"
"log/slog"
"net/http"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/service"
@ -27,19 +28,21 @@ func CreateTodoResponseFromTodo(todo service.Todo) CreateTodoResponse {
return CreateTodoResponse{Id: todo.Id, Name: todo.Name, Done: todo.Done}
}
func HandleTodoCreate(todoService TodoCreater) http.HandlerFunc {
func HandleTodoCreate(logger *slog.Logger, todoService TodoCreater) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
createTodoRequest := CreateTodoRequest{}
decoder := json.NewDecoder(r.Body)
decoder.DisallowUnknownFields()
err := decoder.Decode(&createTodoRequest)
if err != nil {
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusBadRequest)
return
}
todo, err := todoService.CreateTodo(TodoFromCreateTodoRequest(createTodoRequest))
todo, err := todoService.CreateTodo(ctx, TodoFromCreateTodoRequest(createTodoRequest))
if err != nil {
if err == service.ErrNotFound {
@ -47,6 +50,7 @@ func HandleTodoCreate(todoService TodoCreater) http.HandlerFunc {
return
}
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusInternalServerError)
return
}
@ -60,6 +64,7 @@ func HandleTodoCreate(todoService TodoCreater) http.HandlerFunc {
err = json.NewEncoder(w).Encode(response)
if err != nil {
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusInternalServerError)
return
}

View File

@ -2,7 +2,9 @@ package handler
import (
"bytes"
"context"
"encoding/json"
"log/slog"
"errors"
"net/http"
@ -14,26 +16,27 @@ import (
)
type MockTodoCreater struct {
CreateTodoFunc func(todo service.Todo) (service.Todo, error)
CreateTodoFunc func(cxt context.Context, todo service.Todo) (service.Todo, error)
}
func (tg *MockTodoCreater) CreateTodo(todo service.Todo) (service.Todo, error) {
return tg.CreateTodoFunc(todo)
func (tg *MockTodoCreater) CreateTodo(ctx context.Context, todo service.Todo) (service.Todo, error) {
return tg.CreateTodoFunc(ctx, todo)
}
func TestCreateTodo(t *testing.T) {
logger := slog.Default()
t.Run("create todo", func(t *testing.T) {
createTodoRequest := CreateTodoRequest{Name: "clean dishes", Done: false}
createdTodo := service.Todo{Id: 1, Name: "clean dishes", Done: false}
want := CreateTodoResponse{Id: 1, Name: "clean dishes", Done: false}
service := MockTodoCreater{
CreateTodoFunc: func(todo service.Todo) (service.Todo, error) {
CreateTodoFunc: func(ctx context.Context, todo service.Todo) (service.Todo, error) {
return createdTodo, nil
},
}
handler := HandleTodoCreate(&service)
handler := HandleTodoCreate(logger, &service)
requestBody, err := json.Marshal(createTodoRequest)
@ -67,7 +70,7 @@ func TestCreateTodo(t *testing.T) {
})
t.Run("returns 400 with bad json", func(t *testing.T) {
handler := HandleTodoCreate(nil)
handler := HandleTodoCreate(logger, nil)
badStruct := struct {
Foo string
@ -95,12 +98,12 @@ func TestCreateTodo(t *testing.T) {
createTodoRequest := CreateTodoRequest{Name: "clean dishes", Done: false}
service := MockTodoCreater{
CreateTodoFunc: func(todo service.Todo) (service.Todo, error) {
CreateTodoFunc: func(ctx context.Context, todo service.Todo) (service.Todo, error) {
return service.Todo{}, errors.New("foo bar")
},
}
handler := HandleTodoCreate(&service)
handler := HandleTodoCreate(logger, &service)
requestBody, err := json.Marshal(createTodoRequest)

View File

@ -1,24 +1,27 @@
package handler
import (
"log/slog"
"net/http"
"strconv"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/service"
)
func HandleTodoDelete(todoService TodoDeleter) http.HandlerFunc {
func HandleTodoDelete(logger *slog.Logger, todoService TodoDeleter) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
idString := r.PathValue("id")
id, err := strconv.ParseInt(idString, 10, 64)
if err != nil {
slog.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusBadRequest)
return
}
err = todoService.DeleteTodo(id)
err = todoService.DeleteTodo(ctx, id)
if err != nil {
if err == service.ErrNotFound {
@ -26,6 +29,7 @@ func HandleTodoDelete(todoService TodoDeleter) http.HandlerFunc {
return
}
slog.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusInternalServerError)
return
}

View File

@ -1,7 +1,9 @@
package handler
import (
"context"
"errors"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
@ -10,22 +12,23 @@ import (
)
type MockTodoDeleter struct {
DeleteTodoFunc func(id int64) error
DeleteTodoFunc func(ctx context.Context, id int64) error
}
func (tg *MockTodoDeleter) DeleteTodo(id int64) error {
return tg.DeleteTodoFunc(id)
func (tg *MockTodoDeleter) DeleteTodo(ctx context.Context, id int64) error {
return tg.DeleteTodoFunc(ctx, id)
}
func TestDeleteTodo(t *testing.T) {
logger := slog.Default()
t.Run("deletes existing todo", func(t *testing.T) {
service := MockTodoDeleter{
DeleteTodoFunc: func(id int64) error {
DeleteTodoFunc: func(ctx context.Context, id int64) error {
return nil
},
}
handler := HandleTodoDelete(&service)
handler := HandleTodoDelete(logger, &service)
req := httptest.NewRequest(http.MethodDelete, "/todo/1", nil)
res := httptest.NewRecorder()
@ -39,7 +42,7 @@ func TestDeleteTodo(t *testing.T) {
})
t.Run("returns 400 with bad id", func(t *testing.T) {
handler := HandleTodoDelete(nil)
handler := HandleTodoDelete(logger, nil)
req := httptest.NewRequest(http.MethodDelete, "/todo/hello", nil)
res := httptest.NewRecorder()
@ -54,12 +57,12 @@ func TestDeleteTodo(t *testing.T) {
t.Run("returns 404 for not found todo", func(t *testing.T) {
service := MockTodoDeleter{
DeleteTodoFunc: func(id int64) error {
DeleteTodoFunc: func(ctx context.Context, id int64) error {
return service.ErrNotFound
},
}
handler := HandleTodoDelete(&service)
handler := HandleTodoDelete(logger, &service)
req := httptest.NewRequest(http.MethodDelete, "/todo/1", nil)
res := httptest.NewRecorder()
@ -74,12 +77,12 @@ func TestDeleteTodo(t *testing.T) {
t.Run("returns 500 for arbitrary errors", func(t *testing.T) {
service := MockTodoDeleter{
DeleteTodoFunc: func(id int64) error {
DeleteTodoFunc: func(ctx context.Context, id int64) error {
return errors.New("foo bar")
},
}
handler := HandleTodoDelete(&service)
handler := HandleTodoDelete(logger, &service)
req := httptest.NewRequest(http.MethodDelete, "/todo/1", nil)
res := httptest.NewRecorder()

View File

@ -2,6 +2,7 @@ package handler
import (
"encoding/json"
"log/slog"
"net/http"
"strconv"
@ -18,18 +19,20 @@ func GetTodoResponseFromTodo(todo service.Todo) GetTodoResponse {
return GetTodoResponse{Id: todo.Id, Name: todo.Name, Done: todo.Done}
}
func HandleTodoGet(todoService TodoGetter) http.HandlerFunc {
func HandleTodoGet(logger *slog.Logger, todoService TodoGetter) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
idString := r.PathValue("id")
id, err := strconv.ParseInt(idString, 10, 64)
if err != nil {
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusBadRequest)
return
}
todo, err := todoService.GetTodo(id)
todo, err := todoService.GetTodo(ctx, id)
if err != nil {
if err == service.ErrNotFound {
@ -37,6 +40,7 @@ func HandleTodoGet(todoService TodoGetter) http.HandlerFunc {
return
}
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusInternalServerError)
return
}
@ -49,6 +53,7 @@ func HandleTodoGet(todoService TodoGetter) http.HandlerFunc {
err = json.NewEncoder(w).Encode(response)
if err != nil {
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusInternalServerError)
return
}

View File

@ -1,9 +1,11 @@
package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"net/http/httptest"
"reflect"
@ -13,25 +15,26 @@ import (
)
type MockTodoGetter struct {
GetTodoFunc func(id int64) (service.Todo, error)
GetTodoFunc func(ctx context.Context, id int64) (service.Todo, error)
}
func (tg *MockTodoGetter) GetTodo(id int64) (service.Todo, error) {
return tg.GetTodoFunc(id)
func (tg *MockTodoGetter) GetTodo(ctx context.Context, id int64) (service.Todo, error) {
return tg.GetTodoFunc(ctx, id)
}
func TestGetTodo(t *testing.T) {
logger := slog.Default()
t.Run("gets existing todo", func(t *testing.T) {
todo := service.Todo{Id: 1, Name: "clean dishes", Done: false}
wantedTodo := GetTodoResponse{Id: 1, Name: "clean dishes", Done: false}
service := MockTodoGetter{
GetTodoFunc: func(id int64) (service.Todo, error) {
GetTodoFunc: func(ctx context.Context, id int64) (service.Todo, error) {
return todo, nil
},
}
handler := HandleTodoGet(&service)
handler := HandleTodoGet(logger, &service)
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/todo/%d", todo.Id), nil)
res := httptest.NewRecorder()
@ -56,7 +59,7 @@ func TestGetTodo(t *testing.T) {
})
t.Run("returns 400 with bad id", func(t *testing.T) {
handler := HandleTodoGet(nil)
handler := HandleTodoGet(logger, nil)
req := httptest.NewRequest(http.MethodGet, "/todo/hello", nil)
res := httptest.NewRecorder()
@ -71,12 +74,12 @@ func TestGetTodo(t *testing.T) {
t.Run("returns 404 for not found todo", func(t *testing.T) {
service := MockTodoGetter{
GetTodoFunc: func(id int64) (service.Todo, error) {
GetTodoFunc: func(ctx context.Context, id int64) (service.Todo, error) {
return service.Todo{}, service.ErrNotFound
},
}
handler := HandleTodoGet(&service)
handler := HandleTodoGet(logger, &service)
req := httptest.NewRequest(http.MethodGet, "/todo/1", nil)
res := httptest.NewRecorder()
@ -91,12 +94,12 @@ func TestGetTodo(t *testing.T) {
t.Run("returns 500 for arbitrary errors", func(t *testing.T) {
service := MockTodoGetter{
GetTodoFunc: func(id int64) (service.Todo, error) {
GetTodoFunc: func(ctx context.Context, id int64) (service.Todo, error) {
return service.Todo{}, errors.New("foo bar")
},
}
handler := HandleTodoGet(&service)
handler := HandleTodoGet(logger, &service)
req := httptest.NewRequest(http.MethodGet, "/todo/1", nil)
res := httptest.NewRecorder()

View File

@ -1,21 +1,23 @@
package handler
import (
"context"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/service"
)
type TodoGetter interface {
GetTodo(id int64) (service.Todo, error)
GetTodo(ctx context.Context, id int64) (service.Todo, error)
}
type TodoCreater interface {
CreateTodo(todo service.Todo) (service.Todo, error)
CreateTodo(ctx context.Context, todo service.Todo) (service.Todo, error)
}
type TodoDeleter interface {
DeleteTodo(id int64) error
DeleteTodo(ctx context.Context, id int64) error
}
type TodoUpdater interface {
UpdateTodo(todo service.Todo) error
UpdateTodo(ctx context.Context, todo service.Todo) error
}

View File

@ -2,6 +2,7 @@ package handler
import (
"encoding/json"
"log/slog"
"net/http"
"strconv"
@ -17,14 +18,16 @@ func TodoFromUpdateTodoRequest(todo UpdateTodoRequest, id int64) service.Todo {
return service.Todo{Id: id, Name: todo.Name, Done: todo.Done}
}
func HandleTodoUpdate(todoService TodoUpdater) http.HandlerFunc {
func HandleTodoUpdate(logger *slog.Logger, todoService TodoUpdater) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
updateTodoRequest := UpdateTodoRequest{}
decoder := json.NewDecoder(r.Body)
decoder.DisallowUnknownFields()
err := decoder.Decode(&updateTodoRequest)
if err != nil {
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusBadRequest)
return
}
@ -34,11 +37,12 @@ func HandleTodoUpdate(todoService TodoUpdater) http.HandlerFunc {
id, err := strconv.ParseInt(idString, 10, 64)
if err != nil {
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusBadRequest)
return
}
err = todoService.UpdateTodo(TodoFromUpdateTodoRequest(updateTodoRequest, id))
err = todoService.UpdateTodo(ctx, TodoFromUpdateTodoRequest(updateTodoRequest, id))
if err != nil {
if err == service.ErrNotFound {
@ -46,6 +50,7 @@ func HandleTodoUpdate(todoService TodoUpdater) http.HandlerFunc {
return
}
logger.ErrorContext(ctx, err.Error())
http.Error(w, "", http.StatusInternalServerError)
return
}

View File

@ -2,7 +2,9 @@ package handler
import (
"bytes"
"context"
"encoding/json"
"log/slog"
"errors"
"net/http"
@ -13,24 +15,25 @@ import (
)
type MockTodoUpdater struct {
UpdateTodoFunc func(todo service.Todo) error
UpdateTodoFunc func(ctx context.Context, todo service.Todo) error
}
func (tg *MockTodoUpdater) UpdateTodo(todo service.Todo) error {
return tg.UpdateTodoFunc(todo)
func (tg *MockTodoUpdater) UpdateTodo(ctx context.Context, todo service.Todo) error {
return tg.UpdateTodoFunc(ctx, todo)
}
func TestUpdateTodo(t *testing.T) {
logger := slog.Default()
t.Run("update todo", func(t *testing.T) {
updateTodoRequest := UpdateTodoRequest{Name: "clean dishes", Done: false}
service := MockTodoUpdater{
UpdateTodoFunc: func(todo service.Todo) error {
UpdateTodoFunc: func(ctx context.Context, todo service.Todo) error {
return nil
},
}
handler := HandleTodoUpdate(&service)
handler := HandleTodoUpdate(logger, &service)
requestBody, err := json.Marshal(updateTodoRequest)
@ -50,7 +53,7 @@ func TestUpdateTodo(t *testing.T) {
})
t.Run("returns 400 with bad json", func(t *testing.T) {
handler := HandleTodoUpdate(nil)
handler := HandleTodoUpdate(logger, nil)
badStruct := struct {
Foo string
@ -75,7 +78,7 @@ func TestUpdateTodo(t *testing.T) {
})
t.Run("returns 400 with bad id", func(t *testing.T) {
handler := HandleTodoUpdate(nil)
handler := HandleTodoUpdate(logger, nil)
req := httptest.NewRequest(http.MethodPut, "/todo/hello", nil)
res := httptest.NewRecorder()
@ -92,12 +95,12 @@ func TestUpdateTodo(t *testing.T) {
updateTodoRequest := UpdateTodoRequest{Name: "clean dishes", Done: false}
service := MockTodoUpdater{
UpdateTodoFunc: func(todo service.Todo) error {
UpdateTodoFunc: func(ctx context.Context, todo service.Todo) error {
return errors.New("foo bar")
},
}
handler := HandleTodoUpdate(&service)
handler := HandleTodoUpdate(logger, &service)
requestBody, err := json.Marshal(updateTodoRequest)

View File

@ -1,22 +1,26 @@
package postgres
import (
"context"
"database/sql"
"log/slog"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
)
type PostgresTodoRepository struct {
db *sql.DB
logger *slog.Logger
db *sql.DB
}
func NewPostgresTodoRepository(db *sql.DB) *PostgresTodoRepository {
func NewPostgresTodoRepository(logger *slog.Logger, db *sql.DB) *PostgresTodoRepository {
return &PostgresTodoRepository{
db: db,
logger: logger,
db: db,
}
}
func (r *PostgresTodoRepository) GetById(id int64) (repository.TodoRow, error) {
func (r *PostgresTodoRepository) GetById(ctx context.Context, id int64) (repository.TodoRow, error) {
todo := repository.TodoRow{}
err := r.db.QueryRow("SELECT * FROM todo WHERE id = $1;", id).Scan(&todo.Id, &todo.Name, &todo.Done)
@ -26,33 +30,37 @@ func (r *PostgresTodoRepository) GetById(id int64) (repository.TodoRow, error) {
return todo, repository.ErrNotFound
}
r.logger.ErrorContext(ctx, err.Error())
return todo, err
}
return todo, nil
}
func (r *PostgresTodoRepository) Create(todo repository.TodoRow) (repository.TodoRow, error) {
func (r *PostgresTodoRepository) Create(ctx context.Context, todo repository.TodoRow) (repository.TodoRow, error) {
result := r.db.QueryRow("INSERT INTO todo (name, done) VALUES ($1, $2) RETURNING id;", todo.Name, todo.Done)
err := result.Scan(&todo.Id)
if err != nil {
r.logger.ErrorContext(ctx, err.Error())
return repository.TodoRow{}, err
}
return todo, nil
}
func (r *PostgresTodoRepository) Update(todo repository.TodoRow) error {
func (r *PostgresTodoRepository) Update(ctx context.Context, todo repository.TodoRow) error {
result, err := r.db.Exec("UPDATE todo SET name = $1, done = $2 WHERE id = $3;", todo.Name, todo.Done, todo.Id)
if err != nil {
r.logger.ErrorContext(ctx, err.Error())
return err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
r.logger.ErrorContext(ctx, err.Error())
return err
}
@ -63,16 +71,18 @@ func (r *PostgresTodoRepository) Update(todo repository.TodoRow) error {
return nil
}
func (r *PostgresTodoRepository) Delete(id int64) error {
func (r *PostgresTodoRepository) Delete(ctx context.Context, id int64) error {
result, err := r.db.Exec("DELETE FROM todo WHERE id = $1;", id)
if err != nil {
r.logger.ErrorContext(ctx, err.Error())
return err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
r.logger.ErrorContext(ctx, err.Error())
return err
}

View File

@ -2,103 +2,53 @@ package postgres
import (
"context"
"database/sql"
"errors"
"log/slog"
"testing"
"time"
"gitea.michaelthomson.dev/mthomson/habits/internal/migrate"
"gitea.michaelthomson.dev/mthomson/habits/internal/test"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
)
type TestDatabase struct {
Db *sql.DB
container testcontainers.Container
}
func NewTestDatabase(tb testing.TB) *TestDatabase {
tb.Helper()
ctx := context.Background()
// create container
postgresContainer, err := postgres.Run(ctx,
"postgres:16-alpine",
postgres.WithDatabase("todo"),
postgres.WithUsername("todo"),
postgres.WithPassword("password"),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(5*time.Second)),
)
if err != nil {
tb.Fatalf("Failed to create postgres container, %v", err)
}
connectionString, err := postgresContainer.ConnectionString(ctx)
if err != nil {
tb.Fatalf("Failed to get connection string: %v", err)
}
// create db pool
db, err := sql.Open("pgx", connectionString)
if err != nil {
tb.Fatalf("Failed to open db pool: %v", err)
}
migrate.Migrate(db)
return &TestDatabase{
Db: db,
container: postgresContainer,
}
}
func (tdb *TestDatabase) TearDown() {
tdb.Db.Close()
_ = tdb.container.Terminate(context.Background())
}
func TestCRUD(t *testing.T) {
tdb := NewTestDatabase(t)
ctx := context.Background()
logger := slog.Default()
tdb := test.NewTestDatabase(t)
defer tdb.TearDown()
r := NewPostgresTodoRepository(tdb.Db)
r := NewPostgresTodoRepository(logger, tdb.Db)
t.Run("creates new todo", func(t *testing.T) {
want := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false}
newTodo := repository.TodoRow{Name: "clean dishes", Done: false}
got, err := r.Create(newTodo)
got, err := r.Create(ctx, newTodo)
AssertNoError(t, err)
AssertTodoRows(t, got, want)
})
t.Run("gets todo", func(t *testing.T) {
want := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false}
got, err := r.GetById(1)
got, err := r.GetById(ctx, 1)
AssertNoError(t, err)
AssertTodoRows(t, got, want)
})
t.Run("updates todo", func(t *testing.T) {
want := repository.TodoRow{Id: 1, Name: "clean dishes", Done: true}
err := r.Update(want)
err := r.Update(ctx, want)
AssertNoError(t, err)
got, err := r.GetById(1)
got, err := r.GetById(ctx, 1)
AssertNoError(t, err)
AssertTodoRows(t, got, want)
})
t.Run("deletes todo", func(t *testing.T) {
err := r.Delete(1)
err := r.Delete(ctx, 1)
AssertNoError(t, err)
want := repository.ErrNotFound
_, got := r.GetById(1)
_, got := r.GetById(ctx, 1)
AssertErrors(t, got, want)
})

View File

@ -5,7 +5,7 @@ import (
)
var (
ErrNotFound error = errors.New("Todo cannot be found")
ErrNotFound error = errors.New("todo cannot be found")
)
type TodoRow struct {

View File

@ -1,14 +1,15 @@
package service
import (
"context"
"errors"
"log"
"log/slog"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
)
var (
ErrNotFound error = errors.New("Todo cannot be found")
ErrNotFound error = errors.New("todo cannot be found")
)
type Todo struct {
@ -34,65 +35,78 @@ func (t Todo) Equal(todo Todo) bool {
}
type TodoRepository interface {
Create(todo repository.TodoRow) (repository.TodoRow, error)
GetById(id int64) (repository.TodoRow, error)
Update(todo repository.TodoRow) error
Delete(id int64) error
Create(ctx context.Context, todo repository.TodoRow) (repository.TodoRow, error)
GetById(ctx context.Context, id int64) (repository.TodoRow, error)
Update(ctx context.Context, todo repository.TodoRow) error
Delete(ctx context.Context, id int64) error
}
type TodoService struct {
repo TodoRepository
logger *slog.Logger
repo TodoRepository
}
func NewTodoService(todoRepo TodoRepository) *TodoService {
return &TodoService{todoRepo}
func NewTodoService(logger *slog.Logger, todoRepo TodoRepository) *TodoService {
return &TodoService{
logger: logger,
repo: todoRepo,
}
}
func (s *TodoService) GetTodo(id int64) (Todo, error) {
todo, err := s.repo.GetById(id)
func (s *TodoService) GetTodo(ctx context.Context, id int64) (Todo, error) {
todo, err := s.repo.GetById(ctx, id)
if err != nil {
if err == repository.ErrNotFound {
return Todo{}, ErrNotFound
}
s.logger.ErrorContext(ctx, err.Error())
return Todo{}, err
}
return TodoFromTodoRow(todo), err
}
func (s *TodoService) CreateTodo(todo Todo) (Todo, error) {
func (s *TodoService) CreateTodo(ctx context.Context, todo Todo) (Todo, error) {
todoRow := TodoRowFromTodo(todo)
newTodoRow, err := s.repo.Create(todoRow)
newTodoRow, err := s.repo.Create(ctx, todoRow)
if err != nil {
log.Print(err)
s.logger.ErrorContext(ctx, err.Error())
return Todo{}, err
}
return TodoFromTodoRow(newTodoRow), err
}
func (s *TodoService) DeleteTodo(id int64) error {
err := s.repo.Delete(id)
func (s *TodoService) DeleteTodo(ctx context.Context, id int64) error {
err := s.repo.Delete(ctx, id)
if err == repository.ErrNotFound {
return ErrNotFound
}
if err != nil {
s.logger.ErrorContext(ctx, err.Error())
}
return err
}
func (s *TodoService) UpdateTodo(todo Todo) error {
func (s *TodoService) UpdateTodo(ctx context.Context, todo Todo) error {
todoRow := TodoRowFromTodo(todo)
err := s.repo.Update(todoRow)
err := s.repo.Update(ctx, todoRow)
if err == repository.ErrNotFound {
return ErrNotFound
}
if err != nil {
s.logger.ErrorContext(ctx, err.Error())
}
return err
}

View File

@ -1,82 +1,112 @@
package service_test
import (
"context"
"log/slog"
"testing"
"gitea.michaelthomson.dev/mthomson/habits/internal/test"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository/inmemory"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository/postgres"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/service"
_ "github.com/jackc/pgx/v5/stdlib"
)
func TestCreateTodo(t *testing.T) {
todoRepository := inmemory.NewInMemoryTodoRepository()
t.Parallel()
ctx := context.Background()
logger := slog.Default()
tdb := test.NewTestDatabase(t)
defer tdb.TearDown()
r := postgres.NewPostgresTodoRepository(logger, tdb.Db)
todoService := service.NewTodoService(&todoRepository)
todoService := service.NewTodoService(logger, r)
t.Run("Create todo", func(t *testing.T) {
todo := service.NewTodo("clean dishes", false)
_, err := todoService.CreateTodo(todo)
_, err := todoService.CreateTodo(ctx, todo)
AssertNoError(t, err)
})
}
func TestGetTodo(t *testing.T) {
todoRepository := inmemory.NewInMemoryTodoRepository()
t.Parallel()
ctx := context.Background()
logger := slog.Default()
tdb := test.NewTestDatabase(t)
defer tdb.TearDown()
r := postgres.NewPostgresTodoRepository(logger, tdb.Db)
todoRepository.Db[1] = repository.TodoRow{Id: 1, Name: "clean dishes", Done: false}
row := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false}
_, err := r.Create(ctx, row)
AssertNoError(t, err);
todoService := service.NewTodoService(&todoRepository)
todoService := service.NewTodoService(logger, r)
t.Run("Get exisiting todo", func(t *testing.T) {
_, err := todoService.GetTodo(1)
_, err := todoService.GetTodo(ctx, 1)
AssertNoError(t, err)
})
t.Run("Get non-existant todo", func(t *testing.T) {
_, err := todoService.GetTodo(2)
_, err := todoService.GetTodo(ctx, 2)
AssertErrors(t, err, service.ErrNotFound)
})
}
func TestDeleteTodo(t *testing.T) {
todoRepository := inmemory.NewInMemoryTodoRepository()
t.Parallel()
ctx := context.Background()
logger := slog.Default()
tdb := test.NewTestDatabase(t)
defer tdb.TearDown()
r := postgres.NewPostgresTodoRepository(logger, tdb.Db)
todoRepository.Db[1] = repository.TodoRow{Id: 1, Name: "clean dishes", Done: false}
row := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false}
_, err := r.Create(ctx, row)
AssertNoError(t, err);
todoService := service.NewTodoService(&todoRepository)
todoService := service.NewTodoService(logger, r)
t.Run("Delete exisiting todo", func(t *testing.T) {
err := todoService.DeleteTodo(1)
err := todoService.DeleteTodo(ctx, 1)
AssertNoError(t, err)
})
t.Run("Delete non-existant todo", func(t *testing.T) {
err := todoService.DeleteTodo(1)
err := todoService.DeleteTodo(ctx, 1)
AssertErrors(t, err, service.ErrNotFound)
})
}
func TestUpdateTodo(t *testing.T) {
todoRepository := inmemory.NewInMemoryTodoRepository()
t.Parallel()
ctx := context.Background()
logger := slog.Default()
tdb := test.NewTestDatabase(t)
defer tdb.TearDown()
r := postgres.NewPostgresTodoRepository(logger, tdb.Db)
todoRepository.Db[1] = repository.TodoRow{Id: 1, Name: "clean dishes", Done: false}
row := repository.TodoRow{Id: 1, Name: "clean dishes", Done: false}
_, err := r.Create(ctx, row)
AssertNoError(t, err);
todoService := service.NewTodoService(&todoRepository)
todoService := service.NewTodoService(logger, r)
t.Run("Update exisiting todo", func(t *testing.T) {
todo := service.Todo{1, "clean dishes", true}
err := todoService.UpdateTodo(todo)
err := todoService.UpdateTodo(ctx, todo)
AssertNoError(t, err)
newTodo, err := todoService.GetTodo(1)
newTodo, err := todoService.GetTodo(ctx, 1)
AssertNoError(t, err)
@ -86,7 +116,7 @@ func TestUpdateTodo(t *testing.T) {
t.Run("Update non-existant todo", func(t *testing.T) {
todo := service.Todo{2, "clean dishes", true}
err := todoService.UpdateTodo(todo)
err := todoService.UpdateTodo(ctx, todo)
AssertErrors(t, err, service.ErrNotFound)
})

View File

@ -7,5 +7,8 @@ build:
test:
go test ./...
format:
go fmt ./...
clean:
rm tmp/main habits.db