114 lines
2.8 KiB
Go
114 lines
2.8 KiB
Go
package ym
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"sipro-mps/internal/redis"
|
|
"time"
|
|
|
|
"github.com/redis/rueidis"
|
|
)
|
|
|
|
// RateLimit defines a rate limit configuration
|
|
type RateLimit struct {
|
|
Count int // Number of requests allowed
|
|
TimeDelta time.Duration // Time window
|
|
}
|
|
|
|
// Path rate limits for Yandex Market API
|
|
var PathLimits = map[string]RateLimit{
|
|
"/tariffs/calculate": {Count: 95, TimeDelta: time.Minute},
|
|
"/campaigns": {Count: 300, TimeDelta: time.Minute},
|
|
"/orders": {Count: 1000, TimeDelta: time.Minute},
|
|
}
|
|
|
|
var rateLimitScript = rueidis.NewLuaScript(`
|
|
local key = KEYS[1]
|
|
local now = tonumber(ARGV[1])
|
|
local window = tonumber(ARGV[2])
|
|
local limit = tonumber(ARGV[3])
|
|
|
|
-- Remove old entries outside the time window
|
|
redis.call('ZREMRANGEBYSCORE', key, '-inf', now - window)
|
|
local count = redis.call('ZCARD', key)
|
|
|
|
if count < limit then
|
|
-- Add new request timestamp and set TTL
|
|
redis.call('ZADD', key, now, now)
|
|
redis.call('EXPIRE', key, math.ceil(window / 1000))
|
|
return 0
|
|
else
|
|
-- Find oldest request timestamp
|
|
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')[2]
|
|
-- Return wait time until oldest request expires
|
|
return (tonumber(oldest) + window) - now
|
|
end
|
|
`)
|
|
|
|
type RateLimitTransport struct {
|
|
http.RoundTripper
|
|
}
|
|
|
|
func (t *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
ctx := req.Context()
|
|
|
|
// Extract API key from headers
|
|
apiKey := req.Header.Get("Api-Key")
|
|
if apiKey == "" {
|
|
return nil, fmt.Errorf("Api-Key header is required for rate limiting")
|
|
}
|
|
|
|
// Get path from header or URL
|
|
var path string
|
|
path = req.URL.Path
|
|
|
|
// Get rate limit for this path
|
|
rateLimit, exists := PathLimits[path]
|
|
if !exists {
|
|
rateLimit = RateLimit{Count: 100, TimeDelta: time.Minute} // default limit
|
|
}
|
|
|
|
// Create unique key based on API key and path
|
|
rateLimitKey := fmt.Sprintf("ym:ratelimit:%s:%s", apiKey, path)
|
|
|
|
now := time.Now().UnixMilli()
|
|
windowMillis := int64(rateLimit.TimeDelta / time.Millisecond)
|
|
|
|
client := *redis.Client
|
|
|
|
waitTime, err := rateLimitScript.Exec(ctx, client, []string{rateLimitKey}, []string{
|
|
fmt.Sprintf("%d", now),
|
|
fmt.Sprintf("%d", windowMillis),
|
|
fmt.Sprintf("%d", rateLimit.Count),
|
|
}).ToInt64()
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("rate limit script error: %w", err)
|
|
}
|
|
|
|
if waitTime > 0 {
|
|
select {
|
|
case <-time.After(time.Duration(waitTime) * time.Millisecond):
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
|
|
return t.RoundTripper.RoundTrip(req)
|
|
}
|
|
|
|
// NewRateLimitTransport creates a new rate limiting transport
|
|
func NewRateLimitTransport() *RateLimitTransport {
|
|
return &RateLimitTransport{
|
|
RoundTripper: http.DefaultTransport,
|
|
}
|
|
}
|
|
|
|
// SetPathLimit sets a custom rate limit for a specific path
|
|
func SetPathLimit(path string, timeDelta time.Duration, count int) {
|
|
PathLimits[path] = RateLimit{
|
|
Count: count,
|
|
TimeDelta: timeDelta,
|
|
}
|
|
}
|