171 lines
4.6 KiB
Go
171 lines
4.6 KiB
Go
package wb
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/redis/rueidis"
|
|
"net/http"
|
|
"sipro-mps/internal/redis"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
defaultBucketCapacity = 10 // max burst size
|
|
refillRate = 100.0 / 60000 // 300 requests per minute → 1 token per 200ms
|
|
tokenTTLMillis = 60000 // Redis key TTL: 60s
|
|
)
|
|
|
|
var tokenBucketScript = rueidis.NewLuaScript(`
|
|
local key = KEYS[1]
|
|
local now = tonumber(ARGV[1])
|
|
local default_capacity = tonumber(ARGV[2])
|
|
local refill_rate = tonumber(ARGV[3])
|
|
local ttl = tonumber(ARGV[4])
|
|
|
|
-- Retry lock
|
|
local retry_key = key .. ":retry_until"
|
|
local retry_until = tonumber(redis.call("GET", retry_key))
|
|
if retry_until and now < retry_until then
|
|
return retry_until - now
|
|
end
|
|
|
|
-- Token Bucket
|
|
local capacity_key = key .. ":capacity"
|
|
local token_key = key .. ":tokens"
|
|
local time_key = key .. ":last_refill"
|
|
|
|
local capacity = tonumber(redis.call("GET", capacity_key)) or default_capacity
|
|
local tokens = tonumber(redis.call("GET", token_key))
|
|
local last_refill = tonumber(redis.call("GET", time_key))
|
|
|
|
if tokens == nil then tokens = capacity end
|
|
if last_refill == nil then last_refill = now end
|
|
|
|
local elapsed = now - last_refill
|
|
local refill = elapsed * refill_rate
|
|
tokens = math.min(capacity, tokens + refill)
|
|
last_refill = now
|
|
|
|
if tokens >= 1 then
|
|
tokens = tokens - 1
|
|
redis.call("SET", token_key, tokens)
|
|
redis.call("SET", time_key, last_refill)
|
|
redis.call("PEXPIRE", token_key, ttl)
|
|
redis.call("PEXPIRE", time_key, ttl)
|
|
return 0
|
|
else
|
|
local wait_time = math.ceil((1 - tokens) / refill_rate)
|
|
return wait_time
|
|
end
|
|
`)
|
|
|
|
type RateLimitTransport struct {
|
|
http.RoundTripper
|
|
}
|
|
|
|
func (t *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
ctx := req.Context()
|
|
|
|
tokenString := req.Header.Get("Authorization")
|
|
authData := NewWbAuthData(tokenString)
|
|
authDataBytes, err := json.Marshal(authData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal Wildberries auth data: %w", err)
|
|
}
|
|
_, claims, err := DecodeWildberriesJwt(authDataBytes)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode Wildberries JWT: %w", err)
|
|
}
|
|
sellerId := claims["sid"].(string)
|
|
if sellerId == "" {
|
|
return nil, fmt.Errorf("sellerId is required in JWT claims")
|
|
}
|
|
now := time.Now().UnixMilli()
|
|
client := *redis.Client
|
|
|
|
waitTime, err := tokenBucketScript.Exec(ctx, client, []string{sellerId}, []string{
|
|
fmt.Sprintf("%d", now),
|
|
fmt.Sprintf("%d", defaultBucketCapacity),
|
|
fmt.Sprintf("%f", refillRate),
|
|
fmt.Sprintf("%d", tokenTTLMillis),
|
|
}).ToInt64()
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("rate limit script error: %w", err)
|
|
}
|
|
|
|
if waitTime > 0 {
|
|
select {
|
|
case <-time.After(time.Duration(waitTime) * time.Millisecond):
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
|
|
return t.RoundTripper.RoundTrip(req)
|
|
}
|
|
|
|
func SyncRateLimitRemaining(ctx context.Context, sellerId string, remaining int) error {
|
|
if sellerId == "" || remaining < 0 {
|
|
return fmt.Errorf("invalid sellerId or remaining")
|
|
}
|
|
now := time.Now().UnixMilli()
|
|
client := *redis.Client
|
|
|
|
cmds := []rueidis.Completed{
|
|
client.B().Set().Key(sellerId + ":capacity").Value(fmt.Sprintf("%d", defaultBucketCapacity)).Ex(time.Minute).Build(),
|
|
client.B().Set().Key(sellerId + ":tokens").Value(fmt.Sprintf("%d", remaining)).Ex(time.Minute).Build(),
|
|
client.B().Set().Key(sellerId + ":last_refill").Value(fmt.Sprintf("%d", now)).Ex(time.Minute).Build(),
|
|
}
|
|
|
|
results := client.DoMulti(ctx, cmds...)
|
|
for _, res := range results {
|
|
if res.Error() != nil {
|
|
return fmt.Errorf("failed to sync rate limit: %w", res.Error())
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func SetRateLimitRetry(ctx context.Context, sellerId string, retrySeconds int, limit int, resetSeconds int) error {
|
|
if sellerId == "" {
|
|
return fmt.Errorf("sellerId is required")
|
|
}
|
|
now := time.Now()
|
|
retryUntil := now.Add(time.Duration(retrySeconds) * time.Second).UnixMilli()
|
|
client := *redis.Client
|
|
|
|
cmds := []rueidis.Completed{
|
|
client.B().Set().
|
|
Key(sellerId + ":retry_until").
|
|
Value(fmt.Sprintf("%d", retryUntil)).
|
|
Px(time.Duration(retrySeconds+5) * time.Second).Build(),
|
|
}
|
|
|
|
if limit > 0 {
|
|
cmds = append(cmds, client.B().Set().
|
|
Key(sellerId+":capacity").
|
|
Value(fmt.Sprintf("%d", limit)).
|
|
Ex(time.Hour).Build())
|
|
}
|
|
|
|
if resetSeconds > 0 {
|
|
resetAt := now.Add(time.Duration(resetSeconds) * time.Second)
|
|
fmt.Printf("Seller %s rate limit resets at %v (limit: %d)\n", sellerId, resetAt, limit)
|
|
}
|
|
|
|
results := client.DoMulti(ctx, cmds...)
|
|
for _, res := range results {
|
|
if res.Error() != nil {
|
|
return fmt.Errorf("failed to set retry info: %w", res.Error())
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func NewRateLimitTransport() *RateLimitTransport {
|
|
return &RateLimitTransport{RoundTripper: http.DefaultTransport}
|
|
}
|