feat: shit shit
This commit is contained in:
@@ -15,4 +15,5 @@ type Marketplace struct {
|
||||
AuthData pgtype.Text
|
||||
WarehouseID pgtype.Text
|
||||
AuthDataJson []byte
|
||||
CampaignID pgtype.Text
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
const getMarketplaceByID = `-- name: GetMarketplaceByID :one
|
||||
SELECT id, base_marketplace, name, auth_data, warehouse_id, auth_data_json FROM marketplaces
|
||||
SELECT id, base_marketplace, name, auth_data, warehouse_id, auth_data_json, campaign_id FROM marketplaces
|
||||
WHERE id = $1 LIMIT 1
|
||||
`
|
||||
|
||||
@@ -24,6 +24,7 @@ func (q *Queries) GetMarketplaceByID(ctx context.Context, id int32) (Marketplace
|
||||
&i.AuthData,
|
||||
&i.WarehouseID,
|
||||
&i.AuthDataJson,
|
||||
&i.CampaignID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
@@ -10,5 +10,6 @@ create table marketplaces
|
||||
CASE
|
||||
WHEN ((auth_data)::text IS JSON) THEN (auth_data)::jsonb
|
||||
ELSE NULL::jsonb
|
||||
END) stored
|
||||
END) stored,
|
||||
campaign_id varchar
|
||||
);
|
||||
@@ -2,8 +2,9 @@ package marketplace
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/go-faster/errors"
|
||||
"sipro-mps/pkg/utils"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -18,6 +19,7 @@ type Marketplace struct {
|
||||
AuthData string `json:"auth_data"`
|
||||
WarehouseID string `json:"warehouse_id"`
|
||||
AuthDataJson []byte `json:"auth_data_json,omitempty"`
|
||||
CampaignID string `json:"campaign_id,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Marketplace) getIdentifierWildberries() (string, error) {
|
||||
|
||||
@@ -25,5 +25,6 @@ func (r *dbRepository) GetMarketplaceByID(ctx context.Context, id int) (*Marketp
|
||||
AuthData: marketplace.AuthData.String,
|
||||
WarehouseID: marketplace.WarehouseID.String,
|
||||
AuthDataJson: marketplace.AuthDataJson,
|
||||
CampaignID: marketplace.CampaignID.String,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -2,20 +2,16 @@ package products
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/redis/rueidis"
|
||||
"github.com/samber/lo"
|
||||
pb "sipro-mps/api/generated/v1/wb/products"
|
||||
"sipro-mps/internal/marketplace"
|
||||
"sipro-mps/internal/redis"
|
||||
"sipro-mps/internal/tasks/client"
|
||||
"sipro-mps/internal/tasks/types"
|
||||
"sipro-mps/internal/wb"
|
||||
"sipro-mps/internal/wb/products/mapping/generated"
|
||||
wbapi "sipro-mps/pkg/api/wb/client"
|
||||
"sipro-mps/pkg/utils"
|
||||
|
||||
"github.com/deliveryhero/pipeline/v2"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -97,79 +93,99 @@ func fetchProducts(
|
||||
}
|
||||
|
||||
func (a apiRepository) StreamAllProductsCache(ctx context.Context, marketplaceId int, resultChan chan<- []pb.Product, errChan chan<- error) {
|
||||
defer close(resultChan)
|
||||
defer close(errChan)
|
||||
_, sellerId, err := a.ParseMarketplace(ctx, marketplaceId)
|
||||
// DO NOT close channels here - WithCache will handle it (caller/creator owns them)
|
||||
mp, err := a.marketplaceRepository.GetMarketplaceByID(ctx, marketplaceId)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
errChan <- fmt.Errorf("getting marketplace by ID: %w", err)
|
||||
return
|
||||
}
|
||||
c := *redis.Client
|
||||
key := fmt.Sprintf("wb:products:%s", sellerId)
|
||||
jsonString, err := c.Do(ctx, c.B().Get().Key(key).Build()).ToString()
|
||||
if err == nil && jsonString != "null" {
|
||||
var result []pb.Product
|
||||
err = json.Unmarshal([]byte(jsonString), &result)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("unmarshalling products from cache: %w", err)
|
||||
return
|
||||
}
|
||||
task, err := types.NewFetchProductsTask(types.TypeWbFetchProducts, marketplaceId)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("creating fetch products task: %w", err)
|
||||
return
|
||||
}
|
||||
_, err = client.Client.Enqueue(task)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("enqueueing fetch products task: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
resultChan <- result
|
||||
identifier, err := mp.GetIdentifier()
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("getting marketplace identifier: %w", err)
|
||||
return
|
||||
}
|
||||
if !errors.As(err, &rueidis.Nil) && err != nil {
|
||||
errChan <- fmt.Errorf("fetching products from cache: %w", err)
|
||||
client, err := wb.GetClientFromMarketplace(mp)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
converter := generated.ConverterImpl{}
|
||||
|
||||
innerResultChan := make(chan []WbProduct)
|
||||
innerErrChan := make(chan error)
|
||||
go a.StreamAllProducts(ctx, marketplaceId, innerResultChan, innerErrChan)
|
||||
var allProducts []pb.Product
|
||||
defer func() {
|
||||
jsonData, err := json.Marshal(allProducts)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("marshalling products to cache: %w", err)
|
||||
return
|
||||
}
|
||||
err = c.Do(ctx, c.B().Set().Key(key).Value(string(jsonData)).Build()).Error()
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("setting products to cache: %w", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case err, ok := <-innerErrChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
errChan <- fmt.Errorf("streaming products: %w", err)
|
||||
return
|
||||
case products, ok := <-innerResultChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
pbProducts := lo.Map(products, func(p WbProduct, _ int) pb.Product {
|
||||
return *converter.ToProto(&p)
|
||||
})
|
||||
allProducts = append(allProducts, pbProducts...)
|
||||
resultChan <- pbProducts
|
||||
}
|
||||
transform := pipeline.NewProcessor(func(_ context.Context, products []WbProduct) ([]pb.Product, error) {
|
||||
return lo.Map(products, func(item WbProduct, _ int) pb.Product {
|
||||
return *converter.ToProto(&item)
|
||||
}), nil
|
||||
}, nil)
|
||||
inputChan := make(chan []WbProduct)
|
||||
fetchProducts(ctx, client, identifier, inputChan, nil)
|
||||
for out := range pipeline.Process(ctx, transform, inputChan) {
|
||||
resultChan <- out
|
||||
}
|
||||
|
||||
//c := *redis.Client
|
||||
//key := fmt.Sprintf("wb:products:%s", sellerId)
|
||||
//jsonString, err := c.Do(ctx, c.B().Get().Key(key).Build()).ToString()
|
||||
//if err == nil && jsonString != "null" {
|
||||
// var result []pb.Product
|
||||
// err = json.Unmarshal([]byte(jsonString), &result)
|
||||
// if err != nil {
|
||||
// errChan <- fmt.Errorf("unmarshalling products from cache: %w", err)
|
||||
// return
|
||||
// }
|
||||
// task, err := types.NewFetchProductsTask(types.TypeWbFetchProducts, marketplaceId)
|
||||
// if err != nil {
|
||||
// errChan <- fmt.Errorf("creating fetch products task: %w", err)
|
||||
// return
|
||||
// }
|
||||
// _, err = client.Client.Enqueue(task)
|
||||
// if err != nil {
|
||||
// errChan <- fmt.Errorf("enqueueing fetch products task: %w", err)
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// resultChan <- result
|
||||
// return
|
||||
//}
|
||||
//if !errors.As(err, &rueidis.Nil) && err != nil {
|
||||
// errChan <- fmt.Errorf("fetching products from cache: %w", err)
|
||||
// return
|
||||
//}
|
||||
//converter := generated.ConverterImpl{}
|
||||
//
|
||||
//innerResultChan := make(chan []WbProduct)
|
||||
//innerErrChan := make(chan error)
|
||||
//go a.StreamAllProducts(ctx, marketplaceId, innerResultChan, innerErrChan)
|
||||
//var allProducts []pb.Product
|
||||
//defer func() {
|
||||
// jsonData, err := json.Marshal(allProducts)
|
||||
// if err != nil {
|
||||
// errChan <- fmt.Errorf("marshalling products to cache: %w", err)
|
||||
// return
|
||||
// }
|
||||
// err = c.Do(ctx, c.B().Set().Key(key).Value(string(jsonData)).Build()).Error()
|
||||
// if err != nil {
|
||||
// errChan <- fmt.Errorf("setting products to cache: %w", err)
|
||||
// return
|
||||
// }
|
||||
//}()
|
||||
//for {
|
||||
// select {
|
||||
// case err, ok := <-innerErrChan:
|
||||
// if !ok {
|
||||
// return
|
||||
// }
|
||||
// errChan <- fmt.Errorf("streaming products: %w", err)
|
||||
// return
|
||||
// case products, ok := <-innerResultChan:
|
||||
// if !ok {
|
||||
// return
|
||||
// }
|
||||
// pbProducts := lo.Map(products, func(p WbProduct, _ int) pb.Product {
|
||||
// return *converter.ToProto(&p)
|
||||
// })
|
||||
// allProducts = append(allProducts, pbProducts...)
|
||||
// resultChan <- pbProducts
|
||||
// }
|
||||
//}
|
||||
|
||||
}
|
||||
func (a apiRepository) GetAllProducts(ctx context.Context, marketplaceId int) ([]WbProduct, error) {
|
||||
marketplaceByID, sellerId, err := a.ParseMarketplace(ctx, marketplaceId)
|
||||
|
||||
36
internal/ym/common.go
Normal file
36
internal/ym/common.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package ym
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"sipro-mps/internal/marketplace"
|
||||
"sipro-mps/pkg/api/yandex/ymclient"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func GetClientFromMarketplace(mp *marketplace.Marketplace) (*ymclient.APIClient, error) {
|
||||
authDataParsed := gjson.Parse(mp.AuthData)
|
||||
apiKeyResult := authDataParsed.Get("apiKey")
|
||||
if !apiKeyResult.Exists() {
|
||||
return nil, errors.New("API key not found in marketplace auth data")
|
||||
}
|
||||
apiKey := apiKeyResult.String()
|
||||
if apiKey == "" {
|
||||
return nil, errors.New("API key is empty")
|
||||
}
|
||||
if !strings.HasPrefix(apiKey, "ACMA") {
|
||||
return nil, errors.New("API key does not start with 'ACMA'")
|
||||
}
|
||||
// Create HTTP client with rate limiting
|
||||
httpClient := &http.Client{
|
||||
Transport: NewRateLimitTransport(),
|
||||
}
|
||||
|
||||
cfg := ymclient.NewConfiguration()
|
||||
cfg.AddDefaultHeader("Api-Key", apiKey)
|
||||
cfg.HTTPClient = httpClient
|
||||
client := ymclient.NewAPIClient(cfg)
|
||||
return client, nil
|
||||
}
|
||||
103
internal/ym/products/adapter_grpc.go
Normal file
103
internal/ym/products/adapter_grpc.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package products
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
pb "sipro-mps/api/generated/v1/yandexmarket/products"
|
||||
"sipro-mps/internal/marketplace"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type AdapterGRPC struct {
|
||||
pb.UnimplementedProductsServiceServer
|
||||
repo Repository
|
||||
}
|
||||
|
||||
func NewAdapterGRPC(repo Repository) *AdapterGRPC {
|
||||
return &AdapterGRPC{
|
||||
repo: repo,
|
||||
}
|
||||
}
|
||||
|
||||
func RegisterAdapterGRPC(server *grpc.Server, marketplacesRepository marketplace.Repository) (*Repository, error) {
|
||||
repo := NewAPIRepository(marketplacesRepository)
|
||||
adapter := NewAdapterGRPC(repo)
|
||||
pb.RegisterProductsServiceServer(server, adapter)
|
||||
return &repo, nil
|
||||
}
|
||||
|
||||
func (a *AdapterGRPC) GetProducts(req *pb.GetProductsRequest, stream pb.ProductsService_GetProductsServer) error {
|
||||
ctx := stream.Context()
|
||||
fmt.Printf("GetProducts called with marketplace_id: %d, offer_ids count: %d\n", req.MarketplaceId, len(req.OfferIds))
|
||||
|
||||
resultChan := make(chan []*pb.GetProductsResponse_Offer, 10)
|
||||
errChan := make(chan error)
|
||||
|
||||
go a.repo.GetProducts(ctx, int(req.MarketplaceId), req, resultChan, errChan)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Println("GetProducts: context cancelled or deadline exceeded:", ctx.Err())
|
||||
return ctx.Err()
|
||||
case offers, ok := <-resultChan:
|
||||
if !ok {
|
||||
fmt.Println("GetProducts: result channel closed")
|
||||
return nil
|
||||
}
|
||||
// Send offers in response
|
||||
response := &pb.GetProductsResponse{
|
||||
Offers: offers,
|
||||
}
|
||||
if err := stream.Send(response); err != nil {
|
||||
fmt.Println("GetProducts: error sending response:", err)
|
||||
return err
|
||||
}
|
||||
case err, ok := <-errChan:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Println("GetProducts: error received from channel:", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AdapterGRPC) CalculateProductTariffs(req *pb.CalculateProductTariffsRequest, stream pb.ProductsService_CalculateProductTariffsServer) error {
|
||||
ctx := stream.Context()
|
||||
fmt.Printf("CalculateProductTariffs called with marketplace_id: %d, offers count: %d\n", req.MarketplaceId, len(req.Offers))
|
||||
|
||||
resultChan := make(chan []*pb.CalculateProductTariffsResponse, 10)
|
||||
errChan := make(chan error)
|
||||
|
||||
go a.repo.CalculateProductTariffs(ctx, int(req.MarketplaceId), req, resultChan, errChan)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Println("CalculateProductTariffs: context cancelled or deadline exceeded:", ctx.Err())
|
||||
return ctx.Err()
|
||||
case responses, ok := <-resultChan:
|
||||
if !ok {
|
||||
fmt.Println("CalculateProductTariffs: result channel closed")
|
||||
return nil
|
||||
}
|
||||
for _, response := range responses {
|
||||
if err := stream.Send(response); err != nil {
|
||||
fmt.Println("CalculateProductTariffs: error sending response:", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
case err, ok := <-errChan:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Println("CalculateProductTariffs: error received from channel:", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
13
internal/ym/products/entities.go
Normal file
13
internal/ym/products/entities.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package products
|
||||
|
||||
import (
|
||||
pb "sipro-mps/api/generated/v1/yandexmarket/products"
|
||||
)
|
||||
|
||||
type PbCalculateProductTariffsRequest = pb.CalculateProductTariffsRequest
|
||||
type PbCalculateProductTariffsResponse = pb.CalculateProductTariffsResponse
|
||||
type PbOffer = pb.CalculateProductTariffsRequest_Offers
|
||||
type PbResponseOffer = pb.CalculateProductTariffsResponse_Offers
|
||||
type PbParameters = pb.CalculateProductTariffsRequest_Parameters
|
||||
type PbGetProductsOffer = pb.GetProductsResponse_Offer
|
||||
type PbGetProductsRequest = pb.GetProductsRequest
|
||||
46
internal/ym/products/mapping/converter.go
Normal file
46
internal/ym/products/mapping/converter.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package mapping
|
||||
|
||||
import (
|
||||
proto "sipro-mps/api/generated/v1/yandexmarket/products"
|
||||
"sipro-mps/pkg/api/yandex/ymclient"
|
||||
)
|
||||
|
||||
//go:generate go run github.com/jmattheis/goverter/cmd/goverter gen .
|
||||
|
||||
// goverter:converter
|
||||
// goverter:output:file ./generated/generated.go
|
||||
// goverter:output:package generated
|
||||
// goverter:ignoreUnexported yes
|
||||
// goverter:matchIgnoreCase yes
|
||||
// goverter:useZeroValueOnPointerInconsistency yes
|
||||
// goverter:extend Int64ToFloat32 Int64ToInt32 Float32ToInt64 Int32ToInt64 PointerInt32ToInt64
|
||||
type Converter interface {
|
||||
ProtoOfferToYmOffer(details *proto.CalculateProductTariffsRequest_Offers) *ymclient.CalculateTariffsOfferDTO
|
||||
ProtoParametersToYmParameters(details *proto.CalculateProductTariffsRequest_Parameters) *ymclient.CalculateTariffsParametersDTO
|
||||
|
||||
// Response converters
|
||||
YmOfferToProtoResponseOffer(details *ymclient.CalculateTariffsOfferInfoDTO) *proto.CalculateProductTariffsResponse_Offers
|
||||
YmTariffToProtoTariff(details *ymclient.CalculatedTariffDTO) *proto.CalculateProductTariffsResponse_Tariff
|
||||
|
||||
YmOfferToProtoOffer(details *ymclient.GetOfferDTO) *proto.GetProductsResponse_Offer
|
||||
}
|
||||
|
||||
func Int64ToFloat32(i int64) float32 {
|
||||
return float32(i)
|
||||
}
|
||||
func Int64ToInt32(i int64) int32 {
|
||||
return int32(i)
|
||||
}
|
||||
func Float32ToInt64(f float32) int64 {
|
||||
return int64(f)
|
||||
}
|
||||
func Int32ToInt64(i int32) int64 {
|
||||
return int64(i)
|
||||
}
|
||||
func PointerInt32ToInt64(i *int32) *int64 {
|
||||
if i == nil {
|
||||
return nil
|
||||
}
|
||||
val := int64(*i)
|
||||
return &val
|
||||
}
|
||||
152
internal/ym/products/mapping/generated/generated.go
Normal file
152
internal/ym/products/mapping/generated/generated.go
Normal file
@@ -0,0 +1,152 @@
|
||||
// Code generated by github.com/jmattheis/goverter, DO NOT EDIT.
|
||||
//go:build !goverter
|
||||
|
||||
package generated
|
||||
|
||||
import (
|
||||
products "sipro-mps/api/generated/v1/yandexmarket/products"
|
||||
mapping "sipro-mps/internal/ym/products/mapping"
|
||||
ymclient "sipro-mps/pkg/api/yandex/ymclient"
|
||||
)
|
||||
|
||||
type ConverterImpl struct{}
|
||||
|
||||
func (c *ConverterImpl) ProtoOfferToYmOffer(source *products.CalculateProductTariffsRequest_Offers) *ymclient.CalculateTariffsOfferDTO {
|
||||
var pYmclientCalculateTariffsOfferDTO *ymclient.CalculateTariffsOfferDTO
|
||||
if source != nil {
|
||||
var ymclientCalculateTariffsOfferDTO ymclient.CalculateTariffsOfferDTO
|
||||
ymclientCalculateTariffsOfferDTO.CategoryId = (*source).CategoryId
|
||||
ymclientCalculateTariffsOfferDTO.Price = mapping.Int64ToFloat32((*source).Price)
|
||||
ymclientCalculateTariffsOfferDTO.Length = mapping.Int64ToFloat32((*source).Length)
|
||||
ymclientCalculateTariffsOfferDTO.Width = mapping.Int64ToFloat32((*source).Width)
|
||||
ymclientCalculateTariffsOfferDTO.Height = mapping.Int64ToFloat32((*source).Height)
|
||||
ymclientCalculateTariffsOfferDTO.Weight = mapping.Int64ToFloat32((*source).Weight)
|
||||
pInt32 := mapping.Int64ToInt32((*source).Quantity)
|
||||
ymclientCalculateTariffsOfferDTO.Quantity = &pInt32
|
||||
pYmclientCalculateTariffsOfferDTO = &ymclientCalculateTariffsOfferDTO
|
||||
}
|
||||
return pYmclientCalculateTariffsOfferDTO
|
||||
}
|
||||
func (c *ConverterImpl) ProtoParametersToYmParameters(source *products.CalculateProductTariffsRequest_Parameters) *ymclient.CalculateTariffsParametersDTO {
|
||||
var pYmclientCalculateTariffsParametersDTO *ymclient.CalculateTariffsParametersDTO
|
||||
if source != nil {
|
||||
var ymclientCalculateTariffsParametersDTO ymclient.CalculateTariffsParametersDTO
|
||||
pInt64 := (*source).CampaignId
|
||||
ymclientCalculateTariffsParametersDTO.CampaignId = &pInt64
|
||||
pYmclientSellingProgramType := ymclient.SellingProgramType((*source).SellingProgram)
|
||||
ymclientCalculateTariffsParametersDTO.SellingProgram = &pYmclientSellingProgramType
|
||||
pYmclientPaymentFrequencyType := ymclient.PaymentFrequencyType((*source).Frequency)
|
||||
ymclientCalculateTariffsParametersDTO.Frequency = &pYmclientPaymentFrequencyType
|
||||
pYmclientCurrencyType := ymclient.CurrencyType((*source).Currency)
|
||||
ymclientCalculateTariffsParametersDTO.Currency = &pYmclientCurrencyType
|
||||
pYmclientCalculateTariffsParametersDTO = &ymclientCalculateTariffsParametersDTO
|
||||
}
|
||||
return pYmclientCalculateTariffsParametersDTO
|
||||
}
|
||||
func (c *ConverterImpl) YmOfferToProtoOffer(source *ymclient.GetOfferDTO) *products.GetProductsResponse_Offer {
|
||||
var pProductsGetProductsResponse_Offer *products.GetProductsResponse_Offer
|
||||
if source != nil {
|
||||
var productsGetProductsResponse_Offer products.GetProductsResponse_Offer
|
||||
if (*source).MarketCategoryId != nil {
|
||||
productsGetProductsResponse_Offer.MarketCategoryId = *(*source).MarketCategoryId
|
||||
}
|
||||
productsGetProductsResponse_Offer.WeightDimensions = c.pYmclientOfferWeightDimensionsDTOToPProductsGetProductsResponse_Offer_WeightDimensions((*source).WeightDimensions)
|
||||
productsGetProductsResponse_Offer.BasicPrice = c.pYmclientGetPriceWithDiscountDTOToPProductsGetProductsResponse_Offer_BasicPrice((*source).BasicPrice)
|
||||
productsGetProductsResponse_Offer.OfferId = (*source).OfferId
|
||||
pProductsGetProductsResponse_Offer = &productsGetProductsResponse_Offer
|
||||
}
|
||||
return pProductsGetProductsResponse_Offer
|
||||
}
|
||||
func (c *ConverterImpl) YmOfferToProtoResponseOffer(source *ymclient.CalculateTariffsOfferInfoDTO) *products.CalculateProductTariffsResponse_Offers {
|
||||
var pProductsCalculateProductTariffsResponse_Offers *products.CalculateProductTariffsResponse_Offers
|
||||
if source != nil {
|
||||
var productsCalculateProductTariffsResponse_Offers products.CalculateProductTariffsResponse_Offers
|
||||
productsCalculateProductTariffsResponse_Offers.Offer = c.ymclientCalculateTariffsOfferDTOToPProductsCalculateProductTariffsResponse_Offer((*source).Offer)
|
||||
if (*source).Tariffs != nil {
|
||||
productsCalculateProductTariffsResponse_Offers.Tariffs = make([]*products.CalculateProductTariffsResponse_Tariff, len((*source).Tariffs))
|
||||
for i := 0; i < len((*source).Tariffs); i++ {
|
||||
productsCalculateProductTariffsResponse_Offers.Tariffs[i] = c.ymclientCalculatedTariffDTOToPProductsCalculateProductTariffsResponse_Tariff((*source).Tariffs[i])
|
||||
}
|
||||
}
|
||||
pProductsCalculateProductTariffsResponse_Offers = &productsCalculateProductTariffsResponse_Offers
|
||||
}
|
||||
return pProductsCalculateProductTariffsResponse_Offers
|
||||
}
|
||||
func (c *ConverterImpl) YmTariffToProtoTariff(source *ymclient.CalculatedTariffDTO) *products.CalculateProductTariffsResponse_Tariff {
|
||||
var pProductsCalculateProductTariffsResponse_Tariff *products.CalculateProductTariffsResponse_Tariff
|
||||
if source != nil {
|
||||
var productsCalculateProductTariffsResponse_Tariff products.CalculateProductTariffsResponse_Tariff
|
||||
productsCalculateProductTariffsResponse_Tariff.Type = string((*source).Type)
|
||||
if (*source).Amount != nil {
|
||||
productsCalculateProductTariffsResponse_Tariff.Amount = mapping.Float32ToInt64(*(*source).Amount)
|
||||
}
|
||||
if (*source).Currency != nil {
|
||||
productsCalculateProductTariffsResponse_Tariff.Currency = string(*(*source).Currency)
|
||||
}
|
||||
if (*source).Parameters != nil {
|
||||
productsCalculateProductTariffsResponse_Tariff.Parameters = make([]*products.CalculateProductTariffsResponse_Parameter, len((*source).Parameters))
|
||||
for i := 0; i < len((*source).Parameters); i++ {
|
||||
productsCalculateProductTariffsResponse_Tariff.Parameters[i] = c.ymclientTariffParameterDTOToPProductsCalculateProductTariffsResponse_Parameter((*source).Parameters[i])
|
||||
}
|
||||
}
|
||||
pProductsCalculateProductTariffsResponse_Tariff = &productsCalculateProductTariffsResponse_Tariff
|
||||
}
|
||||
return pProductsCalculateProductTariffsResponse_Tariff
|
||||
}
|
||||
func (c *ConverterImpl) pYmclientGetPriceWithDiscountDTOToPProductsGetProductsResponse_Offer_BasicPrice(source *ymclient.GetPriceWithDiscountDTO) *products.GetProductsResponse_Offer_BasicPrice {
|
||||
var pProductsGetProductsResponse_Offer_BasicPrice *products.GetProductsResponse_Offer_BasicPrice
|
||||
if source != nil {
|
||||
var productsGetProductsResponse_Offer_BasicPrice products.GetProductsResponse_Offer_BasicPrice
|
||||
productsGetProductsResponse_Offer_BasicPrice.Value = (*source).Value
|
||||
pProductsGetProductsResponse_Offer_BasicPrice = &productsGetProductsResponse_Offer_BasicPrice
|
||||
}
|
||||
return pProductsGetProductsResponse_Offer_BasicPrice
|
||||
}
|
||||
func (c *ConverterImpl) pYmclientOfferWeightDimensionsDTOToPProductsGetProductsResponse_Offer_WeightDimensions(source *ymclient.OfferWeightDimensionsDTO) *products.GetProductsResponse_Offer_WeightDimensions {
|
||||
var pProductsGetProductsResponse_Offer_WeightDimensions *products.GetProductsResponse_Offer_WeightDimensions
|
||||
if source != nil {
|
||||
var productsGetProductsResponse_Offer_WeightDimensions products.GetProductsResponse_Offer_WeightDimensions
|
||||
productsGetProductsResponse_Offer_WeightDimensions.Length = (*source).Length
|
||||
productsGetProductsResponse_Offer_WeightDimensions.Width = (*source).Width
|
||||
productsGetProductsResponse_Offer_WeightDimensions.Height = (*source).Height
|
||||
productsGetProductsResponse_Offer_WeightDimensions.Weight = (*source).Weight
|
||||
pProductsGetProductsResponse_Offer_WeightDimensions = &productsGetProductsResponse_Offer_WeightDimensions
|
||||
}
|
||||
return pProductsGetProductsResponse_Offer_WeightDimensions
|
||||
}
|
||||
func (c *ConverterImpl) ymclientCalculateTariffsOfferDTOToPProductsCalculateProductTariffsResponse_Offer(source ymclient.CalculateTariffsOfferDTO) *products.CalculateProductTariffsResponse_Offer {
|
||||
var productsCalculateProductTariffsResponse_Offer products.CalculateProductTariffsResponse_Offer
|
||||
productsCalculateProductTariffsResponse_Offer.CategoryId = source.CategoryId
|
||||
productsCalculateProductTariffsResponse_Offer.Price = mapping.Float32ToInt64(source.Price)
|
||||
productsCalculateProductTariffsResponse_Offer.Length = mapping.Float32ToInt64(source.Length)
|
||||
productsCalculateProductTariffsResponse_Offer.Width = mapping.Float32ToInt64(source.Width)
|
||||
productsCalculateProductTariffsResponse_Offer.Height = mapping.Float32ToInt64(source.Height)
|
||||
productsCalculateProductTariffsResponse_Offer.Weight = mapping.Float32ToInt64(source.Weight)
|
||||
if source.Quantity != nil {
|
||||
productsCalculateProductTariffsResponse_Offer.Quantity = mapping.Int32ToInt64(*source.Quantity)
|
||||
}
|
||||
return &productsCalculateProductTariffsResponse_Offer
|
||||
}
|
||||
func (c *ConverterImpl) ymclientCalculatedTariffDTOToPProductsCalculateProductTariffsResponse_Tariff(source ymclient.CalculatedTariffDTO) *products.CalculateProductTariffsResponse_Tariff {
|
||||
var productsCalculateProductTariffsResponse_Tariff products.CalculateProductTariffsResponse_Tariff
|
||||
productsCalculateProductTariffsResponse_Tariff.Type = string(source.Type)
|
||||
if source.Amount != nil {
|
||||
productsCalculateProductTariffsResponse_Tariff.Amount = mapping.Float32ToInt64(*source.Amount)
|
||||
}
|
||||
if source.Currency != nil {
|
||||
productsCalculateProductTariffsResponse_Tariff.Currency = string(*source.Currency)
|
||||
}
|
||||
if source.Parameters != nil {
|
||||
productsCalculateProductTariffsResponse_Tariff.Parameters = make([]*products.CalculateProductTariffsResponse_Parameter, len(source.Parameters))
|
||||
for i := 0; i < len(source.Parameters); i++ {
|
||||
productsCalculateProductTariffsResponse_Tariff.Parameters[i] = c.ymclientTariffParameterDTOToPProductsCalculateProductTariffsResponse_Parameter(source.Parameters[i])
|
||||
}
|
||||
}
|
||||
return &productsCalculateProductTariffsResponse_Tariff
|
||||
}
|
||||
func (c *ConverterImpl) ymclientTariffParameterDTOToPProductsCalculateProductTariffsResponse_Parameter(source ymclient.TariffParameterDTO) *products.CalculateProductTariffsResponse_Parameter {
|
||||
var productsCalculateProductTariffsResponse_Parameter products.CalculateProductTariffsResponse_Parameter
|
||||
productsCalculateProductTariffsResponse_Parameter.Name = source.Name
|
||||
productsCalculateProductTariffsResponse_Parameter.Value = source.Value
|
||||
return &productsCalculateProductTariffsResponse_Parameter
|
||||
}
|
||||
12
internal/ym/products/repository.go
Normal file
12
internal/ym/products/repository.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package products
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
pb "sipro-mps/api/generated/v1/yandexmarket/products"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
CalculateProductTariffs(ctx context.Context, marketplaceID int, req *pb.CalculateProductTariffsRequest, resultChan chan<- []*pb.CalculateProductTariffsResponse, errChan chan<- error)
|
||||
GetProducts(ctx context.Context, marketplaceID int, req *pb.GetProductsRequest, resultChan chan<- []*pb.GetProductsResponse_Offer, errChan chan<- error)
|
||||
}
|
||||
275
internal/ym/products/repository_api.go
Normal file
275
internal/ym/products/repository_api.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package products
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
pb "sipro-mps/api/generated/v1/yandexmarket/products"
|
||||
"sipro-mps/internal/marketplace"
|
||||
"sipro-mps/internal/ym"
|
||||
"sipro-mps/internal/ym/products/mapping/generated"
|
||||
"sipro-mps/pkg/api/yandex/ymclient"
|
||||
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultChunkSize = 200
|
||||
offerMappingsRateLimit = 600
|
||||
tariffsRateLimit = 100
|
||||
rateLimitWindow = time.Minute
|
||||
maxPageSize = math.MaxInt32
|
||||
)
|
||||
|
||||
// apiRepository implements the Repository interface using Yandex Market API
|
||||
type apiRepository struct {
|
||||
marketplaceRepository marketplace.Repository
|
||||
converter *generated.ConverterImpl
|
||||
}
|
||||
|
||||
// NewAPIRepository creates a new API-based repository implementation
|
||||
func NewAPIRepository(marketplaceRepository marketplace.Repository) Repository {
|
||||
return &apiRepository{
|
||||
marketplaceRepository: marketplaceRepository,
|
||||
converter: &generated.ConverterImpl{},
|
||||
}
|
||||
}
|
||||
|
||||
// getBusinessID retrieves the business ID for a given marketplace by looking up the campaign
|
||||
func (r *apiRepository) getBusinessID(ctx context.Context, mp *marketplace.Marketplace) (int64, error) {
|
||||
if mp.CampaignID == "" {
|
||||
return 0, fmt.Errorf("campaign ID is not set for marketplace %d", mp.ID)
|
||||
}
|
||||
|
||||
campaignID, err := r.validateCampaignID(mp.CampaignID, mp.ID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
client, err := ym.GetClientFromMarketplace(mp)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to create Yandex Market client: %w", err)
|
||||
}
|
||||
|
||||
businessID, err := r.fetchBusinessIDFromCampaigns(ctx, client, campaignID, mp.ID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return businessID, nil
|
||||
}
|
||||
|
||||
// validateCampaignID validates and parses the campaign ID string
|
||||
func (r *apiRepository) validateCampaignID(campaignIDStr string, marketplaceID int) (int64, error) {
|
||||
campaignID, err := strconv.ParseInt(campaignIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid campaign ID '%s' for marketplace %d: %w", campaignIDStr, marketplaceID, err)
|
||||
}
|
||||
return campaignID, nil
|
||||
}
|
||||
|
||||
// fetchBusinessIDFromCampaigns retrieves business ID by searching through campaigns
|
||||
func (r *apiRepository) fetchBusinessIDFromCampaigns(ctx context.Context, client *ymclient.APIClient, campaignID int64, marketplaceID int) (int64, error) {
|
||||
rsp, _, err := client.CampaignsAPI.GetCampaigns(ctx).Page(1).PageSize(maxPageSize).Execute()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to call GetCampaigns: %w", err)
|
||||
}
|
||||
if rsp == nil {
|
||||
return 0, fmt.Errorf("GetCampaigns returned nil response")
|
||||
}
|
||||
|
||||
for _, campaign := range rsp.Campaigns {
|
||||
if campaign.GetId() == campaignID {
|
||||
return campaign.Business.GetId(), nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("campaign ID %d not found in GetCampaigns response for marketplace %d", campaignID, marketplaceID)
|
||||
}
|
||||
|
||||
// GetProducts retrieves products from Yandex Market API in chunks and sends results to channels
|
||||
func (r *apiRepository) GetProducts(ctx context.Context, marketplaceID int, req *pb.GetProductsRequest, resultChan chan<- []*pb.GetProductsResponse_Offer, errChan chan<- error) {
|
||||
defer close(resultChan)
|
||||
defer close(errChan)
|
||||
|
||||
_, client, businessID, err := r.setupMarketplaceClient(ctx, marketplaceID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
r.setOfferMappingsRateLimit(businessID)
|
||||
|
||||
for _, chunk := range lo.Chunk(req.OfferIds, defaultChunkSize) {
|
||||
offers, err := r.fetchOfferMappings(ctx, client, businessID, chunk)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
resultChan <- offers
|
||||
}
|
||||
}
|
||||
|
||||
// setupMarketplaceClient initializes marketplace, API client, and business ID
|
||||
func (r *apiRepository) setupMarketplaceClient(ctx context.Context, marketplaceID int) (*marketplace.Marketplace, *ymclient.APIClient, int64, error) {
|
||||
mp, err := r.marketplaceRepository.GetMarketplaceByID(ctx, marketplaceID)
|
||||
if err != nil {
|
||||
return nil, nil, 0, fmt.Errorf("failed to get marketplace: %w", err)
|
||||
}
|
||||
|
||||
client, err := ym.GetClientFromMarketplace(mp)
|
||||
if err != nil {
|
||||
return nil, nil, 0, fmt.Errorf("failed to create Yandex Market client: %w", err)
|
||||
}
|
||||
|
||||
businessID, err := r.getBusinessID(ctx, mp)
|
||||
if err != nil {
|
||||
return nil, nil, 0, fmt.Errorf("failed to get business ID: %w", err)
|
||||
}
|
||||
|
||||
return mp, client, businessID, nil
|
||||
}
|
||||
|
||||
// setOfferMappingsRateLimit configures rate limiting for offer mappings endpoint
|
||||
func (r *apiRepository) setOfferMappingsRateLimit(businessID int64) {
|
||||
path := "/businesses/" + strconv.Itoa(int(businessID)) + "/offer-mappings"
|
||||
ym.SetPathLimit(path, rateLimitWindow, offerMappingsRateLimit)
|
||||
}
|
||||
|
||||
// fetchOfferMappings retrieves offer mappings for a given set of offer IDs
|
||||
func (r *apiRepository) fetchOfferMappings(ctx context.Context, client *ymclient.APIClient, businessID int64, offerIDs []string) ([]*pb.GetProductsResponse_Offer, error) {
|
||||
req := ymclient.NewGetOfferMappingsRequest()
|
||||
req.OfferIds = offerIDs
|
||||
|
||||
rsp, _, err := client.BusinessOfferMappingsAPI.GetOfferMappings(ctx, businessID).GetOfferMappingsRequest(*req).Execute()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call GetOfferMappings: %w", err)
|
||||
}
|
||||
if rsp == nil {
|
||||
return nil, fmt.Errorf("GetOfferMappings returned nil response")
|
||||
}
|
||||
|
||||
return r.processOfferMappings(rsp.Result.GetOfferMappings())
|
||||
}
|
||||
|
||||
// processOfferMappings converts YM offer mappings to protobuf format
|
||||
func (r *apiRepository) processOfferMappings(offerMappings []ymclient.GetOfferMappingDTO) ([]*pb.GetProductsResponse_Offer, error) {
|
||||
var resultOffers []*pb.GetProductsResponse_Offer
|
||||
|
||||
for _, offerMapping := range offerMappings {
|
||||
protoOffer := r.converter.YmOfferToProtoOffer(offerMapping.Offer)
|
||||
if protoOffer == nil {
|
||||
fmt.Printf("Warning: received nil offer for ID %s\n", offerMapping.Offer.OfferId)
|
||||
resultOffers = append(resultOffers, &pb.GetProductsResponse_Offer{})
|
||||
continue
|
||||
}
|
||||
|
||||
resultOffers = append(resultOffers, protoOffer)
|
||||
}
|
||||
|
||||
return resultOffers, nil
|
||||
}
|
||||
|
||||
// CalculateProductTariffs calculates tariffs for products using Yandex Market API
|
||||
func (r *apiRepository) CalculateProductTariffs(ctx context.Context, marketplaceID int, req *pb.CalculateProductTariffsRequest, resultChan chan<- []*pb.CalculateProductTariffsResponse, errChan chan<- error) {
|
||||
defer close(resultChan)
|
||||
defer close(errChan)
|
||||
|
||||
_, client, _, err := r.setupMarketplaceClient(ctx, marketplaceID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
r.setTariffsRateLimit()
|
||||
|
||||
ymParameters := r.converter.ProtoParametersToYmParameters(req.Parameters)
|
||||
if ymParameters == nil {
|
||||
errChan <- fmt.Errorf("failed to convert request parameters")
|
||||
return
|
||||
}
|
||||
|
||||
offerChunks := lo.Chunk(req.Offers, defaultChunkSize)
|
||||
|
||||
for chunkIndex, offerChunk := range offerChunks {
|
||||
fmt.Printf("Processing chunk %d/%d with %d offers\n", chunkIndex+1, len(offerChunks), len(offerChunk))
|
||||
|
||||
response, err := r.processTariffChunk(ctx, client, ymParameters, offerChunk, chunkIndex)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
if response != nil {
|
||||
resultChan <- []*pb.CalculateProductTariffsResponse{response}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setTariffsRateLimit configures rate limiting for tariffs calculation endpoint
|
||||
func (r *apiRepository) setTariffsRateLimit() {
|
||||
ym.SetPathLimit("/tariffs/calculate", rateLimitWindow, tariffsRateLimit)
|
||||
}
|
||||
|
||||
// processTariffChunk processes a single chunk of offers for tariff calculation
|
||||
func (r *apiRepository) processTariffChunk(ctx context.Context, client *ymclient.APIClient, ymParameters *ymclient.CalculateTariffsParametersDTO, offerChunk []*pb.CalculateProductTariffsRequest_Offers, chunkIndex int) (*pb.CalculateProductTariffsResponse, error) {
|
||||
ymOffers := r.convertOffersToYM(offerChunk)
|
||||
if len(ymOffers) == 0 {
|
||||
fmt.Printf("Skipping chunk %d: no valid offers\n", chunkIndex+1)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ymRequest := ymclient.NewCalculateTariffsRequest(*ymParameters, ymOffers)
|
||||
|
||||
response, httpResp, err := client.TariffsAPI.CalculateTariffs(ctx).
|
||||
CalculateTariffsRequest(*ymRequest).
|
||||
Execute()
|
||||
|
||||
if httpResp != nil && httpResp.Body != nil {
|
||||
_ = httpResp.Body.Close()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call Yandex Market API for chunk %d: %w", chunkIndex+1, err)
|
||||
}
|
||||
|
||||
if response == nil || response.Result == nil {
|
||||
fmt.Printf("Warning: received empty response for chunk %d\n", chunkIndex+1)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return r.convertResponseToProto(response), nil
|
||||
}
|
||||
|
||||
// convertOffersToYM converts protobuf offers to Yandex Market format
|
||||
func (r *apiRepository) convertOffersToYM(offers []*pb.CalculateProductTariffsRequest_Offers) []ymclient.CalculateTariffsOfferDTO {
|
||||
var ymOffers []ymclient.CalculateTariffsOfferDTO
|
||||
for _, offer := range offers {
|
||||
ymOffer := r.converter.ProtoOfferToYmOffer(offer)
|
||||
if ymOffer != nil {
|
||||
ymOffers = append(ymOffers, *ymOffer)
|
||||
}
|
||||
}
|
||||
return ymOffers
|
||||
}
|
||||
|
||||
// convertResponseToProto converts Yandex Market response to protobuf format
|
||||
func (r *apiRepository) convertResponseToProto(response *ymclient.CalculateTariffsResponse) *pb.CalculateProductTariffsResponse {
|
||||
var offers []*pb.CalculateProductTariffsResponse_Offers
|
||||
|
||||
if response.Result.Offers != nil {
|
||||
for _, ymOfferInfo := range response.Result.Offers {
|
||||
pbResponseOffer := r.converter.YmOfferToProtoResponseOffer(&ymOfferInfo)
|
||||
if pbResponseOffer != nil {
|
||||
offers = append(offers, pbResponseOffer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := &pb.CalculateProductTariffsResponse{}
|
||||
result.Offers = offers
|
||||
return result
|
||||
}
|
||||
113
internal/ym/rate_limiter.go
Normal file
113
internal/ym/rate_limiter.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package ym
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sipro-mps/internal/redis"
|
||||
"time"
|
||||
|
||||
"github.com/redis/rueidis"
|
||||
)
|
||||
|
||||
// RateLimit defines a rate limit configuration
|
||||
type RateLimit struct {
|
||||
Count int // Number of requests allowed
|
||||
TimeDelta time.Duration // Time window
|
||||
}
|
||||
|
||||
// Path rate limits for Yandex Market API
|
||||
var PathLimits = map[string]RateLimit{
|
||||
"/tariffs/calculate": {Count: 100, TimeDelta: time.Minute},
|
||||
"/campaigns": {Count: 300, TimeDelta: time.Minute},
|
||||
"/orders": {Count: 1000, TimeDelta: time.Minute},
|
||||
}
|
||||
|
||||
var rateLimitScript = rueidis.NewLuaScript(`
|
||||
local key = KEYS[1]
|
||||
local now = tonumber(ARGV[1])
|
||||
local window = tonumber(ARGV[2])
|
||||
local limit = tonumber(ARGV[3])
|
||||
|
||||
-- Remove old entries outside the time window
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', now - window)
|
||||
local count = redis.call('ZCARD', key)
|
||||
|
||||
if count < limit then
|
||||
-- Add new request timestamp and set TTL
|
||||
redis.call('ZADD', key, now, now)
|
||||
redis.call('EXPIRE', key, math.ceil(window / 1000))
|
||||
return 0
|
||||
else
|
||||
-- Find oldest request timestamp
|
||||
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')[2]
|
||||
-- Return wait time until oldest request expires
|
||||
return (tonumber(oldest) + window) - now
|
||||
end
|
||||
`)
|
||||
|
||||
type RateLimitTransport struct {
|
||||
http.RoundTripper
|
||||
}
|
||||
|
||||
func (t *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
ctx := req.Context()
|
||||
|
||||
// Extract API key from headers
|
||||
apiKey := req.Header.Get("Api-Key")
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("Api-Key header is required for rate limiting")
|
||||
}
|
||||
|
||||
// Get path from header or URL
|
||||
var path string
|
||||
path = req.URL.Path
|
||||
|
||||
// Get rate limit for this path
|
||||
rateLimit, exists := PathLimits[path]
|
||||
if !exists {
|
||||
rateLimit = RateLimit{Count: 100, TimeDelta: time.Minute} // default limit
|
||||
}
|
||||
|
||||
// Create unique key based on API key and path
|
||||
rateLimitKey := fmt.Sprintf("ym:ratelimit:%s:%s", apiKey, path)
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
windowMillis := int64(rateLimit.TimeDelta / time.Millisecond)
|
||||
|
||||
client := *redis.Client
|
||||
|
||||
waitTime, err := rateLimitScript.Exec(ctx, client, []string{rateLimitKey}, []string{
|
||||
fmt.Sprintf("%d", now),
|
||||
fmt.Sprintf("%d", windowMillis),
|
||||
fmt.Sprintf("%d", rateLimit.Count),
|
||||
}).ToInt64()
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rate limit script error: %w", err)
|
||||
}
|
||||
|
||||
if waitTime > 0 {
|
||||
select {
|
||||
case <-time.After(time.Duration(waitTime) * time.Millisecond):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
return t.RoundTripper.RoundTrip(req)
|
||||
}
|
||||
|
||||
// NewRateLimitTransport creates a new rate limiting transport
|
||||
func NewRateLimitTransport() *RateLimitTransport {
|
||||
return &RateLimitTransport{
|
||||
RoundTripper: http.DefaultTransport,
|
||||
}
|
||||
}
|
||||
|
||||
// SetPathLimit sets a custom rate limit for a specific path
|
||||
func SetPathLimit(path string, timeDelta time.Duration, count int) {
|
||||
PathLimits[path] = RateLimit{
|
||||
Count: count,
|
||||
TimeDelta: timeDelta,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user