Refactor configuration and database initialization: replace LoadConfig with GetConfig, introduce DbConfig struct, and streamline token service retrieval. Remove unused request types and enhance user refresh token handling.

This commit is contained in:
fabio 2026-04-09 09:21:38 +02:00
parent 3731e6e409
commit 5b9fe6c9b7
11 changed files with 138 additions and 138 deletions

View File

@ -36,20 +36,12 @@ func main() {
seedCount := flag.Int("seed", 0, "seed N fake users at startup (0 to skip)")
flag.Parse()
cfg, err := config.LoadConfig()
cfg, err := config.GetConfig()
if err != nil {
log.Fatalf("load config: %v", err)
log.Fatalf("config: %v", err)
}
if secret := os.Getenv("AUTH_SECRET"); secret != "" {
cfg.Auth.Secret = secret
}
dbCfg := db.Config{
Driver: envOrDefault("DB_driver", "sqlite"),
DSN: envOrDefault("DB_dsn", "file:./data/data.db?_foreign_keys=on"),
}
dbConn, err := db.Init(dbCfg)
dbConn, err := db.GetDB()
if err != nil {
log.Fatalf("init db: %v", err)
}
@ -58,12 +50,10 @@ func main() {
log.Fatalf("migrate user: %v", err)
}
tockenService, _ := tokens.NewTockenService(tokens.Config{
Secret: "your-secret-key",
Issuer: "your-issuer",
AccessTokenExpiry: time.Hour,
RefreshTokenExpiry: 24 * time.Hour,
})
tokenService, err := tokens.GetTockenService()
if err != nil {
log.Fatalf("init tokens: %v", err)
}
app := fiber.New(fiber.Config{
AppName: cfg.AppName,
@ -107,7 +97,7 @@ func main() {
return c.Next()
})
app.Use(roles.RequireEndpointPermission(dbConn, tockenService))
app.Use(roles.RequireEndpointPermission(dbConn, tokenService))
routes.Register(app)

View File

@ -1,23 +0,0 @@
package auth
// Typescript: interface
type LoginRequest struct {
Username string `json:"username" validate:"required,email"`
Password string `json:"password" validate:"required,min=8,max=128"`
}
// Typescript: interface
type RefreshRequest struct {
RefreshToken string `json:"refresh_token"`
}
// Typescript: interface
type ForgotPasswordRequest struct {
Email string `json:"email" validate:"required,email"`
}
// Typescript: interface
type ResetPasswordRequest struct {
Token string `json:"token" validate:"required,min=20,max=255"`
Password string `json:"password" validate:"required,min=8,max=128"`
}

View File

@ -14,6 +14,7 @@ type ServerConfig struct {
DisableStartupMessage bool `json:"disable_startup_message"`
Auth AuthConfig `json:"auth"`
Mail MailConfig `json:"mail"`
Db DbConfig `json:"db_config"`
RolesConfigPath string `json:"roles_config_path"`
}
@ -42,6 +43,25 @@ type SMTPMailConfig struct {
Password string `json:"password"`
}
type DbConfig struct {
Driver string
DSN string
}
var Config *ServerConfig = nil
func GetConfig() (*ServerConfig, error) {
if Config == nil {
var err error
Config, err = loadConfig()
if err != nil {
fmt.Printf("Failed to load config: %v\n", err)
return nil, err
}
}
return Config, nil
}
func envOrDefault(key, defaultValue string) string {
if value, exists := os.LookupEnv(key); exists {
return value
@ -49,26 +69,28 @@ func envOrDefault(key, defaultValue string) string {
return defaultValue
}
func LoadConfig() (ServerConfig, error) {
func loadConfig() (*ServerConfig, error) {
path := envOrDefault("CONFIG_PATH", "configs/config.json")
data, err := os.ReadFile(path)
if err != nil {
return ServerConfig{}, fmt.Errorf("read config: %w", err)
return nil, fmt.Errorf("read config: %w", err)
}
var cfg ServerConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return ServerConfig{}, fmt.Errorf("parse config: %w", err)
return nil, fmt.Errorf("parse config: %w", err)
}
if secret := os.Getenv("AUTH_SECRET"); secret != "" {
cfg.Auth.Secret = secret
}
if cfg.Auth.Secret == "" {
return ServerConfig{}, fmt.Errorf("auth.secret must be set")
return nil, fmt.Errorf("auth.secret must be set")
}
if cfg.Auth.AccessTokenExpiryMinutes <= 0 {
return ServerConfig{}, fmt.Errorf("auth.access_token_expiry_minutes must be greater than zero")
return nil, fmt.Errorf("auth.access_token_expiry_minutes must be greater than zero")
}
if cfg.Auth.RefreshTokenExpiryMinutes <= 0 {
return ServerConfig{}, fmt.Errorf("auth.refresh_token_expiry_minutes must be greater than zero")
return nil, fmt.Errorf("auth.refresh_token_expiry_minutes must be greater than zero")
}
if cfg.Mail.Mode == "" {
cfg.Mail.Mode = "file"
@ -80,21 +102,26 @@ func LoadConfig() (ServerConfig, error) {
cfg.Mail.ResetPasswordPath = "/#reset-password"
}
if cfg.Mail.Mode != "smtp" && cfg.Mail.Mode != "file" {
return ServerConfig{}, fmt.Errorf("mail.mode must be either smtp or file")
return nil, fmt.Errorf("mail.mode must be either smtp or file")
}
if cfg.Mail.From == "" {
return ServerConfig{}, fmt.Errorf("mail.from must be set")
return nil, fmt.Errorf("mail.from must be set")
}
if cfg.Mail.Mode == "smtp" {
if cfg.Mail.SMTP.Host == "" {
return ServerConfig{}, fmt.Errorf("mail.smtp.host must be set when mail.mode=smtp")
return nil, fmt.Errorf("mail.smtp.host must be set when mail.mode=smtp")
}
if cfg.Mail.SMTP.Port <= 0 {
return ServerConfig{}, fmt.Errorf("mail.smtp.port must be greater than zero when mail.mode=smtp")
return nil, fmt.Errorf("mail.smtp.port must be greater than zero when mail.mode=smtp")
}
} else if cfg.Mail.DebugDir == "" {
cfg.Mail.DebugDir = "data/mail-debug"
}
return cfg, nil
cfg.Db = DbConfig{
Driver: envOrDefault("DB_driver", "sqlite"),
DSN: envOrDefault("DB_dsn", "file:./data/data.db?_foreign_keys=on"),
}
return &cfg, nil
}

View File

@ -4,6 +4,7 @@ import (
"fmt"
"os"
"path/filepath"
"server/internal/config"
"strings"
"github.com/gofiber/fiber/v3"
@ -12,33 +13,42 @@ import (
"gorm.io/gorm"
)
type Config struct {
Driver string
DSN string
}
var DB *gorm.DB
// GetDB returns the global *gorm.DB instance. It panics if the database is not initialized.
func GetDB() (*gorm.DB, error) {
if DB == nil {
cfg, err := config.GetConfig()
if err != nil {
return nil, err
}
DB, err = InitDB(cfg.Db)
if err != nil {
fmt.Printf("Failed to initialize database: %v\n", err)
return nil, err
}
}
return DB, nil
}
// Init opens the configured database connection and runs schema migrations.
func Init(cfg Config) (*gorm.DB, error) {
func InitDB(cfg config.DbConfig) (*gorm.DB, error) {
switch cfg.Driver {
case "sqlite":
if err := ensureSQLiteDir(cfg.DSN); err != nil {
return nil, fmt.Errorf("prepare sqlite path: %w", err)
}
db, err := gorm.Open(sqlite.Open(cfg.DSN), &gorm.Config{})
DB, err := gorm.Open(sqlite.Open(cfg.DSN), &gorm.Config{})
if err != nil {
return nil, fmt.Errorf("open sqlite: %w", err)
}
DB = db
return db, nil
return DB, nil
case "postgres":
db, err := gorm.Open(postgres.Open(cfg.DSN), &gorm.Config{})
DB, err := gorm.Open(postgres.Open(cfg.DSN), &gorm.Config{})
if err != nil {
return nil, fmt.Errorf("open postgres: %w", err)
}
DB = db
return db, nil
return DB, nil
default:
return nil, fmt.Errorf("unsupported driver %q", cfg.Driver)
}

View File

@ -6,7 +6,6 @@ import (
"crypto/tls"
"fmt"
"html/template"
"log"
"net"
"net/smtp"
"os"
@ -42,7 +41,7 @@ type Message struct {
TemplateData any
}
type Service struct {
type MailService struct {
cfg Config
}
@ -54,11 +53,12 @@ type TemplateData struct {
ResetToken string
}
func New() (*Service, error) {
// if service fail send admin allert instead a response to user or a simple response server error.
func New() (*MailService, error) {
serverCfg, err := config.LoadConfig()
serverCfg, err := config.GetConfig()
if err != nil {
log.Fatalf("load config: %v", err)
return nil, err
}
cfg := Config{
@ -102,10 +102,10 @@ func New() (*Service, error) {
return nil, fmt.Errorf("smtp host and port are required")
}
}
return &Service{cfg: cfg}, nil
return &MailService{cfg: cfg}, nil
}
func (s *Service) Send(ctx context.Context, msg Message) error {
func (s *MailService) Send(ctx context.Context, msg Message) error {
htmlBody, textBody, err := s.renderBodies(msg.Template, msg.TemplateData)
if err != nil {
return err
@ -122,7 +122,7 @@ func (s *Service) Send(ctx context.Context, msg Message) error {
}
}
func (s *Service) ResetLink(token string) string {
func (s *MailService) ResetLink(token string) string {
base := strings.TrimRight(s.cfg.FrontendBaseURL, "/")
path := s.cfg.ResetPasswordPath
if path == "" {
@ -137,11 +137,11 @@ func (s *Service) ResetLink(token string) string {
return base + path + "?token=" + token
}
func (s *Service) AppName() string {
func (s *MailService) AppName() string {
return s.cfg.AppName
}
func (s *Service) renderBodies(templateName string, data any) (string, string, error) {
func (s *MailService) renderBodies(templateName string, data any) (string, string, error) {
htmlPath := filepath.Join(s.cfg.TemplatesDir, templateName+".html.tmpl")
textPath := filepath.Join(s.cfg.TemplatesDir, templateName+".txt.tmpl")
@ -195,7 +195,7 @@ func buildMessage(from, to, subject, textBody, htmlBody string) []byte {
return []byte(strings.Join(append(headers, body...), "\r\n"))
}
func (s *Service) sendSMTP(ctx context.Context, to string, raw []byte) error {
func (s *MailService) sendSMTP(ctx context.Context, to string, raw []byte) error {
addr := fmt.Sprintf("%s:%d", s.cfg.SMTP.Host, s.cfg.SMTP.Port)
dialer := &net.Dialer{}
conn, err := dialer.DialContext(ctx, "tcp", addr)
@ -245,7 +245,7 @@ func (s *Service) sendSMTP(ctx context.Context, to string, raw []byte) error {
return nil
}
func (s *Service) writeDebugMail(to, subject string, raw []byte) error {
func (s *MailService) writeDebugMail(to, subject string, raw []byte) error {
safeRecipient := strings.NewReplacer("@", "_at_", "/", "_", "\\", "_", ":", "_", " ", "_").Replace(to)
filename := fmt.Sprintf("%d_%s.eml", time.Now().UnixNano(), safeRecipient)
path := filepath.Join(s.cfg.DebugDir, filename)

View File

@ -8,13 +8,9 @@ type SimpleResponse struct {
}
// success wraps a payload in the standard API envelope.
func success(data any) fiber.Map {
func Success(data any) fiber.Map {
return fiber.Map{
"data": data,
"error": nil,
}
}
func Success(data any) fiber.Map {
return success(data)
}

View File

@ -8,17 +8,6 @@ import (
"github.com/gofiber/fiber/v3"
)
// Typescript: interface
type FormRequest struct {
Req string `json:"req"`
Count int `json:"count"`
}
// Typescript: interface
type FormResponse struct {
Test string `json:"test"`
}
func Register(app *fiber.App) {
systemUtils.RegisterSystemRoutes(app)
users.RegisterUserRoutes(app)

View File

@ -6,6 +6,7 @@ import (
"encoding/base64"
"encoding/hex"
"errors"
"server/internal/config"
"time"
"github.com/gofiber/fiber/v3"
@ -13,19 +14,12 @@ import (
)
type TockenService struct {
cfg Config
cfg config.AuthConfig
secret []byte
accessExpiry time.Duration
refreshExpiry time.Duration
}
type Config struct {
Secret string
Issuer string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
}
type Claims struct {
Username string `json:"username"`
Role string `json:"role"`
@ -44,22 +38,38 @@ const (
TokenTypeRefresh = "refresh"
)
func NewTockenService(cfg Config) (*TockenService, error) {
var Tockens *TockenService
func GetTockenService() (*TockenService, error) {
if Tockens == nil {
cfg, err := config.GetConfig()
if err != nil {
return nil, err
}
Tockens, err = NewTockenService(cfg.Auth)
if err != nil {
return nil, err
}
}
return Tockens, nil
}
func NewTockenService(cfg config.AuthConfig) (*TockenService, error) {
if cfg.Secret == "" {
return nil, errors.New("jwt secret is required")
}
if cfg.AccessTokenExpiry <= 0 {
if cfg.AccessTokenExpiryMinutes <= 0 {
return nil, errors.New("access token expiry must be positive")
}
if cfg.RefreshTokenExpiry <= 0 {
if cfg.RefreshTokenExpiryMinutes <= 0 {
return nil, errors.New("refresh token expiry must be positive")
}
return &TockenService{
cfg: cfg,
secret: []byte(cfg.Secret),
accessExpiry: cfg.AccessTokenExpiry,
refreshExpiry: cfg.RefreshTokenExpiry,
accessExpiry: time.Duration(cfg.AccessTokenExpiryMinutes) * time.Minute,
refreshExpiry: time.Duration(cfg.RefreshTokenExpiryMinutes) * time.Minute,
}, nil
}

View File

@ -6,8 +6,6 @@ import (
"fmt"
"os"
"os/exec"
"server/internal/config"
tsrpc "server/pkg/ts-rpc"
)
@ -32,14 +30,6 @@ func TsGenerate() (string, error) {
return "", fmt.Errorf("write local generated typescript: %w", err)
}
configPath := os.Getenv("CONFIG_PATH")
if configPath == "" {
configPath = "configs/config.json"
}
if _, err := config.LoadConfig(); err != nil {
return "", fmt.Errorf("load config from %s: %w", configPath, err)
}
frontendAPIPath := os.Getenv("FRONTEND_API_PATH")
if frontendAPIPath == "" {
return "", errors.New("FRONTEND_API_PATH must be set")

View File

@ -670,13 +670,25 @@ func (uc *UserController) Me(c fiber.Ctx) error {
return c.JSON(responses.Success(&user))
}
func (us *UserController) Refresh(refreshToken string) (tokens.TokenPair, error) {
claims, err := us.TockenService.ParseToken(refreshToken)
func (us *UserController) Refresh(c fiber.Ctx) error {
var req RefreshRequest
if err := c.Bind().Body(&req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, "invalid payload")
}
if req.RefreshToken == "" {
return fiber.NewError(fiber.StatusBadRequest, "refresh_token is required")
}
claims, err := us.TockenService.ParseToken(req.RefreshToken)
if err != nil {
return tokens.TokenPair{}, err
return fiber.NewError(fiber.StatusUnauthorized, err.Error())
}
if claims.TokenType != tokens.TokenTypeRefresh {
return tokens.TokenPair{}, errors.New("refresh token required")
return fiber.NewError(fiber.StatusUnauthorized, "refresh token required")
}
return us.TockenService.GenerateTokenPair(claims.Username)
tokens, err := us.TockenService.GenerateTokenPair(claims.Username)
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, err.Error())
}
return c.JSON(responses.Success(tokens))
}

View File

@ -1,6 +1,7 @@
package users
import (
"fmt"
"server/internal/roles"
"server/internal/tokens"
"time"
@ -10,12 +11,10 @@ import (
)
func RegisterUserRoutes(app *fiber.App) {
tockenService, _ := tokens.NewTockenService(tokens.Config{
Secret: "your-secret-key",
Issuer: "your-issuer",
AccessTokenExpiry: time.Hour,
RefreshTokenExpiry: 24 * time.Hour,
})
tockenService, err := tokens.GetTockenService()
if err != nil {
panic(fmt.Sprintf("token service: %v", err))
}
authRateLimiter := limiter.New(limiter.Config{
Max: 10,
@ -25,37 +24,37 @@ func RegisterUserRoutes(app *fiber.App) {
userController := NewUserController(tockenService)
// Typescript: TSEndpoint= path=/users/:uuid; name=getUser; method=GET; response=models.UserProfile
// Typescript: TSEndpoint= path=/users/:uuid; name=getUser; method=GET; response=users.UserProfile
app.Get("/users/:uuid", tockenService.Middleware(), userController.GetUser)
// Typescript: TSEndpoint= path=/users; name=createUser; method=POST; request=models.UserCreateInput; response=models.UserProfile
// Typescript: TSEndpoint= path=/users; name=createUser; method=POST; request=users.UserCreateInput; response=users.UserProfile
app.Post("/users", tockenService.Middleware(), userController.CreateUser)
// Typescript: TSEndpoint= path=/users/:uuid; name=updateUser; method=PUT; request=controllers.UpdateUserRequest; response=models.UserProfile
// Typescript: TSEndpoint= path=/users/:uuid; name=updateUser; method=PUT; request=users.UpdateUserRequest; response=users.UserProfile
app.Put("/users/:uuid", tockenService.Middleware(), userController.UpdateUser)
// Typescript: TSEndpoint= path=/users/:uuid; name=deleteUser; method=DELETE; response=controllers.SimpleResponse
// Typescript: TSEndpoint= path=/users/:uuid; name=deleteUser; method=DELETE; response=responses.SimpleResponse
app.Delete("/users/:uuid", tockenService.Middleware(), userController.DeleteUser)
// Typescript: TSEndpoint= path=/auth/me; name=me; method=GET; response=models.UserShort
// Typescript: TSEndpoint= path=/auth/me; name=me; method=GET; response=users.User
app.Get("/auth/me", tockenService.Middleware(), userController.Me)
roles.RegisterEndpoint("GET/auth/me", int(roles.UserPermission))
// Typescript: TSEndpoint= path=/auth/login; name=login; method=POST; request=model.LoginRequest; response=model.TokenPair
// Typescript: TSEndpoint= path=/auth/login; name=login; method=POST; request=users.LoginRequest; response=tokens.TokenPair
app.Post("/auth/login", authRateLimiter, userController.Login)
// Typescript: TSEndpoint= path=/auth/refresh; name=refresh; method=POST; request=model.RefreshRequest; response=model.TokenPair
// Typescript: TSEndpoint= path=/auth/refresh; name=refresh; method=POST; request=tokens.RefreshRequest; response=tokens.TokenPair
app.Post("/auth/refresh", authRateLimiter, userController.Refresh)
// Typescript: TSEndpoint= path=/auth/register; name=register; method=POST; request=models.UserCreateInput; response=models.UserShort
// Typescript: TSEndpoint= path=/auth/register; name=register; method=POST; request=users.UserCreateInput; response=users.User
app.Post("/auth/register", authRateLimiter, userController.Register)
// Typescript: TSEndpoint= path=/auth/password/forgot; name=forgotPassword; method=POST; request=model.ForgotPasswordRequest; response=controllers.SimpleResponse
// Typescript: TSEndpoint= path=/auth/password/forgot; name=forgotPassword; method=POST; request=users.ForgotPasswordRequest; response=responses.SimpleResponse
app.Post("/auth/password/forgot", authRateLimiter, userController.ForgotPassword)
// Typescript: TSEndpoint= path=/auth/password/reset; name=resetPassword; method=POST; request=model.ResetPasswordRequest; response=controllers.SimpleResponse
// Typescript: TSEndpoint= path=/auth/password/reset; name=resetPassword; method=POST; request=users.ResetPasswordRequest; response=responses.SimpleResponse
app.Post("/auth/password/reset", authRateLimiter, userController.ResetPassword)
// Typescript: TSEndpoint= path=/auth/password/valid; name=validToken; method=POST; request=string; response=controllers.SimpleResponse
// Typescript: TSEndpoint= path=/auth/password/valid; name=validToken; method=POST; request=string; response=responses.SimpleResponse
app.Post("/auth/password/valid", authRateLimiter, userController.ValidToken)
}