diff --git a/routers/deal.py b/routers/deal.py index 279ed0e..e1dda39 100644 --- a/routers/deal.py +++ b/routers/deal.py @@ -132,3 +132,4 @@ async def get_deal_by_id( session: Annotated[AsyncSession, Depends(get_session)] ): return await DealService(session).get_by_id(deal_id) + diff --git a/routers/product.py b/routers/product.py index 70dae9f..1359a3e 100644 --- a/routers/product.py +++ b/routers/product.py @@ -1,4 +1,4 @@ -from typing import Annotated +from typing import Annotated, Union from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession @@ -38,6 +38,7 @@ async def delete_product( ): return await ProductService(session).delete(request) + @product_router.post( '/update', response_model=ProductUpdateResponse, @@ -49,6 +50,7 @@ async def delete_product( ): return await ProductService(session).update(request) + @product_router.get( '/get', response_model=ProductGetResponse, diff --git a/schemas/base.py b/schemas/base.py index 584bb48..62c6f5f 100644 --- a/schemas/base.py +++ b/schemas/base.py @@ -31,8 +31,8 @@ class OkMessageSchema(CustomModelCamel): class PaginationSchema(CustomModelCamel): - page: int - items_per_page: int + page: int | None = None + items_per_page: int | None = None class PaginationInfoSchema(CustomModelCamel): diff --git a/schemas/deal.py b/schemas/deal.py index e2915d0..bc53ec3 100644 --- a/schemas/deal.py +++ b/schemas/deal.py @@ -3,6 +3,7 @@ from typing import List from schemas.base import CustomModelCamel, OkMessageSchema from schemas.client import ClientSchema +from schemas.product import ProductSchema from schemas.service import ServiceSchema @@ -28,6 +29,11 @@ class DealServiceSchema(CustomModelCamel): quantity: int +class DealProductSchema(CustomModelCamel): + product: ProductSchema + quantity: int + + class DealSchema(CustomModelCamel): id: int name: str @@ -35,6 +41,7 @@ class DealSchema(CustomModelCamel): created_at: datetime.datetime current_status: int services: List[DealServiceSchema] + products: List[DealProductSchema] # total_price: int diff --git a/schemas/product.py b/schemas/product.py index 7913d72..e188839 100644 --- a/schemas/product.py +++ b/schemas/product.py @@ -4,12 +4,16 @@ from schemas.base import CustomModelCamel, PaginationInfoSchema, OkMessageSchema # region Entities +class ProductBarcodeSchema(CustomModelCamel): + barcode: str + + class ProductSchema(CustomModelCamel): id: int name: str article: str client_id: int - barcodes: list[str] + barcodes: list[ProductBarcodeSchema] # endregion diff --git a/services/deal.py b/services/deal.py index 0f4a4c5..e5f984c 100644 --- a/services/deal.py +++ b/services/deal.py @@ -172,7 +172,12 @@ class DealService(BaseService): joinedload(Deal.client), selectinload(Deal.services) .joinedload(models.secondary.DealService.service) - .joinedload(Service.category)) + .joinedload(Service.category), + selectinload(Deal.products) + .joinedload(models.secondary.DealProduct.product) + .joinedload(models.Product.client) + .selectinload(models.Product.barcodes) + ) .where(Deal.id == deal_id) ) if not deal: diff --git a/services/product.py b/services/product.py index 928b224..9006fb0 100644 --- a/services/product.py +++ b/services/product.py @@ -6,9 +6,11 @@ from models.product import Product, ProductBarcode from schemas.base import PaginationSchema from services.base import BaseService from schemas.product import * +from utils.dependecies import is_valid_pagination class ProductService(BaseService): + async def create(self, request: ProductCreateRequest) -> ProductCreateResponse: # Unique article validation existing_product_query = await self.session.execute( @@ -85,23 +87,39 @@ class ProductService(BaseService): async def get_by_client_id(self, client_id: int, pagination: PaginationSchema) -> ProductGetResponse: stmt = ( select(Product) - .options(selectinload(Product.barcodes)) + .options(selectinload(Product.barcodes) + .noload(ProductBarcode.product)) .where(Product.client_id == client_id) .order_by(Product.id) ) - total_products_query = await self.session.execute( - select( - func.cast(func.ceil(func.count() / pagination.items_per_page), Integer), - func.count() + if is_valid_pagination(pagination): + total_products_query = await self.session.execute( + select( + func.cast(func.ceil(func.count() / pagination.items_per_page), Integer), + func.count() + ) + .select_from(stmt.subquery()) ) - .select_from(stmt.subquery()) - ) - total_pages, total_items = total_products_query.first() + total_pages, total_items = total_products_query.first() + else: + total_items_query = await self.session.execute( + select(func.count()) + .select_from(stmt.subquery()) + ) + total_items = total_items_query.scalar() + total_pages = 1 pagination_info = PaginationInfoSchema(total_pages=total_pages, total_items=total_items) + + if is_valid_pagination(pagination): + stmt = ( + stmt + .offset(pagination.page * pagination.items_per_page) + .limit(pagination.items_per_page) + ) + query = await self.session.execute( stmt - .offset(pagination.page * pagination.items_per_page) - .limit(pagination.items_per_page) + .order_by(Product.id) ) products: list[ProductSchema] = [] for product in query.scalars().all(): diff --git a/test/test.py b/test/test.py index 528a0ac..a7a9cc7 100644 --- a/test/test.py +++ b/test/test.py @@ -7,7 +7,7 @@ from models import Product, ProductBarcode async def main(session: AsyncSession): - client_ids = [2, 4] + client_ids = [8, 18] for client_id in client_ids: for i in range(1, 500 + 1): product = Product( diff --git a/utils/dependecies.py b/utils/dependecies.py index 74876dd..4e82902 100644 --- a/utils/dependecies.py +++ b/utils/dependecies.py @@ -1,5 +1,14 @@ from schemas.base import PaginationSchema -async def pagination_parameters(page: int, items_per_page: int) -> PaginationSchema: +async def pagination_parameters(page: int | None = None, items_per_page: int | None = None) -> PaginationSchema: return PaginationSchema(page=page, items_per_page=items_per_page) + + +def is_valid_pagination(pagination: PaginationSchema | None) -> bool: + if not pagination: + return False + return all([ + isinstance(pagination.items_per_page, int), + isinstance(pagination.page, int) + ])