diff --git a/background/tasks.py b/background/tasks.py index 3e16b20..a614df7 100644 --- a/background/tasks.py +++ b/background/tasks.py @@ -17,6 +17,12 @@ def update_marketplace(marketplace_id: int): return loop.run_until_complete(background.update.update_marketplace(marketplace_id)) +@celery.task(name='update_marketplace_products') +def update_marketplace_products(marketplace_id: int, product_ids: list[int]): + loop = asyncio.get_event_loop() + return loop.run_until_complete(background.update.update_marketplace_products(marketplace_id, product_ids)) + + @celery.task(name='update_marketplaces') def update_marketplaces(marketplace_ids: Union[List[int], None]): loop = asyncio.get_event_loop() diff --git a/background/update.py b/background/update.py index c074d2d..963e98f 100644 --- a/background/update.py +++ b/background/update.py @@ -20,6 +20,14 @@ async def update_marketplace(marketplace_id: int): return {'message': f'Stocks for marketplace {marketplace_id} successfully updated'} +async def update_marketplace_products(marketplace_id: int, product_ids: list[int]): + async with session_factory() as session: + updater = StocksUpdater(session) + await updater.update_marketplace_products(marketplace_id, product_ids) + return { + 'message': f'Products [{",".join(list(map(str, product_ids)))}] successfully updated for marketplace {marketplace_id}'} + + async def update_marketplaces(marketplace_ids: Union[List[int], None]): async with session_factory() as session: updater = StocksUpdater(session) diff --git a/main.py b/main.py index 9d7ce99..e248544 100644 --- a/main.py +++ b/main.py @@ -9,7 +9,8 @@ from starlette.responses import JSONResponse import backend.config import background.tasks from background.tasks import * -from schemas.general import UpdateRequest, UpdateResponse, UpdateMarketplaceRequest, UpdateMarketplacesRequest +from schemas.general import UpdateRequest, UpdateResponse, UpdateMarketplaceRequest, UpdateMarketplacesRequest, \ + UpdateMarketplaceProductsRequest auth_schema = HTTPBearer() @@ -52,6 +53,14 @@ async def update_marketplace( return UpdateResponse(task_id=task.id) +@app.post('/update/marketplace/products') +async def update_marketplace_products( + request: UpdateMarketplaceProductsRequest +): + task = background.tasks.update_marketplace_products.delay(request.marketplace_id, request.product_ids) + return UpdateResponse(task_id=task.id) + + @app.post('/update/marketplaces') async def update_marketplace( request: UpdateMarketplacesRequest diff --git a/schemas/general.py b/schemas/general.py index 118ea75..eb440f4 100644 --- a/schemas/general.py +++ b/schemas/general.py @@ -21,6 +21,10 @@ class UpdateMarketplaceRequest(BaseSchema): marketplace_id: int +class UpdateMarketplaceProductsRequest(UpdateMarketplaceRequest): + product_ids: List[int] + + class UpdateMarketplacesRequest(BaseSchema): marketplace_ids: Union[List[int], None] = None diff --git a/updaters/base.py b/updaters/base.py index 135e3b1..51e7ecc 100644 --- a/updaters/base.py +++ b/updaters/base.py @@ -31,6 +31,9 @@ class BaseMarketplaceUpdater(ABC): if not self.marketplace_api: return product_ids = list(set([update.product_id for update in updates])) + await self.update_products(product_ids) + + async def update_products(self, product_ids: list[int]): stock_data_list = await queries.general.get_stocks_data( session=self.session, marketplace=self.marketplace, diff --git a/updaters/stocks_updater.py b/updaters/stocks_updater.py index 65ef552..8d303f4 100644 --- a/updaters/stocks_updater.py +++ b/updaters/stocks_updater.py @@ -19,7 +19,7 @@ class StocksUpdater: def __init__(self, session: AsyncSession): self.session = session - async def get_marketplace(self, marketplace_id: int): + async def get_marketplace(self, marketplace_id: int) -> Marketplace: marketplace = await self.session.get(Marketplace, marketplace_id, options=[ joinedload(Marketplace.warehouses).joinedload(Warehouse.suppliers), joinedload(Marketplace.warehouses).joinedload(Warehouse.company_warehouses), @@ -94,6 +94,14 @@ class StocksUpdater: logging.info( f"Successfully uploaded {len(updates)} updates to {marketplace.name} in {round(time.time() - start, 2)} seconds.") + async def update_marketplace_products(self, marketplace_id: int, product_ids: list[int]): + marketplace = await self.get_marketplace(marketplace_id) + start = time.time() + updater = UpdaterFactory.get_updater(self.session, marketplace) + await updater.update_products(product_ids) + logging.info( + f"Successfully updated {len(product_ids)} products for {marketplace.name} in {round(time.time() - start, 2)} seconds.") + async def update(self, updates: list[StockUpdate]): updates_dict = defaultdict(list) for update in updates: