From f1bbd06ef7cfdc9cf38d3ba007d42e29d43d282b Mon Sep 17 00:00:00 2001 From: Michael Thomson Date: Thu, 15 May 2025 14:47:46 -0400 Subject: [PATCH] sql driver to pgx driver --- cmd/main.go | 7 +- internal/migrate/migrate.go | 17 ++-- internal/test/test_database.go | 6 +- internal/todo/repository/inmemory/inmemory.go | 59 ------------ internal/todo/repository/postgres/postgres.go | 31 +++---- .../todo/repository/postgres/postgres_test.go | 1 - internal/todo/repository/sqlite/sqlite.go | 92 ------------------- 7 files changed, 27 insertions(+), 186 deletions(-) delete mode 100644 internal/todo/repository/inmemory/inmemory.go delete mode 100644 internal/todo/repository/sqlite/sqlite.go diff --git a/cmd/main.go b/cmd/main.go index 5541279..ea9288f 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,7 +1,7 @@ package main import ( - "database/sql" + "context" "log" "net/http" "os" @@ -12,7 +12,7 @@ import ( todohandler "gitea.michaelthomson.dev/mthomson/habits/internal/todo/handler" todorepository "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository/postgres" todoservice "gitea.michaelthomson.dev/mthomson/habits/internal/todo/service" - _ "github.com/jackc/pgx/v5/stdlib" + "github.com/jackc/pgx/v5/pgxpool" ) func main() { @@ -30,7 +30,8 @@ func main() { // create db pool postgresUrl := "postgres://todo:password@localhost:5432/todo" - db, err := sql.Open("pgx", postgresUrl) + // db, err := sql.Open("pgx", postgresUrl) + db, err := pgxpool.New(context.Background(), postgresUrl) if err != nil { logger.Error(err.Error()) os.Exit(1); diff --git a/internal/migrate/migrate.go b/internal/migrate/migrate.go index a4a4f26..cf8a159 100644 --- a/internal/migrate/migrate.go +++ b/internal/migrate/migrate.go @@ -1,13 +1,14 @@ package migrate import ( - "database/sql" + "context" "embed" "fmt" "log" "log/slog" - _ "github.com/jackc/pgx/v5/stdlib" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" ) //go:embed migrations/*.sql @@ -18,7 +19,7 @@ type Migration struct { Name string } -func Migrate(logger *slog.Logger, db *sql.DB) { +func Migrate(logger *slog.Logger, db *pgxpool.Pool) { logger.Info("Running migrations...") migrationTableSql := ` CREATE TABLE IF NOT EXISTS migrations( @@ -26,7 +27,7 @@ func Migrate(logger *slog.Logger, db *sql.DB) { name VARCHAR(50) );` - _, err := db.Exec(migrationTableSql) + _, err := db.Exec(context.Background(), migrationTableSql) if err != nil { log.Fatal(err) } @@ -38,21 +39,21 @@ func Migrate(logger *slog.Logger, db *sql.DB) { for _, file := range files { var migration Migration - row := db.QueryRow("SELECT * FROM migrations WHERE name = $1;", file.Name()) + row := db.QueryRow(context.Background(), "SELECT * FROM migrations WHERE name = $1;", file.Name()) err = row.Scan(&migration.Version, &migration.Name) - if err == sql.ErrNoRows { + if err == pgx.ErrNoRows { logger.Info(fmt.Sprintf("Running migration: %s", file.Name())) migrationSql, err := migrations.ReadFile(fmt.Sprintf("migrations/%s", file.Name())) if err != nil { log.Fatal(err) } - _, err = db.Exec(string(migrationSql)) + _, err = db.Exec(context.Background(), string(migrationSql)) if err != nil { log.Fatal(err) } - _, err = db.Exec("INSERT INTO migrations(name) VALUES($1);", file.Name()) + _, err = db.Exec(context.Background(), "INSERT INTO migrations(name) VALUES($1);", file.Name()) if err != nil { log.Fatal(err) } diff --git a/internal/test/test_database.go b/internal/test/test_database.go index a6919e0..4496b5e 100644 --- a/internal/test/test_database.go +++ b/internal/test/test_database.go @@ -2,19 +2,19 @@ package test import ( "context" - "database/sql" "log/slog" "testing" "time" "gitea.michaelthomson.dev/mthomson/habits/internal/migrate" + "github.com/jackc/pgx/v5/pgxpool" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" ) type TestDatabase struct { - Db *sql.DB + Db *pgxpool.Pool container testcontainers.Container } @@ -43,7 +43,7 @@ func NewTestDatabase(tb testing.TB) *TestDatabase { } // create db pool - db, err := sql.Open("pgx", connectionString) + db, err := pgxpool.New(context.Background(), connectionString) if err != nil { tb.Fatalf("Failed to open db pool: %v", err) } diff --git a/internal/todo/repository/inmemory/inmemory.go b/internal/todo/repository/inmemory/inmemory.go deleted file mode 100644 index 8d4f109..0000000 --- a/internal/todo/repository/inmemory/inmemory.go +++ /dev/null @@ -1,59 +0,0 @@ -package inmemory - -import ( - "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository" -) - -type InMemoryTodoRepository struct { - Db map[int64]repository.TodoRow - id int64 -} - -func NewInMemoryTodoRepository() InMemoryTodoRepository { - return InMemoryTodoRepository{ - Db: make(map[int64]repository.TodoRow), - id: 1, - } -} - -func (r *InMemoryTodoRepository) GetById(id int64) (repository.TodoRow, error) { - todo, found := r.Db[id] - - if !found { - return todo, repository.ErrNotFound - } - - return todo, nil -} - -func (r *InMemoryTodoRepository) Create(todo repository.TodoRow) (repository.TodoRow, error) { - todo.Id = r.id - r.Db[r.id] = todo - r.id++ - - return todo, nil -} - -func (r *InMemoryTodoRepository) Update(todo repository.TodoRow) error { - _, found := r.Db[todo.Id] - - if !found { - return repository.ErrNotFound - } - - r.Db[todo.Id] = todo - - return nil -} - -func (r *InMemoryTodoRepository) Delete(id int64) error { - _, found := r.Db[id] - - if !found { - return repository.ErrNotFound - } - - delete(r.Db, id) - - return nil -} diff --git a/internal/todo/repository/postgres/postgres.go b/internal/todo/repository/postgres/postgres.go index c1c834f..bdf03cb 100644 --- a/internal/todo/repository/postgres/postgres.go +++ b/internal/todo/repository/postgres/postgres.go @@ -2,18 +2,19 @@ package postgres import ( "context" - "database/sql" "log/slog" "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" ) type PostgresTodoRepository struct { logger *slog.Logger - db *sql.DB + db *pgxpool.Pool } -func NewPostgresTodoRepository(logger *slog.Logger, db *sql.DB) *PostgresTodoRepository { +func NewPostgresTodoRepository(logger *slog.Logger, db *pgxpool.Pool) *PostgresTodoRepository { return &PostgresTodoRepository{ logger: logger, db: db, @@ -23,10 +24,10 @@ func NewPostgresTodoRepository(logger *slog.Logger, db *sql.DB) *PostgresTodoRep func (r *PostgresTodoRepository) GetById(ctx context.Context, id int64) (repository.TodoRow, error) { todo := repository.TodoRow{} - err := r.db.QueryRow("SELECT * FROM todo WHERE id = $1;", id).Scan(&todo.Id, &todo.Name, &todo.Done) + err := r.db.QueryRow(ctx, "SELECT * FROM todo WHERE id = $1;", id).Scan(&todo.Id, &todo.Name, &todo.Done) if err != nil { - if err == sql.ErrNoRows { + if err == pgx.ErrNoRows { return todo, repository.ErrNotFound } @@ -38,7 +39,7 @@ func (r *PostgresTodoRepository) GetById(ctx context.Context, id int64) (reposit } func (r *PostgresTodoRepository) Create(ctx context.Context, todo repository.TodoRow) (repository.TodoRow, error) { - result := r.db.QueryRow("INSERT INTO todo (name, done) VALUES ($1, $2) RETURNING id;", todo.Name, todo.Done) + result := r.db.QueryRow(ctx, "INSERT INTO todo (name, done) VALUES ($1, $2) RETURNING id;", todo.Name, todo.Done) err := result.Scan(&todo.Id) if err != nil { @@ -50,19 +51,14 @@ func (r *PostgresTodoRepository) Create(ctx context.Context, todo repository.Tod } func (r *PostgresTodoRepository) Update(ctx context.Context, todo repository.TodoRow) error { - result, err := r.db.Exec("UPDATE todo SET name = $1, done = $2 WHERE id = $3;", todo.Name, todo.Done, todo.Id) + result, err := r.db.Exec(ctx, "UPDATE todo SET name = $1, done = $2 WHERE id = $3;", todo.Name, todo.Done, todo.Id) if err != nil { r.logger.ErrorContext(ctx, err.Error()) return err } - rowsAffected, err := result.RowsAffected() - - if err != nil { - r.logger.ErrorContext(ctx, err.Error()) - return err - } + rowsAffected := result.RowsAffected() if rowsAffected == 0 { return repository.ErrNotFound @@ -72,19 +68,14 @@ func (r *PostgresTodoRepository) Update(ctx context.Context, todo repository.Tod } func (r *PostgresTodoRepository) Delete(ctx context.Context, id int64) error { - result, err := r.db.Exec("DELETE FROM todo WHERE id = $1;", id) + result, err := r.db.Exec(ctx, "DELETE FROM todo WHERE id = $1;", id) if err != nil { r.logger.ErrorContext(ctx, err.Error()) return err } - rowsAffected, err := result.RowsAffected() - - if err != nil { - r.logger.ErrorContext(ctx, err.Error()) - return err - } + rowsAffected := result.RowsAffected() if rowsAffected == 0 { return repository.ErrNotFound diff --git a/internal/todo/repository/postgres/postgres_test.go b/internal/todo/repository/postgres/postgres_test.go index a858954..7cb35c2 100644 --- a/internal/todo/repository/postgres/postgres_test.go +++ b/internal/todo/repository/postgres/postgres_test.go @@ -8,7 +8,6 @@ import ( "gitea.michaelthomson.dev/mthomson/habits/internal/test" "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository" - _ "github.com/jackc/pgx/v5/stdlib" ) func TestCRUD(t *testing.T) { diff --git a/internal/todo/repository/sqlite/sqlite.go b/internal/todo/repository/sqlite/sqlite.go deleted file mode 100644 index 5988b5c..0000000 --- a/internal/todo/repository/sqlite/sqlite.go +++ /dev/null @@ -1,92 +0,0 @@ -package sqlite - -import ( - "database/sql" - - "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository" -) - -type SqliteTodoRepository struct { - db *sql.DB -} - -func NewSqliteTodoRepository(db *sql.DB) *SqliteTodoRepository { - return &SqliteTodoRepository{ - db: db, - } -} - -func (r *SqliteTodoRepository) GetById(id int64) (repository.TodoRow, error) { - todo := repository.TodoRow{} - - row := r.db.QueryRow("SELECT * FROM todo WHERE id = ?;", id) - err := row.Scan(&todo.Id, &todo.Name, &todo.Done) - - if err != nil { - if err == sql.ErrNoRows { - return todo, repository.ErrNotFound - } - - return todo, err - } - - return todo, nil -} - -func (r *SqliteTodoRepository) Create(todo repository.TodoRow) (repository.TodoRow, error) { - result, err := r.db.Exec("INSERT INTO todo (name, done) VALUES (?, ?)", todo.Name, todo.Done) - - if err != nil { - return repository.TodoRow{}, err - } - - id, err := result.LastInsertId() - - if err != nil { - return repository.TodoRow{}, err - } - - todo.Id = id - - return todo, nil -} - -func (r *SqliteTodoRepository) Update(todo repository.TodoRow) error { - result, err := r.db.Exec("UPDATE todo SET name = ?, done = ? WHERE id = ?", todo.Name, todo.Done, todo.Id) - - if err != nil { - return err - } - - rowsAffected, err := result.RowsAffected() - - if err != nil { - return err - } - - if rowsAffected == 0 { - return repository.ErrNotFound - } - - return nil -} - -func (r *SqliteTodoRepository) Delete(id int64) error { - result, err := r.db.Exec("DELETE FROM todo WHERE id = ?", id) - - if err != nil { - return err - } - - rowsAffected, err := result.RowsAffected() - - if err != nil { - return err - } - - if rowsAffected == 0 { - return repository.ErrNotFound - } - - return nil -}