From 7ba3426989cb5bdee283bc42b14dd83614db8577 Mon Sep 17 00:00:00 2001 From: fakz9 Date: Tue, 2 Jul 2024 08:55:24 +0300 Subject: [PATCH] 123 --- background/tasks.py | 13 +++--- background/update.py | 14 ++++++ database/sipro/models/general.py | 1 + database/sipro/models/products.py | 3 ++ main.py | 54 ++++++++++------------ marketplaces/__init__.py | 6 +-- marketplaces/base.py | 17 ++++--- marketplaces/factory.py | 16 +++---- marketplaces/ozon.py | 26 +++++------ marketplaces/wildberries.py | 25 ++++++----- queries/general.py | 33 ++++++++------ schemas/__init__.py | 0 schemas/general.py | 20 +++++++++ updaters/base.py | 36 ++++++++++++--- updaters/factory.py | 18 ++++++++ updaters/ozon_updater.py | 17 +++---- updaters/stocks_updater.py | 74 +++++++++++-------------------- updaters/wildberries_updater.py | 10 +++++ 18 files changed, 228 insertions(+), 155 deletions(-) create mode 100644 background/update.py create mode 100644 schemas/__init__.py create mode 100644 schemas/general.py create mode 100644 updaters/factory.py diff --git a/background/tasks.py b/background/tasks.py index 421fbcd..a95b664 100644 --- a/background/tasks.py +++ b/background/tasks.py @@ -1,9 +1,12 @@ -import json +import asyncio + +from asgiref.sync import async_to_sync from background import celery +import background.update -@celery.task(name='test') -def test_task(): - with open('test.json', 'a') as tf: - tf.write(json.dumps({'ok': True})) +@celery.task(name='process_update') +def process_update(product_ids: list[int]): + loop = asyncio.get_event_loop() + return loop.run_until_complete(background.update.process_update(product_ids)) diff --git a/background/update.py b/background/update.py new file mode 100644 index 0000000..a04ab83 --- /dev/null +++ b/background/update.py @@ -0,0 +1,14 @@ +import time + +from backend.session import get_session +from schemas.general import StockUpdate +from updaters.stocks_updater import StocksUpdater + + +async def process_update(product_ids: list[int]): + async for session in get_session(): + updates = [StockUpdate(product_id=product_id) for product_id in product_ids] + updater = StocksUpdater(session) + await updater.update(updates) + await session.close() + return {'message': f'Stocks for [{",".join(map(str, product_ids))}] successfully updated'} diff --git a/database/sipro/models/general.py b/database/sipro/models/general.py index f25db59..e33d5db 100644 --- a/database/sipro/models/general.py +++ b/database/sipro/models/general.py @@ -51,6 +51,7 @@ class Marketplace(BaseSiproModel): sell_from_price: Mapped[bool] = mapped_column() warehouses: Mapped[List["Warehouse"]] = relationship(secondary=marketplace_warehouses) + warehouse_id: Mapped[str] = mapped_column() company_id: Mapped[int] = mapped_column(ForeignKey('companies.id')) company: Mapped["Company"] = relationship() diff --git a/database/sipro/models/products.py b/database/sipro/models/products.py index d7becfb..8c00acd 100644 --- a/database/sipro/models/products.py +++ b/database/sipro/models/products.py @@ -24,11 +24,14 @@ class MarketplaceProduct(BaseSiproModel): product_id: Mapped[int] = mapped_column(ForeignKey("products.id")) product: Mapped["Product"] = relationship() + third_additional_article: Mapped[str] = mapped_column() + class SupplierProduct(BaseSiproModel): __tablename__ = 'supplier_products' id: Mapped[int] = mapped_column(primary_key=True) supplier_stock: Mapped[int] = mapped_column() + sold_today: Mapped[int] = mapped_column() supplier_id: Mapped[int] = mapped_column() product_id: Mapped[int] = mapped_column(ForeignKey("products.id")) diff --git a/main.py b/main.py index 55bc5e4..758437a 100644 --- a/main.py +++ b/main.py @@ -1,43 +1,37 @@ from typing import Annotated from celery.result import AsyncResult -from fastapi import FastAPI, Depends, Body -from sqlalchemy import select -from sqlalchemy.dialects.postgresql import insert -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload +from fastapi import FastAPI, Depends, HTTPException +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from starlette import status from starlette.responses import JSONResponse -from backend.session import get_session -from database import DailyStock -from database.sipro import * -from queries.general import get_stocks_data +import background.tasks from background.tasks import * -from updaters.stocks_updater import StockUpdate +from schemas.general import UpdateRequest, UpdateResponse -app = FastAPI() +auth_schema = HTTPBearer() -@app.get("/") -async def root( - session: Annotated[AsyncSession, Depends(get_session)], - marketplace_id: int +async def check_auth(token: Annotated[HTTPAuthorizationCredentials, Depends(auth_schema)]): + if token.credentials != 'vvHh1QNl7lS6c7OVwmxU1TVNd7DLlc9W810csZGf4rkqOrBy6fQwlhIDZsQZd9hQYZYK47yWv33aCq': + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='Invalid credentials') + + +app = FastAPI( + dependencies=[Depends(check_auth)] +) + + +@app.post( + '/update', + response_model=UpdateResponse +) +async def update( + request: UpdateRequest ): - marketplace = await session.get(Marketplace, marketplace_id, options=[ - joinedload(Marketplace.warehouses).joinedload(Warehouse.suppliers), - joinedload(Marketplace.warehouses).joinedload(Warehouse.company_warehouses), - joinedload(Marketplace.company).joinedload(Company.warehouse) - ]) - data = await get_stocks_data(session, marketplace) - data = sorted(data, key=lambda x: x['denco_article']) - return {"message": data} - - -@app.post("/tasks", status_code=201) -def run_task(payload=Body(...)): - task_type = payload["type"] - task = test_task.delay() - return JSONResponse({"task_id": task.id}) + task = background.tasks.process_update.delay(request.product_ids) + return UpdateResponse(task_id=task.id) @app.get("/tasks/{task_id}") diff --git a/marketplaces/__init__.py b/marketplaces/__init__.py index 1f797af..12d4cd2 100644 --- a/marketplaces/__init__.py +++ b/marketplaces/__init__.py @@ -1,3 +1,3 @@ -from .ozon import OzonMarketplace -from .wildberries import WildberriesMarketplace -from .factory import MarketplaceFactory +from .ozon import OzonMarketplaceApi +from .wildberries import WildberriesMarketplaceApi +from .factory import MarketplaceApiFactory diff --git a/marketplaces/base.py b/marketplaces/base.py index f0254b4..5e501cd 100644 --- a/marketplaces/base.py +++ b/marketplaces/base.py @@ -7,7 +7,7 @@ from aiohttp import ClientResponse from database import Marketplace -class BaseJsonMarketplace(ABC): +class BaseMarketplaceApi(ABC): @abstractmethod def __init__(self, marketplace: Marketplace): pass @@ -20,18 +20,17 @@ class BaseJsonMarketplace(ABC): def get_headers(self): pass - @abstractmethod @property + @abstractmethod def api_url(self): pass async def _method(self, http_method: Literal['POST', 'GET', 'PATCH', 'PUT', 'DELETE'], method: str, data: dict) -> ClientResponse: - async with aiohttp.ClientSession as session: - async with session.request(http_method, - f'{self.api_url}{method}', - json=data, - headers=self.get_headers() - ) as response: - return response + async with aiohttp.ClientSession() as session: + return await session.request(http_method, + f'{self.api_url}{method}', + json=data, + headers=self.get_headers() + ) diff --git a/marketplaces/factory.py b/marketplaces/factory.py index 57ffb40..c45b653 100644 --- a/marketplaces/factory.py +++ b/marketplaces/factory.py @@ -2,18 +2,18 @@ from typing import Union from database import Marketplace from database.sipro.enums.general import BaseMarketplace -from .wildberries import WildberriesMarketplace -from .ozon import OzonMarketplace +from .wildberries import WildberriesMarketplaceApi +from .ozon import OzonMarketplaceApi -class MarketplaceFactory: +class MarketplaceApiFactory: @staticmethod - def get_marketplace(marketplace: Marketplace) -> Union[ - WildberriesMarketplace, - OzonMarketplace, + def get_marketplace_api(marketplace: Marketplace) -> Union[ + WildberriesMarketplaceApi, + OzonMarketplaceApi, ]: match marketplace.base_marketplace: case BaseMarketplace.OZON: - return OzonMarketplace(marketplace) + return OzonMarketplaceApi(marketplace) case BaseMarketplace.WILDBERRIES: - return WildberriesMarketplace(marketplace) + return WildberriesMarketplaceApi(marketplace) diff --git a/marketplaces/ozon.py b/marketplaces/ozon.py index 1604254..1c45a31 100644 --- a/marketplaces/ozon.py +++ b/marketplaces/ozon.py @@ -1,17 +1,15 @@ +import asyncio import json import logging from typing import Union -from aiolimiter import AsyncLimiter -from asynciolimiter import StrictLimiter - import utils from database import Marketplace from limiter import BatchLimiter -from marketplaces.base import BaseJsonMarketplace +from marketplaces.base import BaseMarketplaceApi -class OzonMarketplace(BaseJsonMarketplace): +class OzonMarketplaceApi(BaseMarketplaceApi): def __init__(self, marketplace: Marketplace): self.marketplace = marketplace @@ -25,6 +23,7 @@ class OzonMarketplace(BaseJsonMarketplace): def get_headers(self): return self.headers + @property def api_url(self): return 'https://api-seller.ozon.ru' @@ -33,22 +32,23 @@ class OzonMarketplace(BaseJsonMarketplace): return max_stocks = 100 chunks = utils.chunk_list(data, max_stocks) - limiter = BatchLimiter(max_requests=80, - period=60) - for chunk in chunks: + limiter = BatchLimiter(max_requests=80, period=60) + + async def send_stock_chunk(chunk): try: await limiter.acquire() - response = await self._method('POST', - '/v2/products/stocks', - data=chunk) + request_data = {'stocks': chunk} + response = await self._method('POST', '/v2/products/stocks', data=request_data) + print(request_data) response = await response.json() - # response = await error_message = response.get('message') error_code = response.get('code') if error_message: logging.warning( f'Error occurred when sending stocks to [{self.marketplace.id}]: {error_message} ({error_code})') - break except Exception as e: logging.error( f'Exception occurred while sending stocks to marketplace ID [{self.marketplace.id}]: {str(e)}') + + tasks = [send_stock_chunk(chunk) for chunk in chunks] + await asyncio.gather(*tasks) diff --git a/marketplaces/wildberries.py b/marketplaces/wildberries.py index dac4bf7..5bd87d9 100644 --- a/marketplaces/wildberries.py +++ b/marketplaces/wildberries.py @@ -1,3 +1,4 @@ +import asyncio import json import logging from typing import Union @@ -5,10 +6,10 @@ from typing import Union import utils from database import Marketplace from limiter import BatchLimiter -from marketplaces.base import BaseJsonMarketplace +from marketplaces.base import BaseMarketplaceApi -class WildberriesMarketplace(BaseJsonMarketplace): +class WildberriesMarketplaceApi(BaseMarketplaceApi): def __init__(self, marketplace: Marketplace): self.marketplace = marketplace auth_data = json.loads(marketplace.auth_data) @@ -21,6 +22,7 @@ class WildberriesMarketplace(BaseJsonMarketplace): def get_headers(self): return self.headers + @property def api_url(self): return 'https://suppliers-api.wildberries.ru' @@ -29,21 +31,24 @@ class WildberriesMarketplace(BaseJsonMarketplace): return max_stocks = 1000 chunks = utils.chunk_list(data, max_stocks) - limiter = BatchLimiter(max_requests=300, - period=60) - for chunk in chunks: + limiter = BatchLimiter(max_requests=300, period=60) + + async def send_stock_chunk(chunk): try: await limiter.acquire() - response = await self._method('PUT', - '/api/v3/stocks/{warehouseId}', - chunk) - if response.status != 204: + request_data = {'stocks': chunk} + response = await self._method('PUT', f'/api/v3/stocks/{self.marketplace.warehouse_id}', + data=request_data) + print(request_data) + if response.status not in [204, 409]: response = await response.json() error_message = response.get('message') error_code = response.get('code') logging.warning( f'Error occurred when sending stocks to [{self.marketplace.id}]: {error_message} ({error_code})') - break except Exception as e: logging.error( f'Exception occurred while sending stocks to marketplace ID [{self.marketplace.id}]: {str(e)}') + + tasks = [send_stock_chunk(chunk) for chunk in chunks] + await asyncio.gather(*tasks) diff --git a/queries/general.py b/queries/general.py index ce5ec82..134e633 100644 --- a/queries/general.py +++ b/queries/general.py @@ -1,4 +1,5 @@ -from typing import Union +from dataclasses import dataclass +from typing import Union, TypedDict from sqlalchemy import select, func, and_, cast, String, case, or_ from sqlalchemy.ext.asyncio import AsyncSession @@ -9,6 +10,12 @@ from database.sipro import * from database.sipro.enums.product import ProductRelationType +class StockData(TypedDict): + full_stock: int + article: Union[str, int] + marketplace_product: MarketplaceProduct + + def get_marketplace_suppliers_and_company_warehouses(marketplace: Marketplace): company = marketplace.company suppliers = set() @@ -30,7 +37,7 @@ async def get_stocks_data( session: AsyncSession, marketplace: Marketplace, product_ids: Union[list[int], None] = None -): +) -> List[StockData]: if not product_ids: product_ids = [] company = marketplace.company @@ -46,7 +53,7 @@ async def get_stocks_data( supplier_stock_subquery = ( select( func.greatest( - func.sum(SupplierProduct.supplier_stock) - func.coalesce(DailyStock.sold_today, 0), + func.sum(SupplierProduct.supplier_stock - SupplierProduct.sold_today), 0 ) .label('supplier_stock'), @@ -58,10 +65,6 @@ async def get_stocks_data( .join( Product ) - .outerjoin( - DailyStock, - DailyStock.product_id == SupplierProduct.product_id - ) .where( SupplierProduct.supplier_id.in_(supplier_ids) ) @@ -286,9 +289,13 @@ async def get_stocks_data( slaves_stock_subquery.c.product_id == MarketplaceProduct.product_id ) ) + print('-------------------------') + print(stmt.compile(compile_kwargs={ + 'literal_binds': True + })) result = await session.execute(stmt) marketplace_products = result.all() - result = [] + response: List[StockData] = [] for (marketplace_product, denco_article, price_purchase, @@ -301,8 +308,8 @@ async def get_stocks_data( price_recommended, is_archived) in marketplace_products: if is_archived or (sell_from_price > price_recommended): - result.append({ - 'denco_article': denco_article, + response.append({ + 'article': denco_article, 'full_stock': 0, 'marketplace_product': marketplace_product, }) @@ -328,10 +335,10 @@ async def get_stocks_data( full_stock = 0 full_stock = max([0, full_stock]) - result.append({ - 'denco_article': denco_article, + response.append({ + 'article': denco_article, 'full_stock': full_stock, 'marketplace_product': marketplace_product, }) - return result + return response diff --git a/schemas/__init__.py b/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/schemas/general.py b/schemas/general.py new file mode 100644 index 0000000..bab97b9 --- /dev/null +++ b/schemas/general.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +from pydantic import BaseModel + + +@dataclass +class StockUpdate: + product_id: int + + +class BaseSchema(BaseModel): + pass + + +class UpdateRequest(BaseSchema): + product_ids: list[int] + + +class UpdateResponse(BaseSchema): + task_id: str diff --git a/updaters/base.py b/updaters/base.py index a4feed9..2bb5e92 100644 --- a/updaters/base.py +++ b/updaters/base.py @@ -1,15 +1,41 @@ from abc import ABC, abstractmethod from typing import List +from sqlalchemy.ext.asyncio import AsyncSession + +import queries.general from database import Marketplace -from updaters.stocks_updater import StockUpdate +from marketplaces import MarketplaceApiFactory +from marketplaces.base import BaseMarketplaceApi +from queries.general import StockData +from schemas.general import StockUpdate class BaseMarketplaceUpdater(ABC): - @abstractmethod - def __init__(self, marketplace: Marketplace): - pass + marketplace: Marketplace + marketplace_api: BaseMarketplaceApi + session: AsyncSession + + def __init__(self, marketplace: Marketplace, session: AsyncSession): + self.marketplace = marketplace + self.session = session + self.marketplace_api = MarketplaceApiFactory.get_marketplace_api(marketplace) @abstractmethod - async def update(self, updates: List[StockUpdate]): + def get_update_for_marketplace(self, + stock_data: StockData) -> dict: pass + + async def update(self, updates: List[StockUpdate]): + product_ids = list(set([update.product_id for update in updates])) + stock_data_list = await queries.general.get_stocks_data( + session=self.session, + marketplace=self.marketplace, + product_ids=product_ids + ) + return + marketplace_updates = [] + for stock_data in stock_data_list: + marketplace_update = self.get_update_for_marketplace(stock_data) + marketplace_updates.append(marketplace_update) + await self.marketplace_api.update_stocks(marketplace_updates) diff --git a/updaters/factory.py b/updaters/factory.py new file mode 100644 index 0000000..19d20eb --- /dev/null +++ b/updaters/factory.py @@ -0,0 +1,18 @@ +from typing import Union + +from sqlalchemy.ext.asyncio import AsyncSession + +from database import Marketplace +from database.sipro.enums.general import BaseMarketplace +from updaters.ozon_updater import OzonUpdater +from updaters.wildberries_updater import WildberriesUpdater + + +class UpdaterFactory: + @staticmethod + def get_updater(session: AsyncSession, marketplace: Marketplace) -> Union[OzonUpdater, WildberriesUpdater]: + match marketplace.base_marketplace: + case BaseMarketplace.WILDBERRIES: + return WildberriesUpdater(marketplace, session) + case BaseMarketplace.OZON: + return OzonUpdater(marketplace, session) diff --git a/updaters/ozon_updater.py b/updaters/ozon_updater.py index 50edd94..171f947 100644 --- a/updaters/ozon_updater.py +++ b/updaters/ozon_updater.py @@ -1,14 +1,11 @@ -from typing import List - -from database import Marketplace -from marketplaces import MarketplaceFactory, OzonMarketplace +from queries.general import StockData from updaters.base import BaseMarketplaceUpdater -from updaters.stocks_updater import StockUpdate class OzonUpdater(BaseMarketplaceUpdater): - def __init__(self, marketplace: Marketplace): - self.ozon_marketplace: OzonMarketplace = MarketplaceFactory.get_marketplace(marketplace) - - async def update(self, updates: List[StockUpdate]): - pass + def get_update_for_marketplace(self, data: StockData) -> dict: + return { + 'offer_id': str(data['article']), + 'stock': 0, # $data['full_stock'], + 'warehouse_id': self.marketplace.warehouse_id + } diff --git a/updaters/stocks_updater.py b/updaters/stocks_updater.py index ae48628..17a1efc 100644 --- a/updaters/stocks_updater.py +++ b/updaters/stocks_updater.py @@ -1,53 +1,47 @@ +import asyncio from collections import defaultdict -from dataclasses import dataclass from enum import unique, IntEnum -from typing import List, Union +from typing import List from sqlalchemy import select -from sqlalchemy.dialects.postgresql import insert from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload -import database -from database import Marketplace, MarketplaceProduct, DailyStock - - -@unique -class StockUpdateType(IntEnum): - SALE = 0 - SUPPLIER_UPDATE = 1 - WAREHOUSE_UPDATE = 2 - -@dataclass -class StockUpdate: - product_id: int - type: StockUpdateType - quantity: int +from database import Marketplace, MarketplaceProduct, Warehouse, Company +from schemas.general import StockUpdate +from updaters.factory import UpdaterFactory class StocksUpdater: def __init__(self, session: AsyncSession): self.session = session + async def get_marketplace(self, marketplace_id: int): + marketplace = await self.session.get(Marketplace, marketplace_id, options=[ + joinedload(Marketplace.warehouses).joinedload(Warehouse.suppliers), + joinedload(Marketplace.warehouses).joinedload(Warehouse.company_warehouses), + joinedload(Marketplace.company).joinedload(Company.warehouse) + ]) + return marketplace + async def update_marketplace(self, marketplace_id: int, updates: List[StockUpdate]): - pass + marketplace = await self.get_marketplace(marketplace_id) + updater = UpdaterFactory.get_updater(self.session, marketplace) + if not updater: + return + await updater.update(updates) async def update(self, updates: list[StockUpdate]): updates_dict = defaultdict(list) - stock_update_values = [] for update in updates: - # Working with sold today - if update.type == StockUpdateType.SALE: - stock_update_values.append({ - 'product_id': update.product_id, - 'sold_today': update.quantity - }) # Working with marketplaces stmt = ( select( MarketplaceProduct.marketplace_id.distinct() ) .where( - MarketplaceProduct.product_id == update.product_id + MarketplaceProduct.product_id == update.product_id, + MarketplaceProduct.marketplace_id.in_([9, 41]) ) ) stmt_result = await self.session.execute(stmt) @@ -57,27 +51,9 @@ class StocksUpdater: for marketplace_id in marketplace_ids: updates_dict[marketplace_id].append(update) updates_list = list(updates_dict.items()) - updates_list = sorted(updates_list, key=lambda x: x[1]) - - # Updating DailyStock-s - insert_stmt = ( - insert( - DailyStock - ) - .values( - stock_update_values - ) - ) - insert_stmt = ( - insert_stmt.on_conflict_do_update( - index_elements=['product_id'], - set_={ - 'sold_today': DailyStock.sold_today + insert_stmt.excluded.sold_today - } - ) - ) - await self.session.execute(insert_stmt) - await self.session.commit() + updates_list = sorted(updates_list, key=lambda x: len(x[1])) + tasks = [] for marketplace_id, marketplace_updates in updates_list: - await self.update_marketplace(marketplace_id, marketplace_updates) + tasks.append(self.update_marketplace(marketplace_id, marketplace_updates)) + await asyncio.gather(*tasks) diff --git a/updaters/wildberries_updater.py b/updaters/wildberries_updater.py index 8b13789..f168276 100644 --- a/updaters/wildberries_updater.py +++ b/updaters/wildberries_updater.py @@ -1 +1,11 @@ +from queries.general import StockData +from updaters.base import BaseMarketplaceUpdater + +class WildberriesUpdater(BaseMarketplaceUpdater): + + def get_update_for_marketplace(self, stock_data: StockData) -> dict: + return { + 'sku': stock_data['marketplace_product'].third_additional_article, + 'amount': 0 # stock_data['full_stock'] + }