sql driver to pgx driver
This commit is contained in:
parent
ab0e40c695
commit
f1bbd06ef7
@ -1,7 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
@ -12,7 +12,7 @@ import (
|
||||
todohandler "gitea.michaelthomson.dev/mthomson/habits/internal/todo/handler"
|
||||
todorepository "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository/postgres"
|
||||
todoservice "gitea.michaelthomson.dev/mthomson/habits/internal/todo/service"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@ -30,7 +30,8 @@ func main() {
|
||||
|
||||
// create db pool
|
||||
postgresUrl := "postgres://todo:password@localhost:5432/todo"
|
||||
db, err := sql.Open("pgx", postgresUrl)
|
||||
// db, err := sql.Open("pgx", postgresUrl)
|
||||
db, err := pgxpool.New(context.Background(), postgresUrl)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
os.Exit(1);
|
||||
|
@ -1,13 +1,14 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"embed"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
//go:embed migrations/*.sql
|
||||
@ -18,7 +19,7 @@ type Migration struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
func Migrate(logger *slog.Logger, db *sql.DB) {
|
||||
func Migrate(logger *slog.Logger, db *pgxpool.Pool) {
|
||||
logger.Info("Running migrations...")
|
||||
migrationTableSql := `
|
||||
CREATE TABLE IF NOT EXISTS migrations(
|
||||
@ -26,7 +27,7 @@ func Migrate(logger *slog.Logger, db *sql.DB) {
|
||||
name VARCHAR(50)
|
||||
);`
|
||||
|
||||
_, err := db.Exec(migrationTableSql)
|
||||
_, err := db.Exec(context.Background(), migrationTableSql)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@ -38,21 +39,21 @@ func Migrate(logger *slog.Logger, db *sql.DB) {
|
||||
|
||||
for _, file := range files {
|
||||
var migration Migration
|
||||
row := db.QueryRow("SELECT * FROM migrations WHERE name = $1;", file.Name())
|
||||
row := db.QueryRow(context.Background(), "SELECT * FROM migrations WHERE name = $1;", file.Name())
|
||||
err = row.Scan(&migration.Version, &migration.Name)
|
||||
if err == sql.ErrNoRows {
|
||||
if err == pgx.ErrNoRows {
|
||||
logger.Info(fmt.Sprintf("Running migration: %s", file.Name()))
|
||||
migrationSql, err := migrations.ReadFile(fmt.Sprintf("migrations/%s", file.Name()))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(string(migrationSql))
|
||||
_, err = db.Exec(context.Background(), string(migrationSql))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.Exec("INSERT INTO migrations(name) VALUES($1);", file.Name())
|
||||
_, err = db.Exec(context.Background(), "INSERT INTO migrations(name) VALUES($1);", file.Name())
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -2,19 +2,19 @@ package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/migrate"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
)
|
||||
|
||||
type TestDatabase struct {
|
||||
Db *sql.DB
|
||||
Db *pgxpool.Pool
|
||||
container testcontainers.Container
|
||||
}
|
||||
|
||||
@ -43,7 +43,7 @@ func NewTestDatabase(tb testing.TB) *TestDatabase {
|
||||
}
|
||||
|
||||
// create db pool
|
||||
db, err := sql.Open("pgx", connectionString)
|
||||
db, err := pgxpool.New(context.Background(), connectionString)
|
||||
if err != nil {
|
||||
tb.Fatalf("Failed to open db pool: %v", err)
|
||||
}
|
||||
|
@ -1,59 +0,0 @@
|
||||
package inmemory
|
||||
|
||||
import (
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
|
||||
)
|
||||
|
||||
type InMemoryTodoRepository struct {
|
||||
Db map[int64]repository.TodoRow
|
||||
id int64
|
||||
}
|
||||
|
||||
func NewInMemoryTodoRepository() InMemoryTodoRepository {
|
||||
return InMemoryTodoRepository{
|
||||
Db: make(map[int64]repository.TodoRow),
|
||||
id: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *InMemoryTodoRepository) GetById(id int64) (repository.TodoRow, error) {
|
||||
todo, found := r.Db[id]
|
||||
|
||||
if !found {
|
||||
return todo, repository.ErrNotFound
|
||||
}
|
||||
|
||||
return todo, nil
|
||||
}
|
||||
|
||||
func (r *InMemoryTodoRepository) Create(todo repository.TodoRow) (repository.TodoRow, error) {
|
||||
todo.Id = r.id
|
||||
r.Db[r.id] = todo
|
||||
r.id++
|
||||
|
||||
return todo, nil
|
||||
}
|
||||
|
||||
func (r *InMemoryTodoRepository) Update(todo repository.TodoRow) error {
|
||||
_, found := r.Db[todo.Id]
|
||||
|
||||
if !found {
|
||||
return repository.ErrNotFound
|
||||
}
|
||||
|
||||
r.Db[todo.Id] = todo
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *InMemoryTodoRepository) Delete(id int64) error {
|
||||
_, found := r.Db[id]
|
||||
|
||||
if !found {
|
||||
return repository.ErrNotFound
|
||||
}
|
||||
|
||||
delete(r.Db, id)
|
||||
|
||||
return nil
|
||||
}
|
@ -2,18 +2,19 @@ package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"log/slog"
|
||||
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type PostgresTodoRepository struct {
|
||||
logger *slog.Logger
|
||||
db *sql.DB
|
||||
db *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewPostgresTodoRepository(logger *slog.Logger, db *sql.DB) *PostgresTodoRepository {
|
||||
func NewPostgresTodoRepository(logger *slog.Logger, db *pgxpool.Pool) *PostgresTodoRepository {
|
||||
return &PostgresTodoRepository{
|
||||
logger: logger,
|
||||
db: db,
|
||||
@ -23,10 +24,10 @@ func NewPostgresTodoRepository(logger *slog.Logger, db *sql.DB) *PostgresTodoRep
|
||||
func (r *PostgresTodoRepository) GetById(ctx context.Context, id int64) (repository.TodoRow, error) {
|
||||
todo := repository.TodoRow{}
|
||||
|
||||
err := r.db.QueryRow("SELECT * FROM todo WHERE id = $1;", id).Scan(&todo.Id, &todo.Name, &todo.Done)
|
||||
err := r.db.QueryRow(ctx, "SELECT * FROM todo WHERE id = $1;", id).Scan(&todo.Id, &todo.Name, &todo.Done)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
if err == pgx.ErrNoRows {
|
||||
return todo, repository.ErrNotFound
|
||||
}
|
||||
|
||||
@ -38,7 +39,7 @@ func (r *PostgresTodoRepository) GetById(ctx context.Context, id int64) (reposit
|
||||
}
|
||||
|
||||
func (r *PostgresTodoRepository) Create(ctx context.Context, todo repository.TodoRow) (repository.TodoRow, error) {
|
||||
result := r.db.QueryRow("INSERT INTO todo (name, done) VALUES ($1, $2) RETURNING id;", todo.Name, todo.Done)
|
||||
result := r.db.QueryRow(ctx, "INSERT INTO todo (name, done) VALUES ($1, $2) RETURNING id;", todo.Name, todo.Done)
|
||||
err := result.Scan(&todo.Id)
|
||||
|
||||
if err != nil {
|
||||
@ -50,19 +51,14 @@ func (r *PostgresTodoRepository) Create(ctx context.Context, todo repository.Tod
|
||||
}
|
||||
|
||||
func (r *PostgresTodoRepository) Update(ctx context.Context, todo repository.TodoRow) error {
|
||||
result, err := r.db.Exec("UPDATE todo SET name = $1, done = $2 WHERE id = $3;", todo.Name, todo.Done, todo.Id)
|
||||
result, err := r.db.Exec(ctx, "UPDATE todo SET name = $1, done = $2 WHERE id = $3;", todo.Name, todo.Done, todo.Id)
|
||||
|
||||
if err != nil {
|
||||
r.logger.ErrorContext(ctx, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
|
||||
if err != nil {
|
||||
r.logger.ErrorContext(ctx, err.Error())
|
||||
return err
|
||||
}
|
||||
rowsAffected := result.RowsAffected()
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return repository.ErrNotFound
|
||||
@ -72,19 +68,14 @@ func (r *PostgresTodoRepository) Update(ctx context.Context, todo repository.Tod
|
||||
}
|
||||
|
||||
func (r *PostgresTodoRepository) Delete(ctx context.Context, id int64) error {
|
||||
result, err := r.db.Exec("DELETE FROM todo WHERE id = $1;", id)
|
||||
result, err := r.db.Exec(ctx, "DELETE FROM todo WHERE id = $1;", id)
|
||||
|
||||
if err != nil {
|
||||
r.logger.ErrorContext(ctx, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
|
||||
if err != nil {
|
||||
r.logger.ErrorContext(ctx, err.Error())
|
||||
return err
|
||||
}
|
||||
rowsAffected := result.RowsAffected()
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return repository.ErrNotFound
|
||||
|
@ -8,7 +8,6 @@ import (
|
||||
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/test"
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
)
|
||||
|
||||
func TestCRUD(t *testing.T) {
|
||||
|
@ -1,92 +0,0 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
|
||||
)
|
||||
|
||||
type SqliteTodoRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewSqliteTodoRepository(db *sql.DB) *SqliteTodoRepository {
|
||||
return &SqliteTodoRepository{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SqliteTodoRepository) GetById(id int64) (repository.TodoRow, error) {
|
||||
todo := repository.TodoRow{}
|
||||
|
||||
row := r.db.QueryRow("SELECT * FROM todo WHERE id = ?;", id)
|
||||
err := row.Scan(&todo.Id, &todo.Name, &todo.Done)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return todo, repository.ErrNotFound
|
||||
}
|
||||
|
||||
return todo, err
|
||||
}
|
||||
|
||||
return todo, nil
|
||||
}
|
||||
|
||||
func (r *SqliteTodoRepository) Create(todo repository.TodoRow) (repository.TodoRow, error) {
|
||||
result, err := r.db.Exec("INSERT INTO todo (name, done) VALUES (?, ?)", todo.Name, todo.Done)
|
||||
|
||||
if err != nil {
|
||||
return repository.TodoRow{}, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
|
||||
if err != nil {
|
||||
return repository.TodoRow{}, err
|
||||
}
|
||||
|
||||
todo.Id = id
|
||||
|
||||
return todo, nil
|
||||
}
|
||||
|
||||
func (r *SqliteTodoRepository) Update(todo repository.TodoRow) error {
|
||||
result, err := r.db.Exec("UPDATE todo SET name = ?, done = ? WHERE id = ?", todo.Name, todo.Done, todo.Id)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return repository.ErrNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SqliteTodoRepository) Delete(id int64) error {
|
||||
result, err := r.db.Exec("DELETE FROM todo WHERE id = ?", id)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return repository.ErrNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user