package ym import ( "fmt" "net/http" "sipro-mps/internal/redis" "time" "github.com/redis/rueidis" ) // RateLimit defines a rate limit configuration type RateLimit struct { Count int // Number of requests allowed TimeDelta time.Duration // Time window } // Path rate limits for Yandex Market API var PathLimits = map[string]RateLimit{ "/tariffs/calculate": {Count: 100, TimeDelta: time.Minute}, "/campaigns": {Count: 300, TimeDelta: time.Minute}, "/orders": {Count: 1000, TimeDelta: time.Minute}, } var rateLimitScript = rueidis.NewLuaScript(` local key = KEYS[1] local now = tonumber(ARGV[1]) local window = tonumber(ARGV[2]) local limit = tonumber(ARGV[3]) -- Remove old entries outside the time window redis.call('ZREMRANGEBYSCORE', key, '-inf', now - window) local count = redis.call('ZCARD', key) if count < limit then -- Add new request timestamp and set TTL redis.call('ZADD', key, now, now) redis.call('EXPIRE', key, math.ceil(window / 1000)) return 0 else -- Find oldest request timestamp local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')[2] -- Return wait time until oldest request expires return (tonumber(oldest) + window) - now end `) type RateLimitTransport struct { http.RoundTripper } func (t *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, error) { ctx := req.Context() // Extract API key from headers apiKey := req.Header.Get("Api-Key") if apiKey == "" { return nil, fmt.Errorf("Api-Key header is required for rate limiting") } // Get path from header or URL var path string path = req.URL.Path // Get rate limit for this path rateLimit, exists := PathLimits[path] if !exists { rateLimit = RateLimit{Count: 100, TimeDelta: time.Minute} // default limit } // Create unique key based on API key and path rateLimitKey := fmt.Sprintf("ym:ratelimit:%s:%s", apiKey, path) now := time.Now().UnixMilli() windowMillis := int64(rateLimit.TimeDelta / time.Millisecond) client := *redis.Client waitTime, err := rateLimitScript.Exec(ctx, client, []string{rateLimitKey}, []string{ fmt.Sprintf("%d", now), fmt.Sprintf("%d", windowMillis), fmt.Sprintf("%d", rateLimit.Count), }).ToInt64() if err != nil { return nil, fmt.Errorf("rate limit script error: %w", err) } if waitTime > 0 { select { case <-time.After(time.Duration(waitTime) * time.Millisecond): case <-ctx.Done(): return nil, ctx.Err() } } return t.RoundTripper.RoundTrip(req) } // NewRateLimitTransport creates a new rate limiting transport func NewRateLimitTransport() *RateLimitTransport { return &RateLimitTransport{ RoundTripper: http.DefaultTransport, } } // SetPathLimit sets a custom rate limit for a specific path func SetPathLimit(path string, timeDelta time.Duration, count int) { PathLimits[path] = RateLimit{ Count: count, TimeDelta: timeDelta, } }