auth services, middleware, and other stuff
This commit is contained in:
66
internal/auth/handler/login.go
Normal file
66
internal/auth/handler/login.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/auth/service"
|
||||
)
|
||||
|
||||
type Loginer interface {
|
||||
Login(ctx context.Context, email string, password string) (string, error)
|
||||
}
|
||||
|
||||
type LoginRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
func HandleLogin(logger *slog.Logger, authService Loginer) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
loginRequest := LoginRequest{}
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
decoder.DisallowUnknownFields()
|
||||
err := decoder.Decode(&loginRequest)
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err.Error())
|
||||
http.Error(w, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := authService.Login(ctx, loginRequest.Email, loginRequest.Password)
|
||||
|
||||
if err == service.ErrUnauthorized {
|
||||
http.Error(w, "", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if err == service.ErrNotFound {
|
||||
http.Error(w, "", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err.Error())
|
||||
http.Error(w, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
cookie := http.Cookie{
|
||||
Name: "token",
|
||||
Value: token,
|
||||
Path: "/",
|
||||
MaxAge: 3600,
|
||||
HttpOnly: true,
|
||||
Secure: false,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
http.SetCookie(w, &cookie)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
}
|
||||
168
internal/auth/handler/login_test.go
Normal file
168
internal/auth/handler/login_test.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/auth/service"
|
||||
)
|
||||
|
||||
type MockLoginer struct {
|
||||
LoginFunc func(ctx context.Context, email string, password string) (string, error)
|
||||
}
|
||||
|
||||
func (m *MockLoginer) Login(ctx context.Context, email string, password string) (string, error) {
|
||||
return m.LoginFunc(ctx, email, password)
|
||||
}
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
logger := slog.Default()
|
||||
t.Run("returns 200 for existing user with correct credentials", func(t *testing.T) {
|
||||
loginRequest := LoginRequest{Email: "test@test.com", Password: "password"}
|
||||
|
||||
token := "examplejwt"
|
||||
|
||||
service := MockLoginer{
|
||||
LoginFunc: func(ctx context.Context, email string, password string) (string, error) {
|
||||
return token, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := HandleLogin(logger, &service)
|
||||
|
||||
requestBody, err := json.Marshal(loginRequest)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request %+v: %v", loginRequest, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
handler(res, req)
|
||||
|
||||
AssertStatusCodes(t, res.Code, http.StatusOK)
|
||||
AssertToken(t, res, token)
|
||||
})
|
||||
|
||||
t.Run("returns 401 for existing user with incorrect credentials", func(t *testing.T) {
|
||||
loginRequest := LoginRequest{Email: "test@test.com", Password: "password"}
|
||||
|
||||
service := MockLoginer{
|
||||
LoginFunc: func(ctx context.Context, email string, password string) (string, error) {
|
||||
return "", service.ErrUnauthorized
|
||||
},
|
||||
}
|
||||
|
||||
handler := HandleLogin(logger, &service)
|
||||
|
||||
requestBody, err := json.Marshal(loginRequest)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request %+v: %v", loginRequest, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
handler(res, req)
|
||||
|
||||
AssertStatusCodes(t, res.Code, http.StatusUnauthorized)
|
||||
})
|
||||
|
||||
t.Run("returns 401 for non-existing user", func(t *testing.T) {
|
||||
loginRequest := LoginRequest{Email: "test@test.com", Password: "password"}
|
||||
|
||||
service := MockLoginer{
|
||||
LoginFunc: func(ctx context.Context, email string, password string) (string, error) {
|
||||
return "", service.ErrNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := HandleLogin(logger, &service)
|
||||
|
||||
requestBody, err := json.Marshal(loginRequest)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request %+v: %v", loginRequest, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
handler(res, req)
|
||||
|
||||
AssertStatusCodes(t, res.Code, http.StatusUnauthorized)
|
||||
})
|
||||
|
||||
t.Run("returns 400 with bad json", func(t *testing.T) {
|
||||
handler := HandleLogin(logger, nil)
|
||||
|
||||
badStruct := struct {
|
||||
Foo string
|
||||
}{
|
||||
Foo: "bar",
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(badStruct)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request %+v: %v", badStruct, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
handler(res, req)
|
||||
|
||||
AssertStatusCodes(t, res.Code, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("returns 500 arbitrary errors", func(t *testing.T) {
|
||||
loginRequest := LoginRequest{Email: "test@test.com", Password: "password"}
|
||||
|
||||
service := MockLoginer{
|
||||
LoginFunc: func(ctx context.Context, email string, password string) (string, error) {
|
||||
return "", errors.New("foo bar")
|
||||
},
|
||||
}
|
||||
|
||||
handler := HandleLogin(logger, &service)
|
||||
|
||||
requestBody, err := json.Marshal(loginRequest)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request %+v: %v", loginRequest, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(requestBody))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
handler(res, req)
|
||||
|
||||
AssertStatusCodes(t, res.Code, http.StatusInternalServerError)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func AssertStatusCodes(t testing.TB, got, want int) {
|
||||
t.Helper()
|
||||
if got != want {
|
||||
t.Errorf("got status code: %v, want status code: %v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func AssertToken(t testing.TB, res *httptest.ResponseRecorder, want string) {
|
||||
t.Helper()
|
||||
got := res.Result().Cookies()[0].Value
|
||||
if got != want {
|
||||
t.Errorf("got cookie: %q, want cookie: %q", got, want)
|
||||
}
|
||||
}
|
||||
70
internal/auth/hashing.go
Normal file
70
internal/auth/hashing.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
type HashSalt struct {
|
||||
Hash, Salt []byte
|
||||
}
|
||||
|
||||
type Argon2IdHash struct {
|
||||
time uint32
|
||||
memory uint32
|
||||
threads uint8
|
||||
keyLen uint32
|
||||
saltLen uint32
|
||||
}
|
||||
|
||||
var (
|
||||
ErrNoMatch error = errors.New("hash doesn't match")
|
||||
)
|
||||
|
||||
func NewArgon2IdHash(time, saltLen uint32, memory uint32, threads uint8, keyLen uint32) *Argon2IdHash {
|
||||
return &Argon2IdHash{
|
||||
time: time,
|
||||
saltLen: saltLen,
|
||||
memory: memory,
|
||||
threads: threads,
|
||||
keyLen: keyLen,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Argon2IdHash) GenerateHash(password, salt []byte) (*HashSalt, error) {
|
||||
var err error
|
||||
if len(salt) == 0 {
|
||||
salt, err = randomSecret(a.saltLen)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hash := argon2.IDKey(password, salt, a.time, a.memory, a.threads, a.keyLen)
|
||||
return &HashSalt{Hash: hash, Salt: salt}, nil
|
||||
}
|
||||
|
||||
func (a *Argon2IdHash) Compare(hash, salt, password []byte) error {
|
||||
hashSalt, err := a.GenerateHash(password, salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !bytes.Equal(hash, hashSalt.Hash) {
|
||||
return ErrNoMatch
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func randomSecret(length uint32) ([]byte, error) {
|
||||
secret := make([]byte, length)
|
||||
|
||||
_, err := rand.Read(secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return secret, nil
|
||||
}
|
||||
90
internal/auth/service/service.go
Normal file
90
internal/auth/service/service.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/auth"
|
||||
userrepository "gitea.michaelthomson.dev/mthomson/habits/internal/user/repository"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotFound error = errors.New("user cannot be found")
|
||||
ErrUnauthorized error = errors.New("user password incorrect")
|
||||
)
|
||||
|
||||
type UserRepository interface {
|
||||
Create(ctx context.Context, user userrepository.UserRow) (userrepository.UserRow, error)
|
||||
GetByEmail(ctx context.Context, email string) (userrepository.UserRow, error)
|
||||
}
|
||||
|
||||
type AuthService struct {
|
||||
logger *slog.Logger
|
||||
jwtKey []byte
|
||||
userRepository UserRepository
|
||||
argon2IdHash *auth.Argon2IdHash
|
||||
}
|
||||
|
||||
func NewAuthService(logger *slog.Logger, jwtKey []byte, userRepository UserRepository, argon2IdHash *auth.Argon2IdHash) *AuthService {
|
||||
return &AuthService{
|
||||
logger: logger,
|
||||
jwtKey: jwtKey,
|
||||
userRepository: userRepository,
|
||||
argon2IdHash: argon2IdHash,
|
||||
}
|
||||
}
|
||||
|
||||
func (a AuthService) Login(ctx context.Context, email string, password string) (string, error) {
|
||||
// get user if exists
|
||||
userRow, err := a.userRepository.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if err == userrepository.ErrNotFound {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
|
||||
a.logger.ErrorContext(ctx, err.Error())
|
||||
return "", err
|
||||
}
|
||||
|
||||
// compare hashed passswords
|
||||
err = a.argon2IdHash.Compare(userRow.HashedPassword, userRow.Salt, []byte(password))
|
||||
|
||||
if err == auth.ErrNoMatch {
|
||||
return "", ErrUnauthorized
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
a.logger.ErrorContext(ctx, err.Error())
|
||||
return "", err
|
||||
}
|
||||
|
||||
// create token and return it
|
||||
token, err := a.CreateToken(ctx, email)
|
||||
if err != nil {
|
||||
a.logger.ErrorContext(ctx, err.Error())
|
||||
return "", err
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (a AuthService) CreateToken(ctx context.Context, email string) (string, error) {
|
||||
// Create a new JWT token with claims
|
||||
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": email,
|
||||
"iss": "todo-app",
|
||||
"aud": "user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
})
|
||||
|
||||
tokenString, err := claims.SignedString(a.jwtKey)
|
||||
if err != nil {
|
||||
a.logger.ErrorContext(ctx, err.Error())
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
81
internal/auth/service/service_test.go
Normal file
81
internal/auth/service/service_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package service_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"testing"
|
||||
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/auth"
|
||||
authservice "gitea.michaelthomson.dev/mthomson/habits/internal/auth/service"
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/test"
|
||||
"gitea.michaelthomson.dev/mthomson/habits/internal/user/repository"
|
||||
userservice "gitea.michaelthomson.dev/mthomson/habits/internal/user/service"
|
||||
"github.com/gofrs/uuid/v5"
|
||||
)
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slog.Default()
|
||||
tdb := test.NewTestDatabase(t)
|
||||
defer tdb.TearDown()
|
||||
userRepository := repository.NewUserRepository(logger, tdb.Db)
|
||||
argon2IdHash := auth.NewArgon2IdHash(1, 32, 64*1024, 32, 256)
|
||||
|
||||
userService := userservice.NewUserService(logger, userRepository, argon2IdHash)
|
||||
authService := authservice.NewAuthService(logger, []byte("secretkey"), userRepository, argon2IdHash)
|
||||
|
||||
_, err := userService.Register(ctx, "test@test.com", "supersecurepassword")
|
||||
AssertNoError(t, err)
|
||||
|
||||
t.Run("login existing user with correct credentials", func(t *testing.T) {
|
||||
want, err := authService.CreateToken(ctx, "test@test.com")
|
||||
AssertNoError(t, err)
|
||||
got, err := authService.Login(ctx, "test@test.com", "supersecurepassword")
|
||||
|
||||
AssertNoError(t, err)
|
||||
AssertTokens(t, got, want)
|
||||
})
|
||||
|
||||
t.Run("login existing user with incorrect credentials", func(t *testing.T) {
|
||||
_, err := authService.Login(ctx, "test@test.com", "superwrongpassword")
|
||||
|
||||
AssertErrors(t, err, authservice.ErrUnauthorized)
|
||||
})
|
||||
|
||||
t.Run("login nonexistant user", func(t *testing.T) {
|
||||
_, err := authService.Login(ctx, "foo@test.com", "supersecurepassword")
|
||||
|
||||
AssertErrors(t, err, authservice.ErrNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func AssertErrors(t testing.TB, got, want error) {
|
||||
t.Helper()
|
||||
if got != want {
|
||||
t.Errorf("got error: %v, want error: %v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func AssertNoError(t testing.TB, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func AssertTokens(t testing.TB, got, want string) {
|
||||
t.Helper()
|
||||
if got != want {
|
||||
t.Errorf("expected matching tokens, got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func NewUUID(t testing.TB) uuid.UUID {
|
||||
t.Helper()
|
||||
uuid, err := uuid.NewV4()
|
||||
if err != nil {
|
||||
t.Errorf("error generation uuid: %v", err)
|
||||
}
|
||||
return uuid
|
||||
}
|
||||
Reference in New Issue
Block a user