172 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			172 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package wb
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"encoding/json"
 | 
						|
	"fmt"
 | 
						|
	"github.com/redis/rueidis"
 | 
						|
	"net/http"
 | 
						|
	"sipro-mps/internal/redis"
 | 
						|
	"sipro-mps/pkg/utils"
 | 
						|
	"time"
 | 
						|
)
 | 
						|
 | 
						|
const (
 | 
						|
	defaultBucketCapacity = 10            // max burst size
 | 
						|
	refillRate            = 100.0 / 60000 // 300 requests per minute → 1 token per 200ms
 | 
						|
	tokenTTLMillis        = 60000         // Redis key TTL: 60s
 | 
						|
)
 | 
						|
 | 
						|
var tokenBucketScript = rueidis.NewLuaScript(`
 | 
						|
local key = KEYS[1]
 | 
						|
local now = tonumber(ARGV[1])
 | 
						|
local default_capacity = tonumber(ARGV[2])
 | 
						|
local refill_rate = tonumber(ARGV[3])
 | 
						|
local ttl = tonumber(ARGV[4])
 | 
						|
 | 
						|
-- Retry lock
 | 
						|
local retry_key = key .. ":retry_until"
 | 
						|
local retry_until = tonumber(redis.call("GET", retry_key))
 | 
						|
if retry_until and now < retry_until then
 | 
						|
	return retry_until - now
 | 
						|
end
 | 
						|
 | 
						|
-- Token Bucket
 | 
						|
local capacity_key = key .. ":capacity"
 | 
						|
local token_key = key .. ":tokens"
 | 
						|
local time_key = key .. ":last_refill"
 | 
						|
 | 
						|
local capacity = tonumber(redis.call("GET", capacity_key)) or default_capacity
 | 
						|
local tokens = tonumber(redis.call("GET", token_key))
 | 
						|
local last_refill = tonumber(redis.call("GET", time_key))
 | 
						|
 | 
						|
if tokens == nil then tokens = capacity end
 | 
						|
if last_refill == nil then last_refill = now end
 | 
						|
 | 
						|
local elapsed = now - last_refill
 | 
						|
local refill = elapsed * refill_rate
 | 
						|
tokens = math.min(capacity, tokens + refill)
 | 
						|
last_refill = now
 | 
						|
 | 
						|
if tokens >= 1 then
 | 
						|
	tokens = tokens - 1
 | 
						|
	redis.call("SET", token_key, tokens)
 | 
						|
	redis.call("SET", time_key, last_refill)
 | 
						|
	redis.call("PEXPIRE", token_key, ttl)
 | 
						|
	redis.call("PEXPIRE", time_key, ttl)
 | 
						|
	return 0
 | 
						|
else
 | 
						|
	local wait_time = math.ceil((1 - tokens) / refill_rate)
 | 
						|
	return wait_time
 | 
						|
end
 | 
						|
`)
 | 
						|
 | 
						|
type RateLimitTransport struct {
 | 
						|
	http.RoundTripper
 | 
						|
}
 | 
						|
 | 
						|
