114 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			114 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package ym
 | 
						|
 | 
						|
import (
 | 
						|
	"fmt"
 | 
						|
	"net/http"
 | 
						|
	"sipro-mps/internal/redis"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/redis/rueidis"
 | 
						|
)
 | 
						|
 | 
						|
// RateLimit defines a rate limit configuration
 | 
						|
type RateLimit struct {
 | 
						|
	Count     int           // Number of requests allowed
 | 
						|
	TimeDelta time.Duration // Time window
 | 
						|
}
 | 
						|
 | 
						|
// Path rate limits for Yandex Market API
 | 
						|
var PathLimits = map[string]RateLimit{
 | 
						|
	"/tariffs/calculate": {Count: 100, TimeDelta: time.Minute},
 | 
						|
	"/campaigns":         {Count: 300, TimeDelta: time.Minute},
 | 
						|
	"/orders":            {Count: 1000, TimeDelta: time.Minute},
 | 
						|
}
 | 
						|
 | 
						|
var rateLimitScript = rueidis.NewLuaScript(`
 | 
						|
local key = KEYS[1]
 | 
						|
local now = tonumber(ARGV[1])
 | 
						|
local window = tonumber(ARGV[2])
 | 
						|
local limit = tonumber(ARGV[3])
 | 
						|
 | 
						|
-- Remove old entries outside the time window
 | 
						|
redis.call('ZREMRANGEBYSCORE', key, '-inf', now - window)
 | 
						|
local count = redis.call('ZCARD', key)
 | 
						|
 | 
						|
if count < limit then
 | 
						|
	-- Add new request timestamp and set TTL
 | 
						|
	redis.call('ZADD', key, now, now)
 | 
						|
	redis.call('EXPIRE', key, math.ceil(window / 1000))
 | 
						|
	return 0
 | 
						|
else
 | 
						|
	-- Find oldest request timestamp
 | 
						|
	local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')[2]
 | 
						|
	-- Return wait time until oldest request expires
 | 
						|
	return (tonumber(oldest) + window) - now
 | 
						|
end
 | 
						|
`)
 | 
						|
 | 
						|
type RateLimitTransport struct {
 | 
						|
	http.RoundTripper
 | 
						|
}
 | 
						|
 | 
						|
func (t *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, error) {
 | 
						|
	ctx := req.Context()
 | 
						|
 | 
						|
	// Extract API key from headers
 | 
						|
	apiKey := req.Header.Get("Api-Key")
 | 
						|
	if apiKey == "" {
 | 
						|
		return nil, fmt.Errorf("Api-Key header is required for rate limiting")
 | 
						|
	}
 | 
						|
 | 
						|
	// Get path from header or URL
 | 
						|
	var path string
 | 
						|
	path = req.URL.Path
 | 
						|
 | 
						|
	// Get rate limit for this path
 | 
						|
	rateLimit, exists := PathLimits[path]
 | 
						|
	if !exists {
 | 
						|
		rateLimit = RateLimit{Count: 100, TimeDelta: time.Minute} // default limit
 | 
						|
	}
 | 
						|
 | 
						|
	// Create unique key based on API key and path
 | 
						|
	rateLimitKey := fmt.Sprintf("ym:ratelimit:%s:%s", apiKey, path)
 | 
						|
 | 
						|
	now := time.Now().UnixMilli()
 | 
						|
	windowMillis := int64(rateLimit.TimeDelta / time.Millisecond)
 | 
						|
 | 
						|
	client := *redis.Client
 | 
						|
 | 
						|
	waitTime, err := rateLimitScript.Exec(ctx, client, []string{rateLimitKey}, []string{
 | 
						|
		fmt.Sprintf("%d", now),
 | 
						|
		fmt.Sprintf("%d", windowMillis),
 | 
						|
		fmt.Sprintf("%d", rateLimit.Count),
 | 
						|
	}).ToInt64()
 | 
						|
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("rate limit script error: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if waitTime > 0 {
 | 
						|
		select {
 | 
						|
		case <-time.After(time.Duration(waitTime) * time.Millisecond):
 | 
						|
		case <-ctx.Done():
 | 
						|
			return nil, ctx.Err()
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return t.RoundTripper.RoundTrip(req)
 | 
						|
}
 | 
						|
 | 
						|
// NewRateLimitTransport creates a new rate limiting transport
 | 
						|
func NewRateLimitTransport() *RateLimitTransport {
 | 
						|
	return &RateLimitTransport{
 | 
						|
		RoundTripper: http.DefaultTransport,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// SetPathLimit sets a custom rate limit for a specific path
 | 
						|
func SetPathLimit(path string, timeDelta time.Duration, count int) {
 | 
						|
	PathLimits[path] = RateLimit{
 | 
						|
		Count:     count,
 | 
						|
		TimeDelta: timeDelta,
 | 
						|
	}
 | 
						|
}
 |