package middleware import ( "fmt" "strconv" "trustcontact/internal/models" "trustcontact/internal/repo" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/session" "gorm.io/gorm" ) const ( sessionUserIDKey = "user_id" contextUserKey = "current_user" contextStoreKey = "session_store" contextTemplateKey = "template_data" ) func SessionStoreMiddleware(store *session.Store) fiber.Handler { return func(c *fiber.Ctx) error { c.Locals(contextStoreKey, store) return c.Next() } } func CurrentUserMiddleware(store *session.Store, database *gorm.DB) fiber.Handler { userRepo := repo.NewUserRepo(database) return func(c *fiber.Ctx) error { user, err := CurrentUser(c, store, userRepo) if err != nil { return err } c.Locals(contextUserKey, user) setTemplateData(c, "CurrentUser", user) return c.Next() } } func CurrentUser(c *fiber.Ctx, store *session.Store, userRepo *repo.UserRepo) (*models.User, error) { sess, err := store.Get(c) if err != nil { return nil, fmt.Errorf("get session: %w", err) } uidRaw := sess.Get(sessionUserIDKey) uid, ok := normalizeUserID(uidRaw) if !ok || uid == 0 { return nil, nil } user, err := userRepo.FindByID(uid) if err != nil { return nil, fmt.Errorf("load current user: %w", err) } if user == nil { sess.Delete(sessionUserIDKey) if err := sess.Save(); err != nil { return nil, fmt.Errorf("save session: %w", err) } return nil, nil } return user, nil } func CurrentUserFromContext(c *fiber.Ctx) (*models.User, bool) { user, ok := c.Locals(contextUserKey).(*models.User) if !ok || user == nil { return nil, false } return user, true } func normalizeUserID(v any) (uint, bool) { switch value := v.(type) { case uint: return value, true case uint64: return uint(value), true case uint32: return uint(value), true case int: if value <= 0 { return 0, false } return uint(value), true case int64: if value <= 0 { return 0, false } return uint(value), true case int32: if value <= 0 { return 0, false } return uint(value), true case string: parsed, err := strconv.ParseUint(value, 10, 64) if err != nil || parsed == 0 { return 0, false } return uint(parsed), true default: return 0, false } } func setTemplateData(c *fiber.Ctx, key string, value any) { data := templateData(c) data[key] = value c.Locals(contextTemplateKey, data) } func SetTemplateData(c *fiber.Ctx, key string, value any) { setTemplateData(c, key, value) } func templateData(c *fiber.Ctx) map[string]any { existing, ok := c.Locals(contextTemplateKey).(map[string]any) if ok && existing != nil { return existing } fresh := make(map[string]any) c.Locals(contextTemplateKey, fresh) return fresh }