func (t *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, error) {
 | 
						|
	ctx := req.Context()
 | 
						|
 | 
						|
	tokenString := req.Header.Get("Authorization")
 | 
						|
	authData := utils.NewWbAuthData(tokenString)
 | 
						|
	authDataBytes, err := json.Marshal(authData)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("failed to marshal Wildberries auth data: %w", err)
 | 
						|
	}
 | 
						|
	_, claims, err := utils.DecodeWildberriesJwt(authDataBytes)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("failed to decode Wildberries JWT: %w", err)
 | 
						|
	}
 | 
						|
	sellerId := claims["sid"].(string)
 | 
						|
	if sellerId == "" {
 | 
						|
		return nil, fmt.Errorf("sellerId is required in JWT claims")
 | 
						|
	}
 | 
						|
	now := time.Now().UnixMilli()
 | 
						|
	client := *redis.Client
 | 
						|
 | 
						|
	waitTime, err := tokenBucketScript.Exec(ctx, client, []string{sellerId}, []string{
 | 
						|
		fmt.Sprintf("%d", now),
 | 
						|
		fmt.Sprintf("%d", defaultBucketCapacity),
 | 
						|
		fmt.Sprintf("%f", refillRate),
 | 
						|
		fmt.Sprintf("%d", tokenTTLMillis),
 | 
						|
	}).ToInt64()
 | 
						|
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("rate limit script error: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if waitTime > 0 {
 | 
						|
		select {
 | 
						|
		case <-time.After(time.Duration(waitTime) * time.Millisecond):
 | 
						|
		case <-ctx.Done():
 | 
						|
			return nil, ctx.Err()
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return t.RoundTripper.RoundTrip(req)
 | 
						|
}
 | 
						|
 | 
						|
func SyncRateLimitRemaining(ctx context.Context, sellerId string, remaining int) error {
 | 
						|
	if sellerId == "" || remaining < 0 {
 | 
						|
		return fmt.Errorf("invalid sellerId or remaining")
 | 
						|
	}
 | 
						|
	now := time.Now().UnixMilli()
 | 
						|
	client := *redis.Client
 | 
						|
 | 
						|
	cmds := []rueidis.Completed{
 | 
						|
		client.B().Set().Key(sellerId + ":capacity").Value(fmt.Sprintf("%d", defaultBucketCapacity)).Ex(time.Minute).Build(),
 | 
						|
		client.B().Set().Key(sellerId + ":tokens").Value(fmt.Sprintf("%d", remaining)).Ex(time.Minute).Build(),
 | 
						|
		client.B().Set().Key(sellerId + ":last_refill").Value(fmt.Sprintf("%d", now)).Ex(time.Minute).Build(),
 | 
						|
	}
 | 
						|
 | 
						|
	results := client.DoMulti(ctx, cmds...)
 | 
						|
	for _, res := range results {
 | 
						|
		if res.Error() != nil {
 | 
						|
			return fmt.Errorf("failed to sync rate limit: %w", res.Error())
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func SetRateLimitRetry(ctx context.Context, sellerId string, retrySeconds int, limit int, resetSeconds int) error {
 | 
						|
	if sellerId == "" {
 | 
						|
		return fmt.Errorf("sellerId is required")
 | 
						|
	}
 | 
						|
	now := time.Now()
 | 
						|
	retryUntil := now.Add(time.Duration(retrySeconds) * time.Second).UnixMilli()
 | 
						|
	client := *redis.Client
 | 
						|
 | 
						|
	cmds := []rueidis.Completed{
 | 
						|
		client.B().Set().
 | 
						|
			Key(sellerId + ":retry_until").
 | 
						|
			Value(fmt.Sprintf("%d", retryUntil)).
 | 
						|
			Px(time.Duration(retrySeconds+5) * time.Second).Build(),
 | 
						|
	}
 | 
						|
 | 
						|
	if limit > 0 {
 | 
						|
		cmds = append(cmds, client.B().Set().
 | 
						|
			Key(sellerId+":capacity").
 | 
						|
			Value(fmt.Sprintf("%d", limit)).
 | 
						|
			Ex(time.Hour).Build())
 | 
						|
	}
 | 
						|
 | 
						|
	if resetSeconds > 0 {
 | 
						|
		resetAt := now.Add(time.Duration(resetSeconds) * time.Second)
 | 
						|
		fmt.Printf("Seller %s rate limit resets at %v (limit: %d)\n", sellerId, resetAt, limit)
 | 
						|
	}
 | 
						|
 | 
						|
	results := client.DoMulti(ctx, cmds...)
 | 
						|
	for _, res := range results {
 | 
						|
		if res.Error() != nil {
 | 
						|
			return fmt.Errorf("failed to set retry info: %w", res.Error())
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func NewRateLimitTransport() *RateLimitTransport {
 | 
						|
	return &RateLimitTransport{RoundTripper: http.DefaultTransport}
 | 
						|
}
 |