diff --git a/models/expense.py b/models/expense.py index d06cee5..8be0ff3 100644 --- a/models/expense.py +++ b/models/expense.py @@ -1,7 +1,7 @@ from datetime import datetime, date from typing import TYPE_CHECKING -from sqlalchemy import ForeignKey +from sqlalchemy import ForeignKey, Table, Column from sqlalchemy.orm import Mapped, mapped_column, relationship from models import BaseModel @@ -10,6 +10,14 @@ if TYPE_CHECKING: from models import User +expenses_expense_tags = Table( + 'expenses_expense_tags', + BaseModel.metadata, + Column('expense_id', ForeignKey('expenses.id', ondelete='CASCADE'), primary_key=True), + Column('expense_tag_id', ForeignKey('expense_tags.id'), primary_key=True), +) + + class Expense(BaseModel): __tablename__ = 'expenses' @@ -22,3 +30,22 @@ class Expense(BaseModel): created_by_user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), nullable=False) created_by_user: Mapped["User"] = relationship(foreign_keys=[created_by_user_id]) + + tags: Mapped[list["ExpenseTag"]] = relationship( + secondary=expenses_expense_tags, + lazy='selectin', + back_populates='expenses', + cascade='all, delete', + ) + + +class ExpenseTag(BaseModel): + __tablename__ = 'expense_tags' + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(unique=True) + + expenses: Mapped[list["Expense"]] = relationship( + secondary=expenses_expense_tags, + lazy='selectin', + back_populates='tags', + ) diff --git a/routers/expense.py b/routers/expense.py index 5e685b1..8ae9df7 100644 --- a/routers/expense.py +++ b/routers/expense.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Depends from backend.dependecies import SessionDependency, CurrentUserDependency, PaginationDependency -from schemas.expense import GetAllExpensesResponse, UpdateExpenseResponse, UpdateExpenseRequest, DeleteExpenseResponse +from schemas.expense import * from services.auth import authorized_user from services.expenses import ExpensesService @@ -12,6 +12,8 @@ expense_router = APIRouter( ) +# region Expenses + @expense_router.get( '/get-all', operation_id='get_all_expenses', @@ -47,3 +49,55 @@ async def delete_expense( expense_id: int, ): return await ExpensesService(session).delete_expense(expense_id) + +# endregion + +# region Expense tags + +@expense_router.get( + '/get-all-tags', + operation_id='get_all_expense_tags', + response_model=GetAllExpenseTagsResponse, +) +async def get_all( + session: SessionDependency, +): + return await ExpensesService(session).get_all_tags() + + +@expense_router.post( + '/create-tag', + operation_id='create_expense_tag', + response_model=UpdateExpenseTagResponse, +) +async def update_expense( + session: SessionDependency, + request: CreateExpenseTagRequest, +): + return await ExpensesService(session).create_tag(request) + + +@expense_router.post( + '/update-tag', + operation_id='update_expense_tag', + response_model=UpdateExpenseTagResponse, +) +async def update_expense( + session: SessionDependency, + request: UpdateExpenseTagRequest, +): + return await ExpensesService(session).update_tag(request) + + +@expense_router.delete( + '/delete-tag/{tag_id}', + operation_id='delete_expense_tag', + response_model=DeleteExpenseTagResponse, +) +async def update_expense( + session: SessionDependency, + tag_id: int, +): + return await ExpensesService(session).delete_tag(tag_id) + +# endregion \ No newline at end of file diff --git a/schemas/expense.py b/schemas/expense.py index 6e99cd8..bf62d13 100644 --- a/schemas/expense.py +++ b/schemas/expense.py @@ -7,6 +7,14 @@ from schemas.user import UserSchema # region Entities +class BaseExpenseTagSchema(BaseSchema): + name: str + + +class ExpenseTagSchema(BaseExpenseTagSchema): + id: int + + class ExpenseSchemaBase(BaseSchema): id: int name: str @@ -14,6 +22,7 @@ class ExpenseSchemaBase(BaseSchema): amount: float created_by_user: UserSchema spent_date: datetime.date + tags: list[ExpenseTagSchema] class UpdateExpenseSchema(BaseSchema): @@ -22,6 +31,7 @@ class UpdateExpenseSchema(BaseSchema): comment: Optional[str] = "" amount: float spent_date: datetime.date + tags: list[str] = [] # endregion @@ -33,6 +43,14 @@ class UpdateExpenseRequest(BaseSchema): expense: UpdateExpenseSchema +class CreateExpenseTagRequest(BaseSchema): + tag: BaseExpenseTagSchema + + +class UpdateExpenseTagRequest(BaseSchema): + tag: ExpenseTagSchema + + # endregion # region Responses @@ -49,4 +67,20 @@ class UpdateExpenseResponse(OkMessageSchema): class DeleteExpenseResponse(OkMessageSchema): pass + +class GetAllExpenseTagsResponse(BaseSchema): + tags: list[ExpenseTagSchema] + + +class CreateExpenseTagResponse(OkMessageSchema): + pass + + +class UpdateExpenseTagResponse(OkMessageSchema): + pass + + +class DeleteExpenseTagResponse(OkMessageSchema): + pass + # endregion diff --git a/schemas/statistics.py b/schemas/statistics.py index c4fccc3..0c5fd86 100644 --- a/schemas/statistics.py +++ b/schemas/statistics.py @@ -1,5 +1,4 @@ import datetime -from optparse import Option from typing import List, Tuple, Optional from enums.profit_table_group_by import ProfitTableGroupBy @@ -32,6 +31,7 @@ class CommonProfitFilters(BaseSchema): base_marketplace_key: str deal_status_id: int manager_id: int + tag_id: int class GetProfitChartDataRequest(CommonProfitFilters): diff --git a/services/expenses.py b/services/expenses.py index 8e601bb..a7ad6d0 100644 --- a/services/expenses.py +++ b/services/expenses.py @@ -1,15 +1,14 @@ from datetime import datetime -from typing import Optional import math from fastapi import HTTPException -from sqlalchemy import delete, select, func from fastapi import status +from sqlalchemy import delete, select, func, update, insert from models import User -from models.expense import Expense -from schemas.base import PaginationSchema, PaginationInfoSchema -from schemas.expense import UpdateExpenseResponse, UpdateExpenseRequest, DeleteExpenseResponse, GetAllExpensesResponse +from models.expense import Expense, ExpenseTag +from schemas.base import PaginationSchema +from schemas.expense import * from services.base import BaseService from utils.dependecies import is_valid_pagination @@ -47,23 +46,37 @@ class ExpensesService(BaseService): ) return response - async def get_by_id(self, expense_id) -> Optional[Expense]: + async def get_by_id(self, expense_id: int) -> Optional[Expense]: expense = await self.session.get(Expense, expense_id) return expense + async def add_tags(self, expense: Expense, tag_names: list[str]): + tags = [] + for tag_name in tag_names: + existing_tag = await self.get_tag_by_name(tag_name) + if existing_tag: + tags.append(existing_tag) + else: + tag = ExpenseTag(name=tag_name) + self.session.add(tag) + tags.append(tag) + + expense.tags = tags + async def update_expense(self, user: User, request: UpdateExpenseRequest) -> UpdateExpenseResponse: - expense = await self.get_by_id(request.expense.id) + expense = await self.get_by_id(request.expense.id or -1) if not expense: - expense = Expense( - created_at=datetime.now(), + new_expense = Expense( + created_at=datetime.datetime.now(), name=request.expense.name, comment=request.expense.comment, amount=request.expense.amount, spent_date=request.expense.spent_date, created_by_user_id=user.id, ) - self.session.add(expense) + self.session.add(new_expense) + await self.add_tags(new_expense, request.expense.tags) await self.session.commit() return UpdateExpenseResponse(ok=True, message='Запись о расходах успешно создана') @@ -72,6 +85,7 @@ class ExpensesService(BaseService): expense.comment = request.expense.comment expense.spent_date = request.expense.spent_date self.session.add(expense) + await self.add_tags(expense, request.expense.tags) await self.session.commit() return UpdateExpenseResponse(ok=True, message='Запись о расходах успешно изменена') @@ -83,3 +97,71 @@ class ExpensesService(BaseService): await self.session.execute(stmt) await self.session.commit() return DeleteExpenseResponse(ok=True, message='Запись о расходах успешно удалена') + + async def get_all_tags(self) -> GetAllExpenseTagsResponse: + stmt = ( + select(ExpenseTag) + .order_by(ExpenseTag.id) + ) + tags = await self.session.execute(stmt) + return GetAllExpenseTagsResponse(tags=tags.scalars().all()) + + async def get_tag_by_id(self, expense_tag_id: int) -> Optional[ExpenseTag]: + return await self.session.get(ExpenseTag, expense_tag_id) + + async def get_tag_by_name(self, expense_tag_name: str) -> Optional[ExpenseTag]: + stmt = ( + select(ExpenseTag) + .where(ExpenseTag.name == expense_tag_name) + ) + tag = await self.session.scalar(stmt) + return tag + + async def create_tag(self, request: CreateExpenseTagRequest) -> CreateExpenseTagResponse: + tag = await self.get_tag_by_name(request.tag.name) + if tag: + return UpdateExpenseResponse(ok=False, message='Ошибка. Такой тег уже есть.') + + tag_dict = request.tag.model_dump() + stmt = ( + insert(ExpenseTag) + .values(**tag_dict) + ) + await self.session.execute(stmt) + await self.session.commit() + return UpdateExpenseResponse(ok=True, message='Тег успешно создан.') + + async def update_tag(self, request: UpdateExpenseTagRequest) -> UpdateExpenseTagResponse: + tag = await self.get_tag_by_name(request.tag.name) + if tag: + return UpdateExpenseTagResponse(ok=False, message='Ошибка. Тег с таким названием уже есть.') + + tag = await self.get_tag_by_id(request.tag.id) + if not tag: + return UpdateExpenseTagResponse(ok=False, message='Ошибка. Тег не найден.') + + tag_dict = request.tag.model_dump() + del tag_dict['id'] + stmt = ( + update(ExpenseTag) + .where(ExpenseTag.id == request.tag.id) + .values(**tag_dict) + ) + await self.session.execute(stmt) + await self.session.commit() + return UpdateExpenseResponse(ok=True, message='Тег успешно изменен.') + + async def delete_tag(self, tag_id: int) -> DeleteExpenseTagResponse: + tag = await self.get_tag_by_id(tag_id) + if not tag: + return DeleteExpenseTagResponse(ok=False, message='Ошибка. Тег не найден.') + if len(tag.expenses) > 0: + return DeleteExpenseTagResponse(ok=False, message='Ошибка. Тег прикреплен к записи о расходах.') + + stmt = ( + delete(ExpenseTag) + .where(ExpenseTag.id == tag_id) + ) + await self.session.execute(stmt) + await self.session.commit() + return DeleteExpenseTagResponse(ok=True, message='Тег удален') diff --git a/services/statistics/expenses_statistics.py b/services/statistics/expenses_statistics.py index 00cb3bf..40304fe 100644 --- a/services/statistics/expenses_statistics.py +++ b/services/statistics/expenses_statistics.py @@ -1,7 +1,10 @@ from datetime import date -from sqlalchemy import select, func, Subquery, cast + +from sqlalchemy import select, func, Subquery, cast, CTE from sqlalchemy.dialects.postgresql import TIMESTAMP -from models import PaymentRecord, Expense + +from models import PaymentRecord, Expense, expenses_expense_tags +from schemas.statistics import CommonProfitFilters from services.base import BaseService from services.statistics.common import generate_date_range @@ -10,19 +13,9 @@ class ExpensesStatisticsService(BaseService): date_from: date date_to: date - def _get_expenses_sub(self, model, date_column, amount_column) -> Subquery: - all_dates = generate_date_range(self.date_from, self.date_to, ["expenses"]) - - expenses = ( - select( - func.sum(getattr(model, amount_column)).label("expenses"), - cast(getattr(model, date_column), TIMESTAMP(timezone=False)).label("date"), - ) - .group_by("date") - .subquery() - ) - - expenses_with_gaps_filled = ( + @staticmethod + def _fill_date_gaps(expenses: Subquery, all_dates: CTE) -> Subquery: + return ( select( all_dates.c.date, (all_dates.c.expenses + func.coalesce(expenses.c.expenses, 0)).label("expenses"), @@ -31,9 +24,53 @@ class ExpensesStatisticsService(BaseService): .order_by(all_dates.c.date) .subquery() ) - return expenses_with_gaps_filled - def _apply_expenses(self, deals_by_dates: Subquery, expenses_subquery: Subquery): + def _get_payment_records_sub(self) -> Subquery: + all_dates = generate_date_range(self.date_from, self.date_to, ["expenses"]) + + expenses = ( + select( + func.sum(PaymentRecord.amount).label("expenses"), + cast(PaymentRecord.start_date, TIMESTAMP(timezone=False)).label("date"), + ) + .group_by("date") + .subquery() + ) + + expenses_with_filled_gaps = self._fill_date_gaps(expenses, all_dates) + return expenses_with_filled_gaps + + def _get_additional_expenses_sub(self, tag_id: int) -> Subquery: + all_dates = generate_date_range(self.date_from, self.date_to, ["expenses"]) + + expenses = ( + select(Expense) + ) + + if tag_id != -1: + expenses = ( + expenses + .join(expenses_expense_tags) + .where(expenses_expense_tags.c.expense_tag_id == tag_id) + ) + + expenses = expenses.subquery() + + expenses = ( + select( + func.sum(expenses.c.amount).label("expenses"), + cast(expenses.c.spent_date, TIMESTAMP(timezone=False)).label("date"), + ) + .where(expenses.c.spent_date.between(self.date_from, self.date_to)) + .group_by("date") + .subquery() + ) + + expenses_with_filled_gaps = self._fill_date_gaps(expenses, all_dates) + return expenses_with_filled_gaps + + @staticmethod + def _apply_expenses(deals_by_dates: Subquery, expenses_subquery: Subquery): return ( select( deals_by_dates.c.date, @@ -46,15 +83,15 @@ class ExpensesStatisticsService(BaseService): .join(expenses_subquery, expenses_subquery.c.date == deals_by_dates.c.date) ) - def apply_expenses(self, date_from: date, date_to: date, deals_by_dates: Subquery): - self.date_from, self.date_to = date_from, date_to + def apply_expenses(self, filters: CommonProfitFilters, deals_by_dates: Subquery): + self.date_from, self.date_to = filters.date_range # Apply salary expenses - salary_expenses = self._get_expenses_sub(PaymentRecord, "start_date", "amount") + salary_expenses = self._get_payment_records_sub() deals_by_dates = self._apply_expenses(deals_by_dates, salary_expenses) # Apply additional expenses - additional_expenses = self._get_expenses_sub(Expense, "spent_date", "amount") + additional_expenses = self._get_additional_expenses_sub(filters.tag_id) deals_by_dates = self._apply_expenses(deals_by_dates, additional_expenses) return deals_by_dates diff --git a/services/statistics/profit_statistics.py b/services/statistics/profit_statistics.py index dfded38..ed3891f 100644 --- a/services/statistics/profit_statistics.py +++ b/services/statistics/profit_statistics.py @@ -294,8 +294,7 @@ class ProfitStatisticsService(BaseService): expenses_statistics_service = ExpensesStatisticsService(self.session) stmt_deals_applied_expenses = expenses_statistics_service.apply_expenses( - self.date_from, - self.date_to, + self.filters, sub_deals_grouped_by_date ) @@ -305,6 +304,7 @@ class ProfitStatisticsService(BaseService): async def _get_data_grouped_by_date(self, request: CommonProfitFilters, is_chart: bool = True): self.date_from, self.date_to = request.date_range + self.filters = request sub_deals_dates = self._get_deals_dates(request.deal_status_id) @@ -345,6 +345,7 @@ class ProfitStatisticsService(BaseService): def _get_common_table_grouped(self, request: GetProfitTableDataRequest): self.date_from, self.date_to = request.date_range + self.filters = request sub_deals_dates = self._get_deals_dates(request.deal_status_id)