sql driver to pgx driver

This commit is contained in:
Michael Thomson 2025-05-15 14:47:46 -04:00
parent ab0e40c695
commit f1bbd06ef7
Signed by: mthomson
GPG Key ID: B6CA05EE5F436C79
7 changed files with 27 additions and 186 deletions

View File

@ -1,7 +1,7 @@
package main package main
import ( import (
"database/sql" "context"
"log" "log"
"net/http" "net/http"
"os" "os"
@ -12,7 +12,7 @@ import (
todohandler "gitea.michaelthomson.dev/mthomson/habits/internal/todo/handler" todohandler "gitea.michaelthomson.dev/mthomson/habits/internal/todo/handler"
todorepository "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository/postgres" todorepository "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository/postgres"
todoservice "gitea.michaelthomson.dev/mthomson/habits/internal/todo/service" todoservice "gitea.michaelthomson.dev/mthomson/habits/internal/todo/service"
_ "github.com/jackc/pgx/v5/stdlib" "github.com/jackc/pgx/v5/pgxpool"
) )
func main() { func main() {
@ -30,7 +30,8 @@ func main() {
// create db pool // create db pool
postgresUrl := "postgres://todo:password@localhost:5432/todo" postgresUrl := "postgres://todo:password@localhost:5432/todo"
db, err := sql.Open("pgx", postgresUrl) // db, err := sql.Open("pgx", postgresUrl)
db, err := pgxpool.New(context.Background(), postgresUrl)
if err != nil { if err != nil {
logger.Error(err.Error()) logger.Error(err.Error())
os.Exit(1); os.Exit(1);

View File

@ -1,13 +1,14 @@
package migrate package migrate
import ( import (
"database/sql" "context"
"embed" "embed"
"fmt" "fmt"
"log" "log"
"log/slog" "log/slog"
_ "github.com/jackc/pgx/v5/stdlib" "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
) )
//go:embed migrations/*.sql //go:embed migrations/*.sql
@ -18,7 +19,7 @@ type Migration struct {
Name string Name string
} }
func Migrate(logger *slog.Logger, db *sql.DB) { func Migrate(logger *slog.Logger, db *pgxpool.Pool) {
logger.Info("Running migrations...") logger.Info("Running migrations...")
migrationTableSql := ` migrationTableSql := `
CREATE TABLE IF NOT EXISTS migrations( CREATE TABLE IF NOT EXISTS migrations(
@ -26,7 +27,7 @@ func Migrate(logger *slog.Logger, db *sql.DB) {
name VARCHAR(50) name VARCHAR(50)
);` );`
_, err := db.Exec(migrationTableSql) _, err := db.Exec(context.Background(), migrationTableSql)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -38,21 +39,21 @@ func Migrate(logger *slog.Logger, db *sql.DB) {
for _, file := range files { for _, file := range files {
var migration Migration var migration Migration
row := db.QueryRow("SELECT * FROM migrations WHERE name = $1;", file.Name()) row := db.QueryRow(context.Background(), "SELECT * FROM migrations WHERE name = $1;", file.Name())
err = row.Scan(&migration.Version, &migration.Name) err = row.Scan(&migration.Version, &migration.Name)
if err == sql.ErrNoRows { if err == pgx.ErrNoRows {
logger.Info(fmt.Sprintf("Running migration: %s", file.Name())) logger.Info(fmt.Sprintf("Running migration: %s", file.Name()))
migrationSql, err := migrations.ReadFile(fmt.Sprintf("migrations/%s", file.Name())) migrationSql, err := migrations.ReadFile(fmt.Sprintf("migrations/%s", file.Name()))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
_, err = db.Exec(string(migrationSql)) _, err = db.Exec(context.Background(), string(migrationSql))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
_, err = db.Exec("INSERT INTO migrations(name) VALUES($1);", file.Name()) _, err = db.Exec(context.Background(), "INSERT INTO migrations(name) VALUES($1);", file.Name())
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@ -2,19 +2,19 @@ package test
import ( import (
"context" "context"
"database/sql"
"log/slog" "log/slog"
"testing" "testing"
"time" "time"
"gitea.michaelthomson.dev/mthomson/habits/internal/migrate" "gitea.michaelthomson.dev/mthomson/habits/internal/migrate"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait" "github.com/testcontainers/testcontainers-go/wait"
) )
type TestDatabase struct { type TestDatabase struct {
Db *sql.DB Db *pgxpool.Pool
container testcontainers.Container container testcontainers.Container
} }
@ -43,7 +43,7 @@ func NewTestDatabase(tb testing.TB) *TestDatabase {
} }
// create db pool // create db pool
db, err := sql.Open("pgx", connectionString) db, err := pgxpool.New(context.Background(), connectionString)
if err != nil { if err != nil {
tb.Fatalf("Failed to open db pool: %v", err) tb.Fatalf("Failed to open db pool: %v", err)
} }

View File

@ -1,59 +0,0 @@
package inmemory
import (
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
)
type InMemoryTodoRepository struct {
Db map[int64]repository.TodoRow
id int64
}
func NewInMemoryTodoRepository() InMemoryTodoRepository {
return InMemoryTodoRepository{
Db: make(map[int64]repository.TodoRow),
id: 1,
}
}
func (r *InMemoryTodoRepository) GetById(id int64) (repository.TodoRow, error) {
todo, found := r.Db[id]
if !found {
return todo, repository.ErrNotFound
}
return todo, nil
}
func (r *InMemoryTodoRepository) Create(todo repository.TodoRow) (repository.TodoRow, error) {
todo.Id = r.id
r.Db[r.id] = todo
r.id++
return todo, nil
}
func (r *InMemoryTodoRepository) Update(todo repository.TodoRow) error {
_, found := r.Db[todo.Id]
if !found {
return repository.ErrNotFound
}
r.Db[todo.Id] = todo
return nil
}
func (r *InMemoryTodoRepository) Delete(id int64) error {
_, found := r.Db[id]
if !found {
return repository.ErrNotFound
}
delete(r.Db, id)
return nil
}

View File

@ -2,18 +2,19 @@ package postgres
import ( import (
"context" "context"
"database/sql"
"log/slog" "log/slog"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository" "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
) )
type PostgresTodoRepository struct { type PostgresTodoRepository struct {
logger *slog.Logger logger *slog.Logger
db *sql.DB db *pgxpool.Pool
} }
func NewPostgresTodoRepository(logger *slog.Logger, db *sql.DB) *PostgresTodoRepository { func NewPostgresTodoRepository(logger *slog.Logger, db *pgxpool.Pool) *PostgresTodoRepository {
return &PostgresTodoRepository{ return &PostgresTodoRepository{
logger: logger, logger: logger,
db: db, db: db,
@ -23,10 +24,10 @@ func NewPostgresTodoRepository(logger *slog.Logger, db *sql.DB) *PostgresTodoRep
func (r *PostgresTodoRepository) GetById(ctx context.Context, id int64) (repository.TodoRow, error) { func (r *PostgresTodoRepository) GetById(ctx context.Context, id int64) (repository.TodoRow, error) {
todo := repository.TodoRow{} todo := repository.TodoRow{}
err := r.db.QueryRow("SELECT * FROM todo WHERE id = $1;", id).Scan(&todo.Id, &todo.Name, &todo.Done) err := r.db.QueryRow(ctx, "SELECT * FROM todo WHERE id = $1;", id).Scan(&todo.Id, &todo.Name, &todo.Done)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == pgx.ErrNoRows {
return todo, repository.ErrNotFound return todo, repository.ErrNotFound
} }
@ -38,7 +39,7 @@ func (r *PostgresTodoRepository) GetById(ctx context.Context, id int64) (reposit
} }
func (r *PostgresTodoRepository) Create(ctx context.Context, todo repository.TodoRow) (repository.TodoRow, error) { func (r *PostgresTodoRepository) Create(ctx context.Context, todo repository.TodoRow) (repository.TodoRow, error) {
result := r.db.QueryRow("INSERT INTO todo (name, done) VALUES ($1, $2) RETURNING id;", todo.Name, todo.Done) result := r.db.QueryRow(ctx, "INSERT INTO todo (name, done) VALUES ($1, $2) RETURNING id;", todo.Name, todo.Done)
err := result.Scan(&todo.Id) err := result.Scan(&todo.Id)
if err != nil { if err != nil {
@ -50,19 +51,14 @@ func (r *PostgresTodoRepository) Create(ctx context.Context, todo repository.Tod
} }
func (r *PostgresTodoRepository) Update(ctx context.Context, todo repository.TodoRow) error { func (r *PostgresTodoRepository) Update(ctx context.Context, todo repository.TodoRow) error {
result, err := r.db.Exec("UPDATE todo SET name = $1, done = $2 WHERE id = $3;", todo.Name, todo.Done, todo.Id) result, err := r.db.Exec(ctx, "UPDATE todo SET name = $1, done = $2 WHERE id = $3;", todo.Name, todo.Done, todo.Id)
if err != nil { if err != nil {
r.logger.ErrorContext(ctx, err.Error()) r.logger.ErrorContext(ctx, err.Error())
return err return err
} }
rowsAffected, err := result.RowsAffected() rowsAffected := result.RowsAffected()
if err != nil {
r.logger.ErrorContext(ctx, err.Error())
return err
}
if rowsAffected == 0 { if rowsAffected == 0 {
return repository.ErrNotFound return repository.ErrNotFound
@ -72,19 +68,14 @@ func (r *PostgresTodoRepository) Update(ctx context.Context, todo repository.Tod
} }
func (r *PostgresTodoRepository) Delete(ctx context.Context, id int64) error { func (r *PostgresTodoRepository) Delete(ctx context.Context, id int64) error {
result, err := r.db.Exec("DELETE FROM todo WHERE id = $1;", id) result, err := r.db.Exec(ctx, "DELETE FROM todo WHERE id = $1;", id)
if err != nil { if err != nil {
r.logger.ErrorContext(ctx, err.Error()) r.logger.ErrorContext(ctx, err.Error())
return err return err
} }
rowsAffected, err := result.RowsAffected() rowsAffected := result.RowsAffected()
if err != nil {
r.logger.ErrorContext(ctx, err.Error())
return err
}
if rowsAffected == 0 { if rowsAffected == 0 {
return repository.ErrNotFound return repository.ErrNotFound

View File

@ -8,7 +8,6 @@ import (
"gitea.michaelthomson.dev/mthomson/habits/internal/test" "gitea.michaelthomson.dev/mthomson/habits/internal/test"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository" "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
_ "github.com/jackc/pgx/v5/stdlib"
) )
func TestCRUD(t *testing.T) { func TestCRUD(t *testing.T) {

View File

@ -1,92 +0,0 @@
package sqlite
import (
"database/sql"
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
)
type SqliteTodoRepository struct {
db *sql.DB
}
func NewSqliteTodoRepository(db *sql.DB) *SqliteTodoRepository {
return &SqliteTodoRepository{
db: db,
}
}
func (r *SqliteTodoRepository) GetById(id int64) (repository.TodoRow, error) {
todo := repository.TodoRow{}
row := r.db.QueryRow("SELECT * FROM todo WHERE id = ?;", id)
err := row.Scan(&todo.Id, &todo.Name, &todo.Done)
if err != nil {
if err == sql.ErrNoRows {
return todo, repository.ErrNotFound
}
return todo, err
}
return todo, nil
}
func (r *SqliteTodoRepository) Create(todo repository.TodoRow) (repository.TodoRow, error) {
result, err := r.db.Exec("INSERT INTO todo (name, done) VALUES (?, ?)", todo.Name, todo.Done)
if err != nil {
return repository.TodoRow{}, err
}
id, err := result.LastInsertId()
if err != nil {
return repository.TodoRow{}, err
}
todo.Id = id
return todo, nil
}
func (r *SqliteTodoRepository) Update(todo repository.TodoRow) error {
result, err := r.db.Exec("UPDATE todo SET name = ?, done = ? WHERE id = ?", todo.Name, todo.Done, todo.Id)
if err != nil {
return err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}
if rowsAffected == 0 {
return repository.ErrNotFound
}
return nil
}
func (r *SqliteTodoRepository) Delete(id int64) error {
result, err := r.db.Exec("DELETE FROM todo WHERE id = ?", id)
if err != nil {
return err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}
if rowsAffected == 0 {
return repository.ErrNotFound
}
return nil
}