feat: shit shit
This commit is contained in:
113
internal/ym/rate_limiter.go
Normal file
113
internal/ym/rate_limiter.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package ym
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sipro-mps/internal/redis"
|
||||
"time"
|
||||
|
||||
"github.com/redis/rueidis"
|
||||
)
|
||||
|
||||
// RateLimit defines a rate limit configuration
|
||||
type RateLimit struct {
|
||||
Count int // Number of requests allowed
|
||||
TimeDelta time.Duration // Time window
|
||||
}
|
||||
|
||||
// Path rate limits for Yandex Market API
|
||||
var PathLimits = map[string]RateLimit{
|
||||
"/tariffs/calculate": {Count: 100, TimeDelta: time.Minute},
|
||||
"/campaigns": {Count: 300, TimeDelta: time.Minute},
|
||||
"/orders": {Count: 1000, TimeDelta: time.Minute},
|
||||
}
|
||||
|
||||
var rateLimitScript = rueidis.NewLuaScript(`
|
||||
local key = KEYS[1]
|
||||
local now = tonumber(ARGV[1])
|
||||
local window = tonumber(ARGV[2])
|
||||
local limit = tonumber(ARGV[3])
|
||||
|
||||
-- Remove old entries outside the time window
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', now - window)
|
||||
local count = redis.call('ZCARD', key)
|
||||
|
||||
if count < limit then
|
||||
-- Add new request timestamp and set TTL
|
||||
redis.call('ZADD', key, now, now)
|
||||
redis.call('EXPIRE', key, math.ceil(window / 1000))
|
||||
return 0
|
||||
else
|
||||
-- Find oldest request timestamp
|
||||
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')[2]
|
||||
-- Return wait time until oldest request expires
|
||||
return (tonumber(oldest) + window) - now
|
||||
end
|
||||
`)
|
||||
|
||||
type RateLimitTransport struct {
|
||||
http.RoundTripper
|
||||
}
|
||||
|
||||
func (t *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
ctx := req.Context()
|
||||
|
||||
// Extract API key from headers
|
||||
apiKey := req.Header.Get("Api-Key")
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("Api-Key header is required for rate limiting")
|
||||
}
|
||||
|
||||
// Get path from header or URL
|
||||
var path string
|
||||
path = req.URL.Path
|
||||
|
||||
// Get rate limit for this path
|
||||
rateLimit, exists := PathLimits[path]
|
||||
if !exists {
|
||||
rateLimit = RateLimit{Count: 100, TimeDelta: time.Minute} // default limit
|
||||
}
|
||||
|
||||
// Create unique key based on API key and path
|
||||
rateLimitKey := fmt.Sprintf("ym:ratelimit:%s:%s", apiKey, path)
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
windowMillis := int64(rateLimit.TimeDelta / time.Millisecond)
|
||||
|
||||
client := *redis.Client
|
||||
|
||||
waitTime, err := rateLimitScript.Exec(ctx, client, []string{rateLimitKey}, []string{
|
||||
fmt.Sprintf("%d", now),
|
||||
fmt.Sprintf("%d", windowMillis),
|
||||
fmt.Sprintf("%d", rateLimit.Count),
|
||||
}).ToInt64()
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rate limit script error: %w", err)
|
||||
}
|
||||
|
||||
if waitTime > 0 {
|
||||
select {
|
||||
case <-time.After(time.Duration(waitTime) * time.Millisecond):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
return t.RoundTripper.RoundTrip(req)
|
||||
}
|
||||
|
||||
// NewRateLimitTransport creates a new rate limiting transport
|
||||
func NewRateLimitTransport() *RateLimitTransport {
|
||||
return &RateLimitTransport{
|
||||
RoundTripper: http.DefaultTransport,
|
||||
}
|
||||
}
|
||||
|
||||
// SetPathLimit sets a custom rate limit for a specific path
|
||||
func SetPathLimit(path string, timeDelta time.Duration, count int) {
|
||||
PathLimits[path] = RateLimit{
|
||||
Count: count,
|
||||
TimeDelta: timeDelta,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user