133 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			133 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import asyncio
 | 
						|
import logging
 | 
						|
import time
 | 
						|
from collections import defaultdict
 | 
						|
from typing import List, Union
 | 
						|
 | 
						|
from sqlalchemy import select, or_
 | 
						|
from sqlalchemy.ext.asyncio import AsyncSession
 | 
						|
from sqlalchemy.orm import joinedload, selectinload
 | 
						|
 | 
						|
from backend.session import session_factory
 | 
						|
from database import Marketplace, MarketplaceProduct, Warehouse, Company
 | 
						|
from database.sipro.enums.general import BaseMarketplace
 | 
						|
from schemas.general import StockUpdate
 | 
						|
from updaters.factory import UpdaterFactory
 | 
						|
 | 
						|
 | 
						|
class StocksUpdater:
 | 
						|
    def __init__(self, session: AsyncSession):
 | 
						|
        self.session = session
 | 
						|
 | 
						|
    async def get_marketplace(self, marketplace_id: int) -> Marketplace:
 | 
						|
        marketplace = await self.session.get(Marketplace, marketplace_id, options=[
 | 
						|
            joinedload(Marketplace.warehouses).joinedload(Warehouse.suppliers),
 | 
						|
            joinedload(Marketplace.warehouses).joinedload(Warehouse.company_warehouses),
 | 
						|
            joinedload(Marketplace.company).joinedload(Company.warehouse)
 | 
						|
        ])
 | 
						|
        return marketplace
 | 
						|
 | 
						|
    async def get_marketplaces(self, marketplace_ids: Union[list[int], None] = None) -> List[Marketplace]:
 | 
						|
        if not marketplace_ids:
 | 
						|
            marketplace_ids = []
 | 
						|
        stmt = (
 | 
						|
            select(
 | 
						|
                Marketplace
 | 
						|
            )
 | 
						|
            .join(
 | 
						|
                Company
 | 
						|
            )
 | 
						|
            .options(
 | 
						|
                selectinload(Marketplace.warehouses).selectinload(Warehouse.suppliers),
 | 
						|
                selectinload(Marketplace.warehouses).selectinload(Warehouse.company_warehouses),
 | 
						|
                joinedload(Marketplace.company).joinedload(Company.warehouse)
 | 
						|
            )
 | 
						|
            .where(
 | 
						|
                Company.is_deleted == False,
 | 
						|
                Company.is_archived == False,
 | 
						|
                Marketplace.is_deleted == False,
 | 
						|
                Marketplace.base_marketplace.in_([
 | 
						|
                    BaseMarketplace.OZON,
 | 
						|
                    BaseMarketplace.WILDBERRIES,
 | 
						|
                    BaseMarketplace.YANDEX_MARKET
 | 
						|
                ]),
 | 
						|
                or_(
 | 
						|
                    marketplace_ids == [],
 | 
						|
                    Marketplace.id.in_(marketplace_ids)
 | 
						|
                )
 | 
						|
            )
 | 
						|
        )
 | 
						|
        query_result = await self.session.scalars(stmt)
 | 
						|
        return query_result.all()
 | 
						|
 | 
						|
    async def full_update_marketplace(self, marketplace_id: int):
 | 
						|
        marketplace = await self.get_marketplace(marketplace_id)
 | 
						|
        start = time.time()
 | 
						|
        updater = UpdaterFactory.get_updater(self.session, marketplace)
 | 
						|
        await updater.update_all()
 | 
						|
        logging.info(
 | 
						|
            f"{marketplace.name} successfully fully updated in {round(time.time() - start, 2)} seconds.")
 | 
						|
 | 
						|
    async def full_update_all_marketplaces(self, marketplace_ids: Union[List[int], None] = None):
 | 
						|
        marketplaces = await self.get_marketplaces(marketplace_ids)
 | 
						|
 | 
						|
        async def update_marketplace(marketplace):
 | 
						|
            async with session_factory() as session:
 | 
						|
                start = time.time()
 | 
						|
 | 
						|
                updater = UpdaterFactory.get_updater(session, marketplace)
 | 
						|
                await updater.update_all()
 | 
						|
                logging.info(
 | 
						|
                    f"{marketplace.name} successfully fully updated in {round(time.time() - start, 2)} seconds.")
 | 
						|
 | 
						|
        tasks = [update_marketplace(marketplace) for marketplace in marketplaces]
 | 
						|
        await asyncio.gather(*tasks)
 | 
						|
 | 
						|
    async def update_marketplace(self, marketplace_id: int, updates: List[StockUpdate]):
 | 
						|
        marketplace = await self.get_marketplace(marketplace_id)
 | 
						|
        async with session_factory() as session:
 | 
						|
            start = time.time()
 | 
						|
            updater = UpdaterFactory.get_updater(session, marketplace)
 | 
						|
            if not updater:
 | 
						|
                return
 | 
						|
            await updater.update(updates)
 | 
						|
            logging.info(
 | 
						|
                f"Successfully uploaded {len(updates)} updates to {marketplace.name} in {round(time.time() - start, 2)} seconds.")
 | 
						|
 | 
						|
    async def update_marketplace_products(self, marketplace_id: int, product_ids: list[int]):
 | 
						|
        marketplace = await self.get_marketplace(marketplace_id)
 | 
						|
        start = time.time()
 | 
						|
        updater = UpdaterFactory.get_updater(self.session, marketplace)
 | 
						|
        await updater.update_products(product_ids)
 | 
						|
        logging.info(
 | 
						|
            f"Successfully updated {len(product_ids)} products for {marketplace.name} in {round(time.time() - start, 2)} seconds.")
 | 
						|
 | 
						|
    async def update(self, updates: list[StockUpdate]):
 | 
						|
        updates_dict = defaultdict(list)
 | 
						|
        for update in updates:
 | 
						|
            stmt = (
 | 
						|
                select(
 | 
						|
                    MarketplaceProduct.marketplace_id.distinct()
 | 
						|
                )
 | 
						|
                .join(Marketplace)
 | 
						|
                .join(Company)
 | 
						|
                .where(
 | 
						|
                    MarketplaceProduct.product_id == update.product_id,
 | 
						|
                    Marketplace.is_deleted == False,
 | 
						|
                    Company.is_deleted == False,
 | 
						|
                    Company.is_archived == False
 | 
						|
                )
 | 
						|
            )
 | 
						|
            stmt_result = await self.session.execute(stmt)
 | 
						|
            marketplace_ids = stmt_result.scalars().all()
 | 
						|
            if not marketplace_ids:
 | 
						|
                continue
 | 
						|
            for marketplace_id in marketplace_ids:
 | 
						|
                updates_dict[marketplace_id].append(update)
 | 
						|
        updates_list = list(updates_dict.items())
 | 
						|
        updates_list = sorted(updates_list, key=lambda x: len(x[1]))
 | 
						|
        tasks = []
 | 
						|
        for marketplace_id, marketplace_updates in updates_list:
 | 
						|
            tasks.append(self.update_marketplace(marketplace_id, marketplace_updates))
 | 
						|
        await asyncio.gather(*tasks)
 |