sql driver to pgx driver
This commit is contained in:
parent
ab0e40c695
commit
f1bbd06ef7
@ -1,7 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"context"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@ -12,7 +12,7 @@ import (
|
|||||||
todohandler "gitea.michaelthomson.dev/mthomson/habits/internal/todo/handler"
|
todohandler "gitea.michaelthomson.dev/mthomson/habits/internal/todo/handler"
|
||||||
todorepository "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository/postgres"
|
todorepository "gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository/postgres"
|
||||||
todoservice "gitea.michaelthomson.dev/mthomson/habits/internal/todo/service"
|
todoservice "gitea.michaelthomson.dev/mthomson/habits/internal/todo/service"
|
||||||
_ "github.com/jackc/pgx/v5/stdlib"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@ -30,7 +30,8 @@ func main() {
|
|||||||
|
|
||||||
// create db pool
|
// create db pool
|
||||||
postgresUrl := "postgres://todo:password@localhost:5432/todo"
|
postgresUrl := "postgres://todo:password@localhost:5432/todo"
|
||||||
db, err := sql.Open("pgx", postgresUrl)
|
// db, err := sql.Open("pgx", postgresUrl)
|
||||||
|
db, err := pgxpool.New(context.Background(), postgresUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err.Error())
|
logger.Error(err.Error())
|
||||||
os.Exit(1);
|
os.Exit(1);
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
package migrate
|
package migrate
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"context"
|
||||||
"embed"
|
"embed"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
|
||||||
_ "github.com/jackc/pgx/v5/stdlib"
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed migrations/*.sql
|
//go:embed migrations/*.sql
|
||||||
@ -18,7 +19,7 @@ type Migration struct {
|
|||||||
Name string
|
Name string
|
||||||
}
|
}
|
||||||
|
|
||||||
func Migrate(logger *slog.Logger, db *sql.DB) {
|
func Migrate(logger *slog.Logger, db *pgxpool.Pool) {
|
||||||
logger.Info("Running migrations...")
|
logger.Info("Running migrations...")
|
||||||
migrationTableSql := `
|
migrationTableSql := `
|
||||||
CREATE TABLE IF NOT EXISTS migrations(
|
CREATE TABLE IF NOT EXISTS migrations(
|
||||||
@ -26,7 +27,7 @@ func Migrate(logger *slog.Logger, db *sql.DB) {
|
|||||||
name VARCHAR(50)
|
name VARCHAR(50)
|
||||||
);`
|
);`
|
||||||
|
|
||||||
_, err := db.Exec(migrationTableSql)
|
_, err := db.Exec(context.Background(), migrationTableSql)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -38,21 +39,21 @@ func Migrate(logger *slog.Logger, db *sql.DB) {
|
|||||||
|
|
||||||
for _, file := range files {
|
for _, file := range files {
|
||||||
var migration Migration
|
var migration Migration
|
||||||
row := db.QueryRow("SELECT * FROM migrations WHERE name = $1;", file.Name())
|
row := db.QueryRow(context.Background(), "SELECT * FROM migrations WHERE name = $1;", file.Name())
|
||||||
err = row.Scan(&migration.Version, &migration.Name)
|
err = row.Scan(&migration.Version, &migration.Name)
|
||||||
if err == sql.ErrNoRows {
|
if err == pgx.ErrNoRows {
|
||||||
logger.Info(fmt.Sprintf("Running migration: %s", file.Name()))
|
logger.Info(fmt.Sprintf("Running migration: %s", file.Name()))
|
||||||
migrationSql, err := migrations.ReadFile(fmt.Sprintf("migrations/%s", file.Name()))
|
migrationSql, err := migrations.ReadFile(fmt.Sprintf("migrations/%s", file.Name()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = db.Exec(string(migrationSql))
|
_, err = db.Exec(context.Background(), string(migrationSql))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = db.Exec("INSERT INTO migrations(name) VALUES($1);", file.Name())
|
_, err = db.Exec(context.Background(), "INSERT INTO migrations(name) VALUES($1);", file.Name())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -2,19 +2,19 @@ package test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gitea.michaelthomson.dev/mthomson/habits/internal/migrate"
|
"gitea.michaelthomson.dev/mthomson/habits/internal/migrate"
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
"github.com/testcontainers/testcontainers-go"
|
"github.com/testcontainers/testcontainers-go"
|
||||||
"github.com/testcontainers/testcontainers-go/modules/postgres"
|
"github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||||
"github.com/testcontainers/testcontainers-go/wait"
|
"github.com/testcontainers/testcontainers-go/wait"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TestDatabase struct {
|
type TestDatabase struct {
|
||||||
Db *sql.DB
|
Db *pgxpool.Pool
|
||||||
container testcontainers.Container
|
container testcontainers.Container
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ func NewTestDatabase(tb testing.TB) *TestDatabase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create db pool
|
// create db pool
|
||||||
db, err := sql.Open("pgx", connectionString)
|
db, err := pgxpool.New(context.Background(), connectionString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tb.Fatalf("Failed to open db pool: %v", err)
|
tb.Fatalf("Failed to open db pool: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -1,59 +0,0 @@
|
|||||||
package inmemory
|
|
||||||
|
|
||||||
import (
|
|
||||||
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
|
|
||||||
)
|
|
||||||
|
|
||||||
type InMemoryTodoRepository struct {
|
|
||||||
Db map[int64]repository.TodoRow
|
|
||||||
id int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewInMemoryTodoRepository() InMemoryTodoRepository {
|
|
||||||
return InMemoryTodoRepository{
|
|
||||||
Db: make(map[int64]repository.TodoRow),
|
|
||||||
id: 1,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *InMemoryTodoRepository) GetById(id int64) (repository.TodoRow, error) {
|
|
||||||
todo, found := r.Db[id]
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
return todo, repository.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return todo, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *InMemoryTodoRepository) Create(todo repository.TodoRow) (repository.TodoRow, error) {
|
|
||||||
todo.Id = r.id
|
|
||||||
r.Db[r.id] = todo
|
|
||||||
r.id++
|
|
||||||
|
|
||||||
return todo, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *InMemoryTodoRepository) Update(todo repository.TodoRow) error {
|
|
||||||
_, found := r.Db[todo.Id]
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
return repository.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Db[todo.Id] = todo
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *InMemoryTodoRepository) Delete(id int64) error {
|
|
||||||
_, found := r.Db[id]
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
return repository.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(r.Db, id)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -2,18 +2,19 @@ package postgres
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
|
||||||
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
|
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PostgresTodoRepository struct {
|
type PostgresTodoRepository struct {
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
db *sql.DB
|
db *pgxpool.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresTodoRepository(logger *slog.Logger, db *sql.DB) *PostgresTodoRepository {
|
func NewPostgresTodoRepository(logger *slog.Logger, db *pgxpool.Pool) *PostgresTodoRepository {
|
||||||
return &PostgresTodoRepository{
|
return &PostgresTodoRepository{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
db: db,
|
db: db,
|
||||||
@ -23,10 +24,10 @@ func NewPostgresTodoRepository(logger *slog.Logger, db *sql.DB) *PostgresTodoRep
|
|||||||
func (r *PostgresTodoRepository) GetById(ctx context.Context, id int64) (repository.TodoRow, error) {
|
func (r *PostgresTodoRepository) GetById(ctx context.Context, id int64) (repository.TodoRow, error) {
|
||||||
todo := repository.TodoRow{}
|
todo := repository.TodoRow{}
|
||||||
|
|
||||||
err := r.db.QueryRow("SELECT * FROM todo WHERE id = $1;", id).Scan(&todo.Id, &todo.Name, &todo.Done)
|
err := r.db.QueryRow(ctx, "SELECT * FROM todo WHERE id = $1;", id).Scan(&todo.Id, &todo.Name, &todo.Done)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == pgx.ErrNoRows {
|
||||||
return todo, repository.ErrNotFound
|
return todo, repository.ErrNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,7 +39,7 @@ func (r *PostgresTodoRepository) GetById(ctx context.Context, id int64) (reposit
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *PostgresTodoRepository) Create(ctx context.Context, todo repository.TodoRow) (repository.TodoRow, error) {
|
func (r *PostgresTodoRepository) Create(ctx context.Context, todo repository.TodoRow) (repository.TodoRow, error) {
|
||||||
result := r.db.QueryRow("INSERT INTO todo (name, done) VALUES ($1, $2) RETURNING id;", todo.Name, todo.Done)
|
result := r.db.QueryRow(ctx, "INSERT INTO todo (name, done) VALUES ($1, $2) RETURNING id;", todo.Name, todo.Done)
|
||||||
err := result.Scan(&todo.Id)
|
err := result.Scan(&todo.Id)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -50,19 +51,14 @@ func (r *PostgresTodoRepository) Create(ctx context.Context, todo repository.Tod
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *PostgresTodoRepository) Update(ctx context.Context, todo repository.TodoRow) error {
|
func (r *PostgresTodoRepository) Update(ctx context.Context, todo repository.TodoRow) error {
|
||||||
result, err := r.db.Exec("UPDATE todo SET name = $1, done = $2 WHERE id = $3;", todo.Name, todo.Done, todo.Id)
|
result, err := r.db.Exec(ctx, "UPDATE todo SET name = $1, done = $2 WHERE id = $3;", todo.Name, todo.Done, todo.Id)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.logger.ErrorContext(ctx, err.Error())
|
r.logger.ErrorContext(ctx, err.Error())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
rowsAffected, err := result.RowsAffected()
|
rowsAffected := result.RowsAffected()
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
r.logger.ErrorContext(ctx, err.Error())
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if rowsAffected == 0 {
|
if rowsAffected == 0 {
|
||||||
return repository.ErrNotFound
|
return repository.ErrNotFound
|
||||||
@ -72,19 +68,14 @@ func (r *PostgresTodoRepository) Update(ctx context.Context, todo repository.Tod
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *PostgresTodoRepository) Delete(ctx context.Context, id int64) error {
|
func (r *PostgresTodoRepository) Delete(ctx context.Context, id int64) error {
|
||||||
result, err := r.db.Exec("DELETE FROM todo WHERE id = $1;", id)
|
result, err := r.db.Exec(ctx, "DELETE FROM todo WHERE id = $1;", id)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.logger.ErrorContext(ctx, err.Error())
|
r.logger.ErrorContext(ctx, err.Error())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
rowsAffected, err := result.RowsAffected()
|
rowsAffected := result.RowsAffected()
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
r.logger.ErrorContext(ctx, err.Error())
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if rowsAffected == 0 {
|
if rowsAffected == 0 {
|
||||||
return repository.ErrNotFound
|
return repository.ErrNotFound
|
||||||
|
@ -8,7 +8,6 @@ import (
|
|||||||
|
|
||||||
"gitea.michaelthomson.dev/mthomson/habits/internal/test"
|
"gitea.michaelthomson.dev/mthomson/habits/internal/test"
|
||||||
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
|
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
|
||||||
_ "github.com/jackc/pgx/v5/stdlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCRUD(t *testing.T) {
|
func TestCRUD(t *testing.T) {
|
||||||
|
@ -1,92 +0,0 @@
|
|||||||
package sqlite
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
|
|
||||||
"gitea.michaelthomson.dev/mthomson/habits/internal/todo/repository"
|
|
||||||
)
|
|
||||||
|
|
||||||
type SqliteTodoRepository struct {
|
|
||||||
db *sql.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSqliteTodoRepository(db *sql.DB) *SqliteTodoRepository {
|
|
||||||
return &SqliteTodoRepository{
|
|
||||||
db: db,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *SqliteTodoRepository) GetById(id int64) (repository.TodoRow, error) {
|
|
||||||
todo := repository.TodoRow{}
|
|
||||||
|
|
||||||
row := r.db.QueryRow("SELECT * FROM todo WHERE id = ?;", id)
|
|
||||||
err := row.Scan(&todo.Id, &todo.Name, &todo.Done)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return todo, repository.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return todo, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return todo, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *SqliteTodoRepository) Create(todo repository.TodoRow) (repository.TodoRow, error) {
|
|
||||||
result, err := r.db.Exec("INSERT INTO todo (name, done) VALUES (?, ?)", todo.Name, todo.Done)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return repository.TodoRow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
id, err := result.LastInsertId()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return repository.TodoRow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
todo.Id = id
|
|
||||||
|
|
||||||
return todo, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *SqliteTodoRepository) Update(todo repository.TodoRow) error {
|
|
||||||
result, err := r.db.Exec("UPDATE todo SET name = ?, done = ? WHERE id = ?", todo.Name, todo.Done, todo.Id)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
rowsAffected, err := result.RowsAffected()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if rowsAffected == 0 {
|
|
||||||
return repository.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *SqliteTodoRepository) Delete(id int64) error {
|
|
||||||
result, err := r.db.Exec("DELETE FROM todo WHERE id = ?", id)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
rowsAffected, err := result.RowsAffected()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if rowsAffected == 0 {
|
|
||||||
return repository.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
Loading…
x
Reference in New Issue
Block a user