diff --git a/main.py b/main.py index 00d7542..150203c 100644 --- a/main.py +++ b/main.py @@ -4,17 +4,21 @@ from typing import Annotated from celery.result import AsyncResult from fastapi import FastAPI, Depends, HTTPException from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from sqlalchemy.ext.asyncio import AsyncSession from starlette import status from starlette.responses import JSONResponse import backend.config +from backend.session import get_session from background import taskiq_broker from buffer.core import TasksBuffer from schemas.general import * import background.tasks +from updaters.stocks_updater import StocksUpdater auth_schema = HTTPBearer() buffer = TasksBuffer() +SessionDependency = Annotated[AsyncSession, Depends(get_session)] async def check_auth(token: Annotated[HTTPAuthorizationCredentials, Depends(auth_schema)]): @@ -117,3 +121,13 @@ def get_status(task_id): "task_result": task_result.result } return JSONResponse(result) + + +@app.get('/marketplace/{marketplace_id}/stocks') +async def get_marketplace_stocks( + marketplace_id: int, + session: SessionDependency, + only_available: bool = False +): + updater = StocksUpdater(session) + return await updater.get_all_stocks_for_marketplace(int(marketplace_id), only_available) diff --git a/updaters/base.py b/updaters/base.py index 04954d6..967465e 100644 --- a/updaters/base.py +++ b/updaters/base.py @@ -58,6 +58,21 @@ class BaseMarketplaceUpdater(ABC): marketplace_updates.append(marketplace_update) await self.marketplace_api.update_stocks(marketplace_updates) + async def get_all_stocks(self, only_available: bool) -> List[StockData]: + if not self.marketplace_api: + return [] + stock_data_list = await queries.general.get_stocks_data( + session=self.session, + marketplace=self.marketplace + ) + if only_available: + stock_data_list = list(filter(lambda x: x["full_stock"] > 0, stock_data_list)) + for idx, stock_data in enumerate(stock_data_list): + stock_data['product_id'] = stock_data['marketplace_product'].product_id + del stock_data["marketplace_product"] + stock_data_list[idx] = stock_data + return stock_data_list + async def reset(self): if not self.marketplace_api: return diff --git a/updaters/stocks_updater.py b/updaters/stocks_updater.py index b5cdaa7..01c52f1 100644 --- a/updaters/stocks_updater.py +++ b/updaters/stocks_updater.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import joinedload, selectinload from backend.session import session_factory from database import Marketplace, MarketplaceProduct, Warehouse, Company from database.sipro.enums.general import BaseMarketplace +from queries.general import StockData from schemas.general import StockUpdate from updaters.factory import UpdaterFactory @@ -47,7 +48,7 @@ class StocksUpdater: Company.is_archived == False, Marketplace.is_deleted == False, Marketplace.is_paused == False, - Marketplace.send_stocks==True, + Marketplace.send_stocks == True, Marketplace.base_marketplace.in_([ BaseMarketplace.OZON, BaseMarketplace.WILDBERRIES, @@ -72,6 +73,11 @@ class StocksUpdater: logging.info( f"{marketplace.name} successfully fully updated in {round(time.time() - start, 2)} seconds.") + async def get_all_stocks_for_marketplace(self, marketplace_id: int, only_available: bool) -> List[StockData]: + marketplace = await self.get_marketplace(marketplace_id) + updater = UpdaterFactory.get_updater(self.session, marketplace) + return await updater.get_all_stocks(only_available) + async def full_update_all_marketplaces(self, marketplace_ids: Union[List[int], None] = None): marketplaces = await self.get_marketplaces(marketplace_ids)