diff --git a/main.py b/main.py index 1171da9..22171ca 100644 --- a/main.py +++ b/main.py @@ -41,6 +41,7 @@ routers_list = [ routers.role_router, routers.marketplace_router, routers.payroll_router, + routers.time_tracking_router, ] for router in routers_list: app.include_router(router) diff --git a/routers/__init__.py b/routers/__init__.py index b3009b4..e2e6c99 100644 --- a/routers/__init__.py +++ b/routers/__init__.py @@ -10,3 +10,4 @@ from .user import user_router from .role import role_router from .marketplace import marketplace_router from .payroll import payroll_router +from .time_tracking import time_tracking_router diff --git a/routers/time_tracking.py b/routers/time_tracking.py new file mode 100644 index 0000000..0e7239f --- /dev/null +++ b/routers/time_tracking.py @@ -0,0 +1,36 @@ +from fastapi import APIRouter + +from backend.dependecies import SessionDependency, CurrentUserDependency +from schemas.time_tracking import * +from services.time_tracking import TimeTrackingService + +time_tracking_router = APIRouter( + prefix="/time-tracking", + tags=["time-tracking"] +) + + +@time_tracking_router.post( + '/get-records', + operation_id='get_time_tracking_records', + response_model=GetTimeTrackingRecordsResponse +) +async def get_data( + session: SessionDependency, + request: GetTimeTrackingRecordsRequest +): + return await TimeTrackingService(session).get_records(request) + + +@time_tracking_router.post( + '/update-record', + operation_id='update_time_tracking_record', + response_model=UpdateTimeTrackingRecordResponse +) +async def get_data( + session: SessionDependency, + request: UpdateTimeTrackingRecordRequest, + user: CurrentUserDependency +): + return await TimeTrackingService(session).update_record(user, request) + diff --git a/schemas/time_tracking.py b/schemas/time_tracking.py new file mode 100644 index 0000000..623b225 --- /dev/null +++ b/schemas/time_tracking.py @@ -0,0 +1,45 @@ +import datetime +from typing import List + +from schemas.base import BaseSchema, OkMessageSchema +from schemas.user import UserSchema + + +# region Entities +class TimeTrackingData(BaseSchema): + date: datetime.date + hours: int + amount: int + + +class TimeTrackingRecord(BaseSchema): + user: UserSchema + total_amount: int + data: List[TimeTrackingData] + + +# endregion + + +# region Requests +class GetTimeTrackingRecordsRequest(BaseSchema): + date: datetime.date + user_ids: list[int] + + +class UpdateTimeTrackingRecordRequest(BaseSchema): + user_id: int + date: datetime.date + hours: int + + +# endregion + +# region Responses +class GetTimeTrackingRecordsResponse(BaseSchema): + records: List[TimeTrackingRecord] + + +class UpdateTimeTrackingRecordResponse(OkMessageSchema): + pass +# endregion diff --git a/services/payroll.py b/services/payroll.py index 6d18596..b0150ff 100644 --- a/services/payroll.py +++ b/services/payroll.py @@ -167,11 +167,10 @@ class PayrollService(BaseService): ) return response - async def _create_payment_record_hourly( + def get_amount( self, - creator: User, user: User, - record_schema: PaymentRecordCreateSchema + work_units: int ): pay_rate: PayRate = user.pay_rate overtime_threshold = 0 @@ -183,13 +182,22 @@ class PayrollService(BaseService): overtime_rate = pay_rate.overtime_rate if overtime_threshold == 0 or overtime_rate == 0: - base_units = record_schema.work_units + base_units = work_units overtime_units = 0 else: - overtime_units = max(0, record_schema.work_units - overtime_threshold) - base_units = record_schema.work_units - overtime_units + overtime_units = max(0, work_units - overtime_threshold) + base_units = work_units - overtime_units - amount = pay_rate.base_rate * base_units + overtime_rate * overtime_units + return pay_rate.base_rate * base_units + overtime_rate * overtime_units + + async def _create_payment_record_hourly( + self, + creator: User, + user: User, + record_schema: PaymentRecordCreateSchema + ): + pay_rate: PayRate = user.pay_rate + amount = self.get_amount(user, record_schema.work_units) payment_record_dict = record_schema.model_dump() del payment_record_dict['user'] payment_record_dict.update({ @@ -202,54 +210,14 @@ class PayrollService(BaseService): stmt = ( insert( PaymentRecord - ).values( + ) + .values( **payment_record_dict ) ) await self.session.execute(stmt) await self.session.commit() - async def _create_payment_record_daily( - self, - creator: User, - user: User, - record_schema: PaymentRecordCreateSchema - ): - pay_rate: PayRate = user.pay_rate - amount = pay_rate.base_rate * record_schema.work_units - payment_record_dict = record_schema.model_dump() - del payment_record_dict['user'] - - payment_record_dict.update({ - 'created_by_user_id': creator.id, - 'created_at': datetime.datetime.now(), - 'payroll_scheme_key': pay_rate.payroll_scheme.key, - 'amount': amount, - 'user_id': record_schema.user.id - - }) - - async def _create_payment_record_monthly( - self, - creator: User, - - user: User, - record_schema: PaymentRecordCreateSchema - ) -> CreatePaymentRecordResponse: - pay_rate: PayRate = user.pay_rate - amount = pay_rate.base_rate * record_schema.work_units - payment_record_dict = record_schema.model_dump() - del payment_record_dict['user'] - - payment_record_dict.update({ - 'created_by_user_id': creator.id, - 'created_at': datetime.datetime.now(), - 'payroll_scheme_key': pay_rate.payroll_scheme.key, - 'amount': amount, - 'user_id': record_schema.user.id - - }) - async def create_payment_record( self, request: CreatePaymentRecordRequest, @@ -282,7 +250,8 @@ class PayrollService(BaseService): stmt = ( delete( PaymentRecord - ).where( + ) + .where( PaymentRecord.id == request.payment_record_id ) ) diff --git a/services/time_tracking.py b/services/time_tracking.py new file mode 100644 index 0000000..29022d4 --- /dev/null +++ b/services/time_tracking.py @@ -0,0 +1,111 @@ +from collections import defaultdict + +from sqlalchemy import select, func +from sqlalchemy.orm import joinedload + +from models import PaymentRecord, User +from schemas.time_tracking import * +from services.base import BaseService +from services.payroll import PayrollService +from services.user import UserService + + +class TimeTrackingService(BaseService): + async def get_records(self, request: GetTimeTrackingRecordsRequest) -> GetTimeTrackingRecordsResponse: + stmt = ( + select( + PaymentRecord, + ) + .options( + joinedload( + PaymentRecord.user + ) + ) + .where( + func.date(func.date_trunc('month', PaymentRecord.start_date)) == request.date, + func.date(func.date_trunc('month', PaymentRecord.end_date)) == request.date, + PaymentRecord.start_date == PaymentRecord.end_date, + # PaymentRecord.user_id.in_(request.user_ids) + ) + ) + query_result = (await self.session.scalars(stmt)).all() + records_dict = defaultdict(list) + users_dict = {} + amount_dict = defaultdict(list) + for payment_record in query_result: + user = UserSchema.model_validate(payment_record.user) + data = TimeTrackingData( + date=payment_record.start_date, + hours=payment_record.work_units, + amount=payment_record.amount + ) + users_dict[user.id] = user + records_dict[user.id].append(data) + amount_dict[user.id].append(payment_record.amount) + + records = [] + for user_id, data_list in records_dict.items(): + amount = sum(amount_dict[user_id]) + user = users_dict[user_id] + record = TimeTrackingRecord( + user=user, + data=data_list, + total_amount=amount + ) + records.append( + record + ) + return GetTimeTrackingRecordsResponse( + records=records + ) + + async def update_record(self, + user: User, + request: UpdateTimeTrackingRecordRequest + ) -> UpdateTimeTrackingRecordResponse: + try: + record_user = await UserService(self.session).get_by_id(user_id=request.user_id) + if not record_user: + return UpdateTimeTrackingRecordResponse(ok=False, message="Указанный пользователь не найден!") + if not record_user.pay_rate: + return UpdateTimeTrackingRecordResponse(ok=False, message="У пользователя не указана схема оплаты!") + existing_record_stmt = ( + select( + PaymentRecord + ) + .where( + PaymentRecord.user_id == request.user_id, + PaymentRecord.start_date == request.date, + PaymentRecord.end_date == request.date, + ) + ) + amount = ( + PayrollService( + self.session + ) + .get_amount( + user=record_user, + work_units=request.hours + ) + ) + existing_record = await self.session.scalar(existing_record_stmt) + if existing_record: + existing_record: PaymentRecord + existing_record.work_units = request.hours + existing_record.amount = amount + else: + new_record = PaymentRecord( + user_id=request.user_id, + created_by_user_id=user.id, + start_date=request.date, + end_date=request.date, + created_at=datetime.datetime.now(), + payroll_scheme_key=record_user.pay_rate.payroll_scheme_key, + amount=amount, + work_units=request.hours + ) + self.session.add(new_record) + await self.session.commit() + return UpdateTimeTrackingRecordResponse(ok=True, message="Запись успешно обновлена") + except Exception as e: + return UpdateTimeTrackingRecordResponse(ok=False, message=str(e)) diff --git a/test.py b/test.py index 0d6471d..a4735e2 100644 --- a/test.py +++ b/test.py @@ -1,25 +1,44 @@ import asyncio +import datetime -from sqlalchemy import select +from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload from backend.session import session_maker -from models import User +from models import User, PaymentRecord async def main(): - work_units = 15 - base_rate = 0 - overtime_units = max([0, work_units - base_rate]) - base_units = work_units - overtime_units - print(overtime_units, base_units) - return session: AsyncSession = session_maker() - a = await session.scalar( - select(User).where(User.first_name == "Абид") - - ) - print(a) + try: + d = datetime.date.today() + d = d.replace(day=1) + print(d) + stmt = ( + select( + PaymentRecord + ) + .select_from(PaymentRecord) + .options( + joinedload( + PaymentRecord.user + ) + ) + .where( + func.date(func.date_trunc('month', PaymentRecord.start_date)) == d, + func.date(func.date_trunc('month', PaymentRecord.end_date)) == d, + PaymentRecord.start_date == PaymentRecord.end_date, + # PaymentRecord.user_id.in_(request.user_ids) + ) + ) + print(stmt.compile(compile_kwargs={ + 'literal_binds': True + })) + query_result = (await session.scalars(stmt)).all() + print(query_result) + except Exception as e: + print(e) await session.close()