package middleware import ( "bytes" "crypto/sha256" "encoding/hex" "encoding/json" "fmt" "net/http" "strings" "time" "github.com/gin-gonic/gin" "github.com/redis/go-redis/v9" "go.uber.org/zap" ) // ResponseCacheConfig configures the HTTP response cache middleware type ResponseCacheConfig struct { RedisClient *redis.Client Logger *zap.Logger DefaultTTL time.Duration KeyPrefix string EndpointTTLs map[string]time.Duration // path prefix → TTL override } // cachedResponse stores the cached HTTP response data type cachedResponse struct { Status int `json:"status"` ContentType string `json:"content_type"` Body string `json:"body"` Headers map[string]string `json:"headers"` } // responseWriter captures the response for caching type cacheResponseWriter struct { gin.ResponseWriter body *bytes.Buffer } func (w *cacheResponseWriter) Write(b []byte) (int, error) { w.body.Write(b) return w.ResponseWriter.Write(b) } // ResponseCache returns a middleware that caches GET responses in Redis. // Only caches successful (2xx) GET requests for unauthenticated or public endpoints. func ResponseCache(cfg ResponseCacheConfig) gin.HandlerFunc { if cfg.RedisClient == nil { return func(c *gin.Context) { c.Next() } } if cfg.DefaultTTL == 0 { cfg.DefaultTTL = 5 * time.Minute } if cfg.KeyPrefix == "" { cfg.KeyPrefix = "http_cache" } if cfg.Logger == nil { cfg.Logger = zap.NewNop() } return func(c *gin.Context) { // Only cache GET requests if c.Request.Method != http.MethodGet { c.Next() return } // Skip caching for authenticated requests (user-specific data) if c.GetHeader("Authorization") != "" { c.Next() return } // Skip caching for cookie-authenticated requests (httpOnly auth cookies) if _, err := c.Cookie("access_token"); err == nil { c.Next() return } // Skip caching for auth endpoints (must never serve cached user data) if strings.Contains(c.Request.URL.Path, "/auth/") { c.Next() return } // Generate cache key from URL + query params cacheKey := generateCacheKey(cfg.KeyPrefix, c.Request.URL.RequestURI()) // Try to serve from cache ctx := c.Request.Context() cached, err := cfg.RedisClient.Get(ctx, cacheKey).Result() if err == nil { // Cache hit — serve from cache var resp cachedResponse if jsonErr := json.Unmarshal([]byte(cached), &resp); jsonErr == nil { for k, v := range resp.Headers { c.Header(k, v) } c.Header("X-Cache", "HIT") c.Data(resp.Status, resp.ContentType, []byte(resp.Body)) c.Abort() return } } // Cache miss — capture response writer := &cacheResponseWriter{ ResponseWriter: c.Writer, body: &bytes.Buffer{}, } c.Writer = writer c.Header("X-Cache", "MISS") c.Next() // Only cache successful responses status := writer.Status() if status < 200 || status >= 300 { return } // Determine TTL for this endpoint ttl := cfg.DefaultTTL for prefix, override := range cfg.EndpointTTLs { if strings.HasPrefix(c.Request.URL.Path, prefix) { ttl = override break } } // Store in cache resp := cachedResponse{ Status: status, ContentType: writer.Header().Get("Content-Type"), Body: writer.body.String(), Headers: map[string]string{ "Content-Type": writer.Header().Get("Content-Type"), }, } data, err := json.Marshal(resp) if err != nil { cfg.Logger.Debug("Failed to marshal response for cache", zap.Error(err)) return } if setErr := cfg.RedisClient.Set(ctx, cacheKey, data, ttl).Err(); setErr != nil { cfg.Logger.Debug("Failed to store response in cache", zap.String("key", cacheKey), zap.Error(setErr)) } } } // generateCacheKey creates a deterministic cache key from a URI func generateCacheKey(prefix, uri string) string { hash := sha256.Sum256([]byte(uri)) return fmt.Sprintf("%s:%s", prefix, hex.EncodeToString(hash[:16])) }