Files
Fulfillment-Backend/services/payroll.py

263 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import datetime
import math
from typing import Optional
from fastapi import HTTPException
from sqlalchemy import select, insert, update, delete, func
from sqlalchemy.orm import joinedload
from starlette import status
import enums.payroll
from models import PayrollScheme, PayRate, user_pay_rate, PaymentRecord, User
from schemas.base import PaginationSchema, PaginationInfoSchema
from schemas.finances import CreatePayRateRequest, UpdatePayRateRequest, DeletePayRateRequest, \
GetAllPayrollSchemeResponse, GetAllPayRatesResponse, CreatePayRateResponse, UpdatePayRateResponse, \
DeletePayRateResponse
from schemas.payment_record import GetPaymentRecordsResponse, PaymentRecordGetSchema, CreatePaymentRecordRequest, \
CreatePaymentRecordResponse, PaymentRecordCreateSchema, DeletePaymentRecordRequest, DeletePaymentRecordResponse
from services.base import BaseService
from schemas.payroll import *
from utils.dependecies import is_valid_pagination
class PayrollService(BaseService):
async def get_all_schemas(self) -> GetAllPayrollSchemeResponse:
stmt = (select(PayrollScheme).order_by(PayrollScheme.key))
payroll_schemas = (await self.session.scalars(stmt)).all()
return GetAllPayrollSchemeResponse(
payroll_schemas=PayrollSchemeSchema.from_orm_list(payroll_schemas)
)
async def get_all_pay_rates(self) -> GetAllPayRatesResponse:
stmt = (
select(
PayRate
).order_by(
PayRate.id
)
.options(
joinedload(PayRate.payroll_scheme)
)
)
pay_rates = (await self.session.scalars(stmt)).all()
return GetAllPayRatesResponse(
pay_rates=pay_rates
)
async def create_pay_rate(self, request: CreatePayRateRequest) -> CreatePayRateResponse:
try:
# Preventing duplicate by name
if await self.session.scalar(select(PayRate).where(PayRate.name == request.data.name)):
return CreatePayRateResponse(ok=False, message="Тариф с таким названием уже существует")
pay_rate_dict = request.data.model_dump()
del pay_rate_dict['payroll_scheme']
pay_rate_dict['payroll_scheme_key'] = request.data.payroll_scheme.key
stmt = (
insert(PayRate)
.values(**pay_rate_dict)
)
await self.session.execute(stmt)
await self.session.commit()
return CreatePayRateResponse(ok=True, message='Тариф успешно создан')
except Exception as e:
return CreatePayRateResponse(ok=False, message=str(e))
async def update_pay_rate(self, request: UpdatePayRateRequest) -> UpdatePayRateResponse:
try:
# Preventing duplicate by name
stmt = (
select(
PayRate
).where(
PayRate.id == request.data.id
)
)
pay_rate = await self.session.scalar(stmt)
if not pay_rate:
return CreatePayRateResponse(ok=False, message="Указанный тариф несуществует")
pay_rate_dict = request.data.model_dump()
del pay_rate_dict['payroll_scheme']
pay_rate_dict['payroll_scheme_key'] = request.data.payroll_scheme.key
stmt = (
update(PayRate)
.values(**pay_rate_dict).where(
PayRate.id == request.data.id
)
)
await self.session.execute(stmt)
await self.session.commit()
return CreatePayRateResponse(ok=True, message='Тариф успешно обновлен')
except Exception as e:
return CreatePayRateResponse(ok=False, message=str(e))
async def delete_pay_rate(self, request: DeletePayRateRequest) -> DeletePayRateResponse:
try:
user_pay_rate_record = await (
self.session.scalar(
select(
user_pay_rate
)
.where(
user_pay_rate.c.pay_rate_id == request.pay_rate_id
)
)
)
if user_pay_rate_record:
return DeletePayRateResponse(ok=False, message="Указанный тариф привязан к пользователю")
stmt = (
delete(
PayRate
)
.where(
PayRate.id == request.pay_rate_id
)
)
await self.session.execute(stmt)
await self.session.commit()
return DeletePayRateResponse(ok=True, message="Тариф успешно удален")
except Exception as e:
return DeletePayRateResponse(ok=False, message=str(e))
async def get_payment_records(self, pagination: PaginationSchema) -> GetPaymentRecordsResponse:
if not is_valid_pagination(pagination):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='Invalid pagination')
page = max(0, pagination.page - 1)
stmt = (
select(
PaymentRecord
)
.options(
joinedload(PaymentRecord.payroll_scheme),
joinedload(PaymentRecord.user).noload(User.payment_records),
joinedload(PaymentRecord.created_by_user),
)
.order_by(
PaymentRecord.created_at.desc()
)
.offset(
page * pagination.items_per_page
)
.limit(
pagination.items_per_page
)
)
total_records = await self.session.scalar(select(func.count()).select_from(PaymentRecord))
if not total_records:
return GetPaymentRecordsResponse(
payment_records=[],
pagination_info=PaginationInfoSchema(
total_pages=0,
total_items=0
)
)
total_items = total_records
total_pages = math.ceil(total_records / pagination.items_per_page)
payment_records = (await self.session.scalars(stmt)).all()
response = GetPaymentRecordsResponse(
payment_records=PaymentRecordGetSchema.from_orm_list(payment_records),
pagination_info=PaginationInfoSchema(
total_items=total_items,
total_pages=total_pages
)
)
return response
def get_amount(
self,
user: User,
work_units: float
):
pay_rate: PayRate = user.pay_rate
overtime_threshold = 0
overtime_rate = 0
if pay_rate.overtime_threshold:
overtime_threshold = pay_rate.overtime_threshold
if pay_rate.overtime_rate:
overtime_rate = pay_rate.overtime_rate
if overtime_threshold == 0 or overtime_rate == 0:
base_units = work_units
overtime_units = 0
else:
overtime_units = max(0.0, work_units - overtime_threshold)
base_units = work_units - overtime_units
return pay_rate.base_rate * base_units + overtime_rate * overtime_units
async def _create_payment_record_hourly(
self,
creator: User,
user: User,
record_schema: PaymentRecordCreateSchema
):
pay_rate: PayRate = user.pay_rate
amount = self.get_amount(user, record_schema.work_units)
payment_record_dict = record_schema.model_dump()
del payment_record_dict['user']
payment_record_dict.update({
'created_by_user_id': creator.id,
'created_at': datetime.datetime.now(),
'payroll_scheme_key': pay_rate.payroll_scheme.key,
'amount': amount,
'user_id': record_schema.user.id
})
stmt = (
insert(
PaymentRecord
)
.values(
**payment_record_dict
)
)
await self.session.execute(stmt)
await self.session.commit()
async def create_payment_record(
self,
request: CreatePaymentRecordRequest,
creator: User
) -> CreatePaymentRecordResponse:
try:
user: Optional[User] = await self.session.scalar(select(User).where(User.id == request.data.user.id))
if not user:
return CreatePaymentRecordResponse(ok=False, message='Указанный пользователь не найден')
if not user.pay_rate:
return CreatePaymentRecordResponse(ok=False, message='У пользователя не указан тариф')
user_payroll_scheme = user.pay_rate.payroll_scheme.key
if user_payroll_scheme == enums.payroll.PaySchemeType.hourly:
await self._create_payment_record_hourly(creator, user, request.data)
if user_payroll_scheme == enums.payroll.PaySchemeType.daily:
await self._create_payment_record_hourly(creator, user, request.data)
if user_payroll_scheme == enums.payroll.PaySchemeType.monthly:
await self._create_payment_record_hourly(creator, user, request.data)
return CreatePaymentRecordResponse(ok=True, message='Запись успешно добавлена')
except Exception as e:
return CreatePaymentRecordResponse(ok=False, message=str(e))
async def delete_payment_record(
self,
request: DeletePaymentRecordRequest
) -> DeletePaymentRecordResponse:
try:
stmt = (
delete(
PaymentRecord
)
.where(
PaymentRecord.id == request.payment_record_id
)
)
await self.session.execute(stmt)
await self.session.commit()
return DeletePayRateResponse(ok=True, message="Начисление успешно удалено")
except Exception as e:
return DeletePayRateResponse(ok=False, message=str(e))