package middleware import ( "fmt" "strings" "trustcontact/internal/models" "trustcontact/internal/repo" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/session" "gorm.io/gorm" ) const ( sessionUserIDKey = "user_id" contextUserKey = "current_user" contextStoreKey = "session_store" contextTemplateKey = "template_data" ) func SessionStoreMiddleware(store *session.Store) fiber.Handler { return func(c *fiber.Ctx) error { c.Locals(contextStoreKey, store) return c.Next() } } func CurrentUserMiddleware(store *session.Store, database *gorm.DB) fiber.Handler { userRepo := repo.NewUserRepo(database) return func(c *fiber.Ctx) error { user, err := CurrentUser(c, store, userRepo) if err != nil { return err } c.Locals(contextUserKey, user) setTemplateData(c, "CurrentUser", user) if user != nil { setTemplateData(c, "UserLang", strings.TrimSpace(user.Properties.Lang)) if user.Properties.UserId != "" { if user.Properties.Dark { setTemplateData(c, "UserTheme", "dark") } else { setTemplateData(c, "UserTheme", "light") } } else { setTemplateData(c, "UserTheme", "") } } else { setTemplateData(c, "UserLang", "") setTemplateData(c, "UserTheme", "") } return c.Next() } } func CurrentUser(c *fiber.Ctx, store *session.Store, userRepo *repo.UserRepo) (*models.User, error) { sess, err := store.Get(c) if err != nil { return nil, fmt.Errorf("get session: %w", err) } uidRaw := sess.Get(sessionUserIDKey) uid, ok := normalizeUserID(uidRaw) if !ok || uid == "" { return nil, nil } user, err := userRepo.FindByID(uid) if err != nil { return nil, fmt.Errorf("load current user: %w", err) } if user == nil { sess.Delete(sessionUserIDKey) if err := sess.Save(); err != nil { return nil, fmt.Errorf("save session: %w", err) } return nil, nil } return user, nil } func CurrentUserFromContext(c *fiber.Ctx) (*models.User, bool) { user, ok := c.Locals(contextUserKey).(*models.User) if !ok || user == nil { return nil, false } return user, true } func normalizeUserID(v any) (string, bool) { switch value := v.(type) { case string: trimmed := strings.TrimSpace(value) if trimmed == "" { return "", false } return trimmed, true default: return "", false } } func setTemplateData(c *fiber.Ctx, key string, value any) { data := templateData(c) data[key] = value c.Locals(contextTemplateKey, data) } func SetTemplateData(c *fiber.Ctx, key string, value any) { setTemplateData(c, key, value) } func templateData(c *fiber.Ctx) map[string]any { existing, ok := c.Locals(contextTemplateKey).(map[string]any) if ok && existing != nil { return existing } fresh := make(map[string]any) c.Locals(contextTemplateKey, fresh) return fresh }