Files
Sipro-Marketplaces/internal/wb/rate_limiter.go

171 lines
4.6 KiB
Go

package wb
import (
"context"
"encoding/json"
"fmt"
"github.com/redis/rueidis"
"net/http"
"sipro-mps/internal/redis"
"time"
)
const (
defaultBucketCapacity = 10 // max burst size
refillRate = 100.0 / 60000 // 300 requests per minute → 1 token per 200ms
tokenTTLMillis = 60000 // Redis key TTL: 60s
)
var tokenBucketScript = rueidis.NewLuaScript(`
local key = KEYS[1]
local now = tonumber(ARGV[1])
local default_capacity = tonumber(ARGV[2])
local refill_rate = tonumber(ARGV[3])
local ttl = tonumber(ARGV[4])
-- Retry lock
local retry_key = key .. ":retry_until"
local retry_until = tonumber(redis.call("GET", retry_key))
if retry_until and now < retry_until then
return retry_until - now
end
-- Token Bucket
local capacity_key = key .. ":capacity"
local token_key = key .. ":tokens"
local time_key = key .. ":last_refill"
local capacity = tonumber(redis.call("GET", capacity_key)) or default_capacity
local tokens = tonumber(redis.call("GET", token_key))
local last_refill = tonumber(redis.call("GET", time_key))
if tokens == nil then tokens = capacity end
if last_refill == nil then last_refill = now end
local elapsed = now - last_refill
local refill = elapsed * refill_rate
tokens = math.min(capacity, tokens + refill)
last_refill = now
if tokens >= 1 then
tokens = tokens - 1
redis.call("SET", token_key, tokens)
redis.call("SET", time_key, last_refill)
redis.call("PEXPIRE", token_key, ttl)
redis.call("PEXPIRE", time_key, ttl)
return 0
else
local wait_time = math.ceil((1 - tokens) / refill_rate)
return wait_time
end
`)
type RateLimitTransport struct {
http.RoundTripper
}
func (t *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := req.Context()
tokenString := req.Header.Get("Authorization")
authData := NewWbAuthData(tokenString)
authDataBytes, err := json.Marshal(authData)
if err != nil {
return nil, fmt.Errorf("failed to marshal Wildberries auth data: %w", err)
}
_, claims, err := DecodeWildberriesJwt(authDataBytes)
if err != nil {
return nil, fmt.Errorf("failed to decode Wildberries JWT: %w", err)
}
sellerId := claims["sid"].(string)
if sellerId == "" {
return nil, fmt.Errorf("sellerId is required in JWT claims")
}
now := time.Now().UnixMilli()
client := *redis.Client
waitTime, err := tokenBucketScript.Exec(ctx, client, []string{sellerId}, []string{
fmt.Sprintf("%d", now),
fmt.Sprintf("%d", defaultBucketCapacity),
fmt.Sprintf("%f", refillRate),
fmt.Sprintf("%d", tokenTTLMillis),
}).ToInt64()
if err != nil {
return nil, fmt.Errorf("rate limit script error: %w", err)
}
if waitTime > 0 {
select {
case <-time.After(time.Duration(waitTime) * time.Millisecond):
case <-ctx.Done():
return nil, ctx.Err()
}
}
return t.RoundTripper.RoundTrip(req)
}
func SyncRateLimitRemaining(ctx context.Context, sellerId string, remaining int) error {
if sellerId == "" || remaining < 0 {
return fmt.Errorf("invalid sellerId or remaining")
}
now := time.Now().UnixMilli()
client := *redis.Client
cmds := []rueidis.Completed{
client.B().Set().Key(sellerId + ":capacity").Value(fmt.Sprintf("%d", defaultBucketCapacity)).Ex(time.Minute).Build(),
client.B().Set().Key(sellerId + ":tokens").Value(fmt.Sprintf("%d", remaining)).Ex(time.Minute).Build(),
client.B().Set().Key(sellerId + ":last_refill").Value(fmt.Sprintf("%d", now)).Ex(time.Minute).Build(),
}
results := client.DoMulti(ctx, cmds...)
for _, res := range results {
if res.Error() != nil {
return fmt.Errorf("failed to sync rate limit: %w", res.Error())
}
}
return nil
}
func SetRateLimitRetry(ctx context.Context, sellerId string, retrySeconds int, limit int, resetSeconds int) error {
if sellerId == "" {
return fmt.Errorf("sellerId is required")
}
now := time.Now()
retryUntil := now.Add(time.Duration(retrySeconds) * time.Second).UnixMilli()
client := *redis.Client
cmds := []rueidis.Completed{
client.B().Set().
Key(sellerId + ":retry_until").
Value(fmt.Sprintf("%d", retryUntil)).
Px(time.Duration(retrySeconds+5) * time.Second).Build(),
}
if limit > 0 {
cmds = append(cmds, client.B().Set().
Key(sellerId+":capacity").
Value(fmt.Sprintf("%d", limit)).
Ex(time.Hour).Build())
}
if resetSeconds > 0 {
resetAt := now.Add(time.Duration(resetSeconds) * time.Second)
fmt.Printf("Seller %s rate limit resets at %v (limit: %d)\n", sellerId, resetAt, limit)
}
results := client.DoMulti(ctx, cmds...)
for _, res := range results {
if res.Error() != nil {
return fmt.Errorf("failed to set retry info: %w", res.Error())
}
}
return nil
}
func NewRateLimitTransport() *RateLimitTransport {
return &RateLimitTransport{RoundTripper: http.DefaultTransport}
}