From 5b9fe6c9b72897b1df6160af622aed0cf6a4ed55 Mon Sep 17 00:00:00 2001 From: fabio Date: Thu, 9 Apr 2026 09:21:38 +0200 Subject: [PATCH] Refactor configuration and database initialization: replace LoadConfig with GetConfig, introduce DbConfig struct, and streamline token service retrieval. Remove unused request types and enhance user refresh token handling. --- backend/cmd/server/main.go | 26 ++++-------- backend/internal/auth/request.go | 23 ---------- backend/internal/config/config.go | 51 +++++++++++++++++------ backend/internal/db/db.go | 34 +++++++++------ backend/internal/mail/service.go | 24 +++++------ backend/internal/responses/services.go | 6 +-- backend/internal/routes/register.go | 11 ----- backend/internal/tokens/services.go | 36 ++++++++++------ backend/internal/tsgenerator/generator.go | 10 ----- backend/internal/user/controller.go | 22 +++++++--- backend/internal/user/routes.go | 33 +++++++-------- 11 files changed, 138 insertions(+), 138 deletions(-) delete mode 100644 backend/internal/auth/request.go diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index bd2783c..9de16e1 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -36,20 +36,12 @@ func main() { seedCount := flag.Int("seed", 0, "seed N fake users at startup (0 to skip)") flag.Parse() - cfg, err := config.LoadConfig() + cfg, err := config.GetConfig() if err != nil { - log.Fatalf("load config: %v", err) + log.Fatalf("config: %v", err) } - if secret := os.Getenv("AUTH_SECRET"); secret != "" { - cfg.Auth.Secret = secret - } - - dbCfg := db.Config{ - Driver: envOrDefault("DB_driver", "sqlite"), - DSN: envOrDefault("DB_dsn", "file:./data/data.db?_foreign_keys=on"), - } - dbConn, err := db.Init(dbCfg) + dbConn, err := db.GetDB() if err != nil { log.Fatalf("init db: %v", err) } @@ -58,12 +50,10 @@ func main() { log.Fatalf("migrate user: %v", err) } - tockenService, _ := tokens.NewTockenService(tokens.Config{ - Secret: "your-secret-key", - Issuer: "your-issuer", - AccessTokenExpiry: time.Hour, - RefreshTokenExpiry: 24 * time.Hour, - }) + tokenService, err := tokens.GetTockenService() + if err != nil { + log.Fatalf("init tokens: %v", err) + } app := fiber.New(fiber.Config{ AppName: cfg.AppName, @@ -107,7 +97,7 @@ func main() { return c.Next() }) - app.Use(roles.RequireEndpointPermission(dbConn, tockenService)) + app.Use(roles.RequireEndpointPermission(dbConn, tokenService)) routes.Register(app) diff --git a/backend/internal/auth/request.go b/backend/internal/auth/request.go deleted file mode 100644 index 0e660dd..0000000 --- a/backend/internal/auth/request.go +++ /dev/null @@ -1,23 +0,0 @@ -package auth - -// Typescript: interface -type LoginRequest struct { - Username string `json:"username" validate:"required,email"` - Password string `json:"password" validate:"required,min=8,max=128"` -} - -// Typescript: interface -type RefreshRequest struct { - RefreshToken string `json:"refresh_token"` -} - -// Typescript: interface -type ForgotPasswordRequest struct { - Email string `json:"email" validate:"required,email"` -} - -// Typescript: interface -type ResetPasswordRequest struct { - Token string `json:"token" validate:"required,min=20,max=255"` - Password string `json:"password" validate:"required,min=8,max=128"` -} diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 5c0b2c0..1b880e6 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -14,6 +14,7 @@ type ServerConfig struct { DisableStartupMessage bool `json:"disable_startup_message"` Auth AuthConfig `json:"auth"` Mail MailConfig `json:"mail"` + Db DbConfig `json:"db_config"` RolesConfigPath string `json:"roles_config_path"` } @@ -42,6 +43,25 @@ type SMTPMailConfig struct { Password string `json:"password"` } +type DbConfig struct { + Driver string + DSN string +} + +var Config *ServerConfig = nil + +func GetConfig() (*ServerConfig, error) { + if Config == nil { + var err error + Config, err = loadConfig() + if err != nil { + fmt.Printf("Failed to load config: %v\n", err) + return nil, err + } + } + return Config, nil +} + func envOrDefault(key, defaultValue string) string { if value, exists := os.LookupEnv(key); exists { return value @@ -49,26 +69,28 @@ func envOrDefault(key, defaultValue string) string { return defaultValue } -func LoadConfig() (ServerConfig, error) { - +func loadConfig() (*ServerConfig, error) { path := envOrDefault("CONFIG_PATH", "configs/config.json") data, err := os.ReadFile(path) if err != nil { - return ServerConfig{}, fmt.Errorf("read config: %w", err) + return nil, fmt.Errorf("read config: %w", err) } var cfg ServerConfig if err := json.Unmarshal(data, &cfg); err != nil { - return ServerConfig{}, fmt.Errorf("parse config: %w", err) + return nil, fmt.Errorf("parse config: %w", err) + } + if secret := os.Getenv("AUTH_SECRET"); secret != "" { + cfg.Auth.Secret = secret } if cfg.Auth.Secret == "" { - return ServerConfig{}, fmt.Errorf("auth.secret must be set") + return nil, fmt.Errorf("auth.secret must be set") } if cfg.Auth.AccessTokenExpiryMinutes <= 0 { - return ServerConfig{}, fmt.Errorf("auth.access_token_expiry_minutes must be greater than zero") + return nil, fmt.Errorf("auth.access_token_expiry_minutes must be greater than zero") } if cfg.Auth.RefreshTokenExpiryMinutes <= 0 { - return ServerConfig{}, fmt.Errorf("auth.refresh_token_expiry_minutes must be greater than zero") + return nil, fmt.Errorf("auth.refresh_token_expiry_minutes must be greater than zero") } if cfg.Mail.Mode == "" { cfg.Mail.Mode = "file" @@ -80,21 +102,26 @@ func LoadConfig() (ServerConfig, error) { cfg.Mail.ResetPasswordPath = "/#reset-password" } if cfg.Mail.Mode != "smtp" && cfg.Mail.Mode != "file" { - return ServerConfig{}, fmt.Errorf("mail.mode must be either smtp or file") + return nil, fmt.Errorf("mail.mode must be either smtp or file") } if cfg.Mail.From == "" { - return ServerConfig{}, fmt.Errorf("mail.from must be set") + return nil, fmt.Errorf("mail.from must be set") } if cfg.Mail.Mode == "smtp" { if cfg.Mail.SMTP.Host == "" { - return ServerConfig{}, fmt.Errorf("mail.smtp.host must be set when mail.mode=smtp") + return nil, fmt.Errorf("mail.smtp.host must be set when mail.mode=smtp") } if cfg.Mail.SMTP.Port <= 0 { - return ServerConfig{}, fmt.Errorf("mail.smtp.port must be greater than zero when mail.mode=smtp") + return nil, fmt.Errorf("mail.smtp.port must be greater than zero when mail.mode=smtp") } } else if cfg.Mail.DebugDir == "" { cfg.Mail.DebugDir = "data/mail-debug" } - return cfg, nil + cfg.Db = DbConfig{ + Driver: envOrDefault("DB_driver", "sqlite"), + DSN: envOrDefault("DB_dsn", "file:./data/data.db?_foreign_keys=on"), + } + + return &cfg, nil } diff --git a/backend/internal/db/db.go b/backend/internal/db/db.go index 1a689f0..2d24663 100644 --- a/backend/internal/db/db.go +++ b/backend/internal/db/db.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "path/filepath" + "server/internal/config" "strings" "github.com/gofiber/fiber/v3" @@ -12,33 +13,42 @@ import ( "gorm.io/gorm" ) -type Config struct { - Driver string - DSN string -} - var DB *gorm.DB +// GetDB returns the global *gorm.DB instance. It panics if the database is not initialized. +func GetDB() (*gorm.DB, error) { + if DB == nil { + cfg, err := config.GetConfig() + if err != nil { + return nil, err + } + DB, err = InitDB(cfg.Db) + if err != nil { + fmt.Printf("Failed to initialize database: %v\n", err) + return nil, err + } + } + return DB, nil +} + // Init opens the configured database connection and runs schema migrations. -func Init(cfg Config) (*gorm.DB, error) { +func InitDB(cfg config.DbConfig) (*gorm.DB, error) { switch cfg.Driver { case "sqlite": if err := ensureSQLiteDir(cfg.DSN); err != nil { return nil, fmt.Errorf("prepare sqlite path: %w", err) } - db, err := gorm.Open(sqlite.Open(cfg.DSN), &gorm.Config{}) + DB, err := gorm.Open(sqlite.Open(cfg.DSN), &gorm.Config{}) if err != nil { return nil, fmt.Errorf("open sqlite: %w", err) } - DB = db - return db, nil + return DB, nil case "postgres": - db, err := gorm.Open(postgres.Open(cfg.DSN), &gorm.Config{}) + DB, err := gorm.Open(postgres.Open(cfg.DSN), &gorm.Config{}) if err != nil { return nil, fmt.Errorf("open postgres: %w", err) } - DB = db - return db, nil + return DB, nil default: return nil, fmt.Errorf("unsupported driver %q", cfg.Driver) } diff --git a/backend/internal/mail/service.go b/backend/internal/mail/service.go index d88593b..7116d05 100644 --- a/backend/internal/mail/service.go +++ b/backend/internal/mail/service.go @@ -6,7 +6,6 @@ import ( "crypto/tls" "fmt" "html/template" - "log" "net" "net/smtp" "os" @@ -42,7 +41,7 @@ type Message struct { TemplateData any } -type Service struct { +type MailService struct { cfg Config } @@ -54,11 +53,12 @@ type TemplateData struct { ResetToken string } -func New() (*Service, error) { +// if service fail send admin allert instead a response to user or a simple response server error. +func New() (*MailService, error) { - serverCfg, err := config.LoadConfig() + serverCfg, err := config.GetConfig() if err != nil { - log.Fatalf("load config: %v", err) + return nil, err } cfg := Config{ @@ -102,10 +102,10 @@ func New() (*Service, error) { return nil, fmt.Errorf("smtp host and port are required") } } - return &Service{cfg: cfg}, nil + return &MailService{cfg: cfg}, nil } -func (s *Service) Send(ctx context.Context, msg Message) error { +func (s *MailService) Send(ctx context.Context, msg Message) error { htmlBody, textBody, err := s.renderBodies(msg.Template, msg.TemplateData) if err != nil { return err @@ -122,7 +122,7 @@ func (s *Service) Send(ctx context.Context, msg Message) error { } } -func (s *Service) ResetLink(token string) string { +func (s *MailService) ResetLink(token string) string { base := strings.TrimRight(s.cfg.FrontendBaseURL, "/") path := s.cfg.ResetPasswordPath if path == "" { @@ -137,11 +137,11 @@ func (s *Service) ResetLink(token string) string { return base + path + "?token=" + token } -func (s *Service) AppName() string { +func (s *MailService) AppName() string { return s.cfg.AppName } -func (s *Service) renderBodies(templateName string, data any) (string, string, error) { +func (s *MailService) renderBodies(templateName string, data any) (string, string, error) { htmlPath := filepath.Join(s.cfg.TemplatesDir, templateName+".html.tmpl") textPath := filepath.Join(s.cfg.TemplatesDir, templateName+".txt.tmpl") @@ -195,7 +195,7 @@ func buildMessage(from, to, subject, textBody, htmlBody string) []byte { return []byte(strings.Join(append(headers, body...), "\r\n")) } -func (s *Service) sendSMTP(ctx context.Context, to string, raw []byte) error { +func (s *MailService) sendSMTP(ctx context.Context, to string, raw []byte) error { addr := fmt.Sprintf("%s:%d", s.cfg.SMTP.Host, s.cfg.SMTP.Port) dialer := &net.Dialer{} conn, err := dialer.DialContext(ctx, "tcp", addr) @@ -245,7 +245,7 @@ func (s *Service) sendSMTP(ctx context.Context, to string, raw []byte) error { return nil } -func (s *Service) writeDebugMail(to, subject string, raw []byte) error { +func (s *MailService) writeDebugMail(to, subject string, raw []byte) error { safeRecipient := strings.NewReplacer("@", "_at_", "/", "_", "\\", "_", ":", "_", " ", "_").Replace(to) filename := fmt.Sprintf("%d_%s.eml", time.Now().UnixNano(), safeRecipient) path := filepath.Join(s.cfg.DebugDir, filename) diff --git a/backend/internal/responses/services.go b/backend/internal/responses/services.go index ef5f8af..6b1edef 100644 --- a/backend/internal/responses/services.go +++ b/backend/internal/responses/services.go @@ -8,13 +8,9 @@ type SimpleResponse struct { } // success wraps a payload in the standard API envelope. -func success(data any) fiber.Map { +func Success(data any) fiber.Map { return fiber.Map{ "data": data, "error": nil, } } - -func Success(data any) fiber.Map { - return success(data) -} diff --git a/backend/internal/routes/register.go b/backend/internal/routes/register.go index 39ca61d..01fbc43 100644 --- a/backend/internal/routes/register.go +++ b/backend/internal/routes/register.go @@ -8,17 +8,6 @@ import ( "github.com/gofiber/fiber/v3" ) -// Typescript: interface -type FormRequest struct { - Req string `json:"req"` - Count int `json:"count"` -} - -// Typescript: interface -type FormResponse struct { - Test string `json:"test"` -} - func Register(app *fiber.App) { systemUtils.RegisterSystemRoutes(app) users.RegisterUserRoutes(app) diff --git a/backend/internal/tokens/services.go b/backend/internal/tokens/services.go index d5c256b..b48e0b7 100644 --- a/backend/internal/tokens/services.go +++ b/backend/internal/tokens/services.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "encoding/hex" "errors" + "server/internal/config" "time" "github.com/gofiber/fiber/v3" @@ -13,19 +14,12 @@ import ( ) type TockenService struct { - cfg Config + cfg config.AuthConfig secret []byte accessExpiry time.Duration refreshExpiry time.Duration } -type Config struct { - Secret string - Issuer string - AccessTokenExpiry time.Duration - RefreshTokenExpiry time.Duration -} - type Claims struct { Username string `json:"username"` Role string `json:"role"` @@ -44,22 +38,38 @@ const ( TokenTypeRefresh = "refresh" ) -func NewTockenService(cfg Config) (*TockenService, error) { +var Tockens *TockenService + +func GetTockenService() (*TockenService, error) { + if Tockens == nil { + cfg, err := config.GetConfig() + if err != nil { + return nil, err + } + Tockens, err = NewTockenService(cfg.Auth) + if err != nil { + return nil, err + } + } + return Tockens, nil +} + +func NewTockenService(cfg config.AuthConfig) (*TockenService, error) { if cfg.Secret == "" { return nil, errors.New("jwt secret is required") } - if cfg.AccessTokenExpiry <= 0 { + if cfg.AccessTokenExpiryMinutes <= 0 { return nil, errors.New("access token expiry must be positive") } - if cfg.RefreshTokenExpiry <= 0 { + if cfg.RefreshTokenExpiryMinutes <= 0 { return nil, errors.New("refresh token expiry must be positive") } return &TockenService{ cfg: cfg, secret: []byte(cfg.Secret), - accessExpiry: cfg.AccessTokenExpiry, - refreshExpiry: cfg.RefreshTokenExpiry, + accessExpiry: time.Duration(cfg.AccessTokenExpiryMinutes) * time.Minute, + refreshExpiry: time.Duration(cfg.RefreshTokenExpiryMinutes) * time.Minute, }, nil } diff --git a/backend/internal/tsgenerator/generator.go b/backend/internal/tsgenerator/generator.go index 25312ff..54be96e 100644 --- a/backend/internal/tsgenerator/generator.go +++ b/backend/internal/tsgenerator/generator.go @@ -6,8 +6,6 @@ import ( "fmt" "os" "os/exec" - - "server/internal/config" tsrpc "server/pkg/ts-rpc" ) @@ -32,14 +30,6 @@ func TsGenerate() (string, error) { return "", fmt.Errorf("write local generated typescript: %w", err) } - configPath := os.Getenv("CONFIG_PATH") - if configPath == "" { - configPath = "configs/config.json" - } - if _, err := config.LoadConfig(); err != nil { - return "", fmt.Errorf("load config from %s: %w", configPath, err) - } - frontendAPIPath := os.Getenv("FRONTEND_API_PATH") if frontendAPIPath == "" { return "", errors.New("FRONTEND_API_PATH must be set") diff --git a/backend/internal/user/controller.go b/backend/internal/user/controller.go index c8d363c..2481fc2 100644 --- a/backend/internal/user/controller.go +++ b/backend/internal/user/controller.go @@ -670,13 +670,25 @@ func (uc *UserController) Me(c fiber.Ctx) error { return c.JSON(responses.Success(&user)) } -func (us *UserController) Refresh(refreshToken string) (tokens.TokenPair, error) { - claims, err := us.TockenService.ParseToken(refreshToken) +func (us *UserController) Refresh(c fiber.Ctx) error { + var req RefreshRequest + if err := c.Bind().Body(&req); err != nil { + return fiber.NewError(fiber.StatusBadRequest, "invalid payload") + } + if req.RefreshToken == "" { + return fiber.NewError(fiber.StatusBadRequest, "refresh_token is required") + } + + claims, err := us.TockenService.ParseToken(req.RefreshToken) if err != nil { - return tokens.TokenPair{}, err + return fiber.NewError(fiber.StatusUnauthorized, err.Error()) } if claims.TokenType != tokens.TokenTypeRefresh { - return tokens.TokenPair{}, errors.New("refresh token required") + return fiber.NewError(fiber.StatusUnauthorized, "refresh token required") } - return us.TockenService.GenerateTokenPair(claims.Username) + tokens, err := us.TockenService.GenerateTokenPair(claims.Username) + if err != nil { + return fiber.NewError(fiber.StatusInternalServerError, err.Error()) + } + return c.JSON(responses.Success(tokens)) } diff --git a/backend/internal/user/routes.go b/backend/internal/user/routes.go index c6c192c..b2edb7a 100644 --- a/backend/internal/user/routes.go +++ b/backend/internal/user/routes.go @@ -1,6 +1,7 @@ package users import ( + "fmt" "server/internal/roles" "server/internal/tokens" "time" @@ -10,12 +11,10 @@ import ( ) func RegisterUserRoutes(app *fiber.App) { - tockenService, _ := tokens.NewTockenService(tokens.Config{ - Secret: "your-secret-key", - Issuer: "your-issuer", - AccessTokenExpiry: time.Hour, - RefreshTokenExpiry: 24 * time.Hour, - }) + tockenService, err := tokens.GetTockenService() + if err != nil { + panic(fmt.Sprintf("token service: %v", err)) + } authRateLimiter := limiter.New(limiter.Config{ Max: 10, @@ -25,37 +24,37 @@ func RegisterUserRoutes(app *fiber.App) { userController := NewUserController(tockenService) - // Typescript: TSEndpoint= path=/users/:uuid; name=getUser; method=GET; response=models.UserProfile + // Typescript: TSEndpoint= path=/users/:uuid; name=getUser; method=GET; response=users.UserProfile app.Get("/users/:uuid", tockenService.Middleware(), userController.GetUser) - // Typescript: TSEndpoint= path=/users; name=createUser; method=POST; request=models.UserCreateInput; response=models.UserProfile + // Typescript: TSEndpoint= path=/users; name=createUser; method=POST; request=users.UserCreateInput; response=users.UserProfile app.Post("/users", tockenService.Middleware(), userController.CreateUser) - // Typescript: TSEndpoint= path=/users/:uuid; name=updateUser; method=PUT; request=controllers.UpdateUserRequest; response=models.UserProfile + // Typescript: TSEndpoint= path=/users/:uuid; name=updateUser; method=PUT; request=users.UpdateUserRequest; response=users.UserProfile app.Put("/users/:uuid", tockenService.Middleware(), userController.UpdateUser) - // Typescript: TSEndpoint= path=/users/:uuid; name=deleteUser; method=DELETE; response=controllers.SimpleResponse + // Typescript: TSEndpoint= path=/users/:uuid; name=deleteUser; method=DELETE; response=responses.SimpleResponse app.Delete("/users/:uuid", tockenService.Middleware(), userController.DeleteUser) - // Typescript: TSEndpoint= path=/auth/me; name=me; method=GET; response=models.UserShort + // Typescript: TSEndpoint= path=/auth/me; name=me; method=GET; response=users.User app.Get("/auth/me", tockenService.Middleware(), userController.Me) roles.RegisterEndpoint("GET/auth/me", int(roles.UserPermission)) - // Typescript: TSEndpoint= path=/auth/login; name=login; method=POST; request=model.LoginRequest; response=model.TokenPair + // Typescript: TSEndpoint= path=/auth/login; name=login; method=POST; request=users.LoginRequest; response=tokens.TokenPair app.Post("/auth/login", authRateLimiter, userController.Login) - // Typescript: TSEndpoint= path=/auth/refresh; name=refresh; method=POST; request=model.RefreshRequest; response=model.TokenPair + // Typescript: TSEndpoint= path=/auth/refresh; name=refresh; method=POST; request=tokens.RefreshRequest; response=tokens.TokenPair app.Post("/auth/refresh", authRateLimiter, userController.Refresh) - // Typescript: TSEndpoint= path=/auth/register; name=register; method=POST; request=models.UserCreateInput; response=models.UserShort + // Typescript: TSEndpoint= path=/auth/register; name=register; method=POST; request=users.UserCreateInput; response=users.User app.Post("/auth/register", authRateLimiter, userController.Register) - // Typescript: TSEndpoint= path=/auth/password/forgot; name=forgotPassword; method=POST; request=model.ForgotPasswordRequest; response=controllers.SimpleResponse + // Typescript: TSEndpoint= path=/auth/password/forgot; name=forgotPassword; method=POST; request=users.ForgotPasswordRequest; response=responses.SimpleResponse app.Post("/auth/password/forgot", authRateLimiter, userController.ForgotPassword) - // Typescript: TSEndpoint= path=/auth/password/reset; name=resetPassword; method=POST; request=model.ResetPasswordRequest; response=controllers.SimpleResponse + // Typescript: TSEndpoint= path=/auth/password/reset; name=resetPassword; method=POST; request=users.ResetPasswordRequest; response=responses.SimpleResponse app.Post("/auth/password/reset", authRateLimiter, userController.ResetPassword) - // Typescript: TSEndpoint= path=/auth/password/valid; name=validToken; method=POST; request=string; response=controllers.SimpleResponse + // Typescript: TSEndpoint= path=/auth/password/valid; name=validToken; method=POST; request=string; response=responses.SimpleResponse app.Post("/auth/password/valid", authRateLimiter, userController.ValidToken) }