package wb import ( "context" "encoding/json" "fmt" "net/http" "sipro-mps/internal/redis" "sipro-mps/pkg/utils" "time" "github.com/redis/rueidis" ) const ( defaultBucketCapacity = 5 // max burst size refillRate = 100.0 / 60000 // 300 requests per minute → 1 token per 200ms tokenTTLMillis = 60000 // Redis key TTL: 60s ) var tokenBucketScript = rueidis.NewLuaScript(` local key = KEYS[1] local now = tonumber(ARGV[1]) local default_capacity = tonumber(ARGV[2]) local refill_rate = tonumber(ARGV[3]) local ttl = tonumber(ARGV[4]) -- Retry lock local retry_key = key .. ":retry_until" local retry_until = tonumber(redis.call("GET", retry_key)) if retry_until and now < retry_until then return retry_until - now end -- Token Bucket local capacity_key = key .. ":capacity" local token_key = key .. ":tokens" local time_key = key .. ":last_refill" local capacity = tonumber(redis.call("GET", capacity_key)) or default_capacity local tokens = tonumber(redis.call("GET", token_key)) local last_refill = tonumber(redis.call("GET", time_key)) if tokens == nil then tokens = capacity end if last_refill == nil then last_refill = now end local elapsed = now - last_refill local refill = elapsed * refill_rate tokens = math.min(capacity, tokens + refill) last_refill = now if tokens >= 1 then tokens = tokens - 1 redis.call("SET", token_key, tokens) redis.call("SET", time_key, last_refill) redis.call("PEXPIRE", token_key, ttl) redis.call("PEXPIRE", time_key, ttl) return 0 else local wait_time = math.ceil((1 - tokens) / refill_rate) return wait_time end `) type RateLimitTransport struct { http.RoundTripper redis rueidis.Client } func (t *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, error) { ctx := req.Context() tokenString := req.Header.Get("Authorization") authData := utils.NewWbAuthData(tokenString) authDataBytes, err := json.Marshal(authData) if err != nil { return nil, fmt.Errorf("failed to marshal Wildberries auth data: %w", err) } _, claims, err := utils.DecodeWildberriesJwt(authDataBytes) if err != nil { return nil, fmt.Errorf("failed to decode Wildberries JWT: %w", err) } sellerId := claims["sid"].(string) if sellerId == "" { return nil, fmt.Errorf("sellerId is required in JWT claims") } now := time.Now().UnixMilli() client := t.redis waitTime, err := tokenBucketScript.Exec(ctx, client, []string{sellerId}, []string{ fmt.Sprintf("%d", now), fmt.Sprintf("%d", defaultBucketCapacity), fmt.Sprintf("%f", refillRate), fmt.Sprintf("%d", tokenTTLMillis), }).ToInt64() if err != nil { return nil, fmt.Errorf("rate limit script error: %w", err) } if waitTime > 0 { select { case <-time.After(time.Duration(waitTime) * time.Millisecond): case <-ctx.Done(): return nil, ctx.Err() } } return t.RoundTripper.RoundTrip(req) } func SyncRateLimitRemaining(ctx context.Context, sellerId string, remaining int) error { if sellerId == "" || remaining < 0 { return fmt.Errorf("invalid sellerId or remaining") } now := time.Now().UnixMilli() client := cmds := []rueidis.Completed{ client.B().Set().Key(sellerId + ":capacity").Value(fmt.Sprintf("%d", defaultBucketCapacity)).Ex(time.Minute).Build(), client.B().Set().Key(sellerId + ":tokens").Value(fmt.Sprintf("%d", remaining)).Ex(time.Minute).Build(), client.B().Set().Key(sellerId + ":last_refill").Value(fmt.Sprintf("%d", now)).Ex(time.Minute).Build(), } results := client.DoMulti(ctx, cmds...) for _, res := range results { if res.Error() != nil { return fmt.Errorf("failed to sync rate limit: %w", res.Error()) } } return nil } func SetRateLimitRetry(ctx context.Context, sellerId string, retrySeconds int, limit int, resetSeconds int) error { if sellerId == "" { return fmt.Errorf("sellerId is required") } now := time.Now() retryUntil := now.Add(time.Duration(retrySeconds) * time.Second).UnixMilli() client := *redis.Client cmds := []rueidis.Completed{ client.B().Set(). Key(sellerId + ":retry_until"). Value(fmt.Sprintf("%d", retryUntil)). Px(time.Duration(retrySeconds+5) * time.Second).Build(), } if limit > 0 { cmds = append(cmds, client.B().Set(). Key(sellerId+":capacity"). Value(fmt.Sprintf("%d", limit)). Ex(time.Hour).Build()) } if resetSeconds > 0 { resetAt := now.Add(time.Duration(resetSeconds) * time.Second) fmt.Printf("Seller %s rate limit resets at %v (limit: %d)\n", sellerId, resetAt, limit) } results := client.DoMulti(ctx, cmds...) for _, res := range results { if res.Error() != nil { return fmt.Errorf("failed to set retry info: %w", res.Error()) } } return nil } func NewRateLimitTransport(client rueidis.Client) *RateLimitTransport { return &RateLimitTransport{RoundTripper: http.DefaultTransport} }