187 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			187 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from datetime import datetime
 | 
						||
 | 
						||
import math
 | 
						||
from fastapi import HTTPException
 | 
						||
from fastapi import status
 | 
						||
from sqlalchemy import delete, select, func, update, insert, and_
 | 
						||
 | 
						||
from models import User
 | 
						||
from models.transaction import Transaction, TransactionTag
 | 
						||
from schemas.base import PaginationSchema
 | 
						||
from schemas.transaction import *
 | 
						||
from services.base import BaseService
 | 
						||
from utils.dependecies import is_valid_pagination
 | 
						||
 | 
						||
 | 
						||
class TransactionsService(BaseService):
 | 
						||
    async def get_all(self, pagination: PaginationSchema, request: GetAllTransactionsRequest) -> GetAllTransactionsResponse:
 | 
						||
        if not is_valid_pagination(pagination):
 | 
						||
            raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='Invalid pagination')
 | 
						||
        page = max(0, pagination.page - 1)
 | 
						||
        is_income = request.is_income
 | 
						||
 | 
						||
        stmt = (
 | 
						||
            select(Transaction)
 | 
						||
            .where(Transaction.is_income == is_income)
 | 
						||
            .order_by(Transaction.spent_date.desc())
 | 
						||
            .offset(page * pagination.items_per_page)
 | 
						||
            .limit(pagination.items_per_page)
 | 
						||
        )
 | 
						||
 | 
						||
        total_records = await self.session.scalar(
 | 
						||
            select(func.count())
 | 
						||
            .select_from(Transaction)
 | 
						||
            .where(Transaction.is_income == is_income)
 | 
						||
        )
 | 
						||
        if not total_records:
 | 
						||
            return GetAllTransactionsResponse(
 | 
						||
                transactions=[],
 | 
						||
                pagination_info=PaginationInfoSchema(
 | 
						||
                    total_pages=0,
 | 
						||
                    total_items=0
 | 
						||
                )
 | 
						||
            )
 | 
						||
        total_items = total_records
 | 
						||
        total_pages = math.ceil(total_records / pagination.items_per_page)
 | 
						||
 | 
						||
        transactions = await self.session.execute(stmt)
 | 
						||
        transactions = transactions.scalars().all()
 | 
						||
        response = GetAllTransactionsResponse(
 | 
						||
            transactions=transactions,
 | 
						||
            pagination_info=PaginationInfoSchema(
 | 
						||
                total_items=total_items,
 | 
						||
                total_pages=total_pages
 | 
						||
            )
 | 
						||
        )
 | 
						||
        return response
 | 
						||
 | 
						||
    async def get_by_id(self, transaction_id: int) -> Optional[Transaction]:
 | 
						||
        transaction = await self.session.get(Transaction, transaction_id)
 | 
						||
        return transaction
 | 
						||
 | 
						||
    async def add_tags(self, transaction: Transaction, tag_names: list[str]):
 | 
						||
        tags = []
 | 
						||
        for tag_name in tag_names:
 | 
						||
            existing_tag = await self.get_tag_by_name(tag_name, transaction.is_income)
 | 
						||
            if existing_tag:
 | 
						||
                tags.append(existing_tag)
 | 
						||
            else:
 | 
						||
                tag = TransactionTag(name=tag_name, is_income=transaction.is_income)
 | 
						||
                self.session.add(tag)
 | 
						||
                tags.append(tag)
 | 
						||
 | 
						||
        transaction.tags = tags
 | 
						||
 | 
						||
    async def update_transaction(self, user: User, request: UpdateTransactionRequest) -> UpdateTransactionResponse:
 | 
						||
        transaction = await self.get_by_id(request.transaction.id or -1)
 | 
						||
 | 
						||
        if not transaction:
 | 
						||
            new_transaction = Transaction(
 | 
						||
                created_at=datetime.datetime.now(),
 | 
						||
                name=request.transaction.name,
 | 
						||
                comment=request.transaction.comment,
 | 
						||
                amount=request.transaction.amount,
 | 
						||
                spent_date=request.transaction.spent_date,
 | 
						||
                created_by_user_id=user.id,
 | 
						||
                is_income=request.transaction.is_income,
 | 
						||
            )
 | 
						||
            self.session.add(new_transaction)
 | 
						||
            await self.add_tags(new_transaction, request.transaction.tags)
 | 
						||
            await self.session.commit()
 | 
						||
            return UpdateTransactionResponse(ok=True, message='Запись успешно создана')
 | 
						||
 | 
						||
        transaction.name = request.transaction.name
 | 
						||
        transaction.amount = request.transaction.amount
 | 
						||
        transaction.comment = request.transaction.comment
 | 
						||
        transaction.spent_date = request.transaction.spent_date
 | 
						||
        self.session.add(transaction)
 | 
						||
        await self.add_tags(transaction, request.transaction.tags)
 | 
						||
        await self.session.commit()
 | 
						||
        return UpdateTransactionResponse(ok=True, message='Запись успешно изменена')
 | 
						||
 | 
						||
    async def delete_transaction(self, expense_id) -> DeleteTransactionResponse:
 | 
						||
        stmt = (
 | 
						||
            delete(Transaction)
 | 
						||
            .where(Transaction.id == expense_id)
 | 
						||
        )
 | 
						||
        await self.session.execute(stmt)
 | 
						||
        await self.session.commit()
 | 
						||
        return DeleteTransactionResponse(ok=True, message='Запись успешно удалена')
 | 
						||
 | 
						||
    async def get_all_tags(self) -> GetAllTransactionTagsResponse:
 | 
						||
        stmt = (
 | 
						||
            select(TransactionTag)
 | 
						||
            .order_by(TransactionTag.id)
 | 
						||
        )
 | 
						||
        tags = await self.session.execute(stmt)
 | 
						||
        return GetAllTransactionTagsResponse(tags=tags.scalars().all())
 | 
						||
 | 
						||
    async def get_tags(self, is_income: bool) -> GetTransactionTagsResponse:
 | 
						||
        stmt = (
 | 
						||
            select(TransactionTag)
 | 
						||
            .where(TransactionTag.is_income == is_income)
 | 
						||
            .order_by(TransactionTag.id)
 | 
						||
        )
 | 
						||
        tags = await self.session.execute(stmt)
 | 
						||
        return GetTransactionTagsResponse(tags=tags.scalars().all())
 | 
						||
 | 
						||
    async def get_tag_by_id(self, tag_id: int) -> Optional[TransactionTag]:
 | 
						||
        return await self.session.get(TransactionTag, tag_id)
 | 
						||
 | 
						||
    async def get_tag_by_name(self, tag_name: str, is_income: bool) -> Optional[TransactionTag]:
 | 
						||
        stmt = (
 | 
						||
            select(TransactionTag)
 | 
						||
            .where(and_(TransactionTag.name == tag_name, TransactionTag.is_income == is_income))
 | 
						||
        )
 | 
						||
        tag = await self.session.scalar(stmt)
 | 
						||
        return tag
 | 
						||
 | 
						||
    async def create_tag(self, request: CreateTransactionTagRequest) -> CreateTransactionTagResponse:
 | 
						||
        tag = await self.get_tag_by_name(request.tag.name, request.tag.is_income)
 | 
						||
        if tag:
 | 
						||
            return CreateTransactionTagResponse(ok=False, message='Такой тег уже есть.')
 | 
						||
 | 
						||
        tag_dict = request.tag.model_dump()
 | 
						||
        stmt = (
 | 
						||
            insert(TransactionTag)
 | 
						||
            .values(**tag_dict)
 | 
						||
        )
 | 
						||
        await self.session.execute(stmt)
 | 
						||
        await self.session.commit()
 | 
						||
        return CreateTransactionTagResponse(ok=True, message='Тег успешно создан.')
 | 
						||
 | 
						||
    async def update_tag(self, request: UpdateTransactionTagRequest) -> UpdateTransactionTagResponse:
 | 
						||
        tag = await self.get_tag_by_name(request.tag.name, request.tag.is_income)
 | 
						||
        if tag:
 | 
						||
            return UpdateTransactionTagResponse(ok=False, message='Тег с таким названием уже есть.')
 | 
						||
 | 
						||
        tag = await self.get_tag_by_id(request.tag.id)
 | 
						||
        if not tag:
 | 
						||
            return UpdateTransactionTagResponse(ok=False, message='Тег не найден.')
 | 
						||
 | 
						||
        tag_dict = request.tag.model_dump()
 | 
						||
        del tag_dict['id']
 | 
						||
        stmt = (
 | 
						||
            update(TransactionTag)
 | 
						||
            .where(TransactionTag.id == request.tag.id)
 | 
						||
            .values(**tag_dict)
 | 
						||
        )
 | 
						||
        await self.session.execute(stmt)
 | 
						||
        await self.session.commit()
 | 
						||
        return UpdateTransactionTagResponse(ok=True, message='Тег успешно изменен.')
 | 
						||
 | 
						||
    async def delete_tag(self, tag_id: int) -> DeleteTransactionTagResponse:
 | 
						||
        tag = await self.get_tag_by_id(tag_id)
 | 
						||
        if not tag:
 | 
						||
            return DeleteTransactionTagResponse(ok=False, message='Тег не найден.')
 | 
						||
        if len(tag.transactions) > 0:
 | 
						||
            return DeleteTransactionTagResponse(ok=False, message='Тег прикреплен к записи о расходах/доходах.')
 | 
						||
 | 
						||
        stmt = (
 | 
						||
            delete(TransactionTag)
 | 
						||
            .where(TransactionTag.id == tag_id)
 | 
						||
        )
 | 
						||
        await self.session.execute(stmt)
 | 
						||
        await self.session.commit()
 | 
						||
        return DeleteTransactionTagResponse(ok=True, message='Тег удален')
 |