backend-server-v2/internal/http/middleware/current_user.go

127 lines
2.7 KiB
Go

package middleware
import (
"fmt"
"strings"
"trustcontact/internal/models"
"trustcontact/internal/repo"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/session"
"gorm.io/gorm"
)
const (
sessionUserIDKey = "user_id"
contextUserKey = "current_user"
contextStoreKey = "session_store"
contextTemplateKey = "template_data"
)
func SessionStoreMiddleware(store *session.Store) fiber.Handler {
return func(c *fiber.Ctx) error {
c.Locals(contextStoreKey, store)
return c.Next()
}
}
func CurrentUserMiddleware(store *session.Store, database *gorm.DB) fiber.Handler {
userRepo := repo.NewUserRepo(database)
return func(c *fiber.Ctx) error {
user, err := CurrentUser(c, store, userRepo)
if err != nil {
return err
}
c.Locals(contextUserKey, user)
setTemplateData(c, "CurrentUser", user)
if user != nil {
setTemplateData(c, "UserLang", strings.TrimSpace(user.Properties.Lang))
if user.Properties.UserId != "" {
if user.Properties.Dark {
setTemplateData(c, "UserTheme", "dark")
} else {
setTemplateData(c, "UserTheme", "light")
}
} else {
setTemplateData(c, "UserTheme", "")
}
} else {
setTemplateData(c, "UserLang", "")
setTemplateData(c, "UserTheme", "")
}
return c.Next()
}
}
func CurrentUser(c *fiber.Ctx, store *session.Store, userRepo *repo.UserRepo) (*models.User, error) {
sess, err := store.Get(c)
if err != nil {
return nil, fmt.Errorf("get session: %w", err)
}
uidRaw := sess.Get(sessionUserIDKey)
uid, ok := normalizeUserID(uidRaw)
if !ok || uid == "" {
return nil, nil
}
user, err := userRepo.FindByID(uid)
if err != nil {
return nil, fmt.Errorf("load current user: %w", err)
}
if user == nil {
sess.Delete(sessionUserIDKey)
if err := sess.Save(); err != nil {
return nil, fmt.Errorf("save session: %w", err)
}
return nil, nil
}
return user, nil
}
func CurrentUserFromContext(c *fiber.Ctx) (*models.User, bool) {
user, ok := c.Locals(contextUserKey).(*models.User)
if !ok || user == nil {
return nil, false
}
return user, true
}
func normalizeUserID(v any) (string, bool) {
switch value := v.(type) {
case string:
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return "", false
}
return trimmed, true
default:
return "", false
}
}
func setTemplateData(c *fiber.Ctx, key string, value any) {
data := templateData(c)
data[key] = value
c.Locals(contextTemplateKey, data)
}
func SetTemplateData(c *fiber.Ctx, key string, value any) {
setTemplateData(c, key, value)
}
func templateData(c *fiber.Ctx) map[string]any {
existing, ok := c.Locals(contextTemplateKey).(map[string]any)
if ok && existing != nil {
return existing
}
fresh := make(map[string]any)
c.Locals(contextTemplateKey, fresh)
return fresh
}