Compare commits

...

4 Commits

17 changed files with 2026 additions and 1419 deletions

View File

@@ -17,4 +17,4 @@ class JsonFormatter(logging.Formatter):
if record.exc_info:
log_record["exception"] = self.formatException(record.exc_info)
return json.dumps(log_record)
return json.dumps(log_record, ensure_ascii=False)

View File

@@ -1,5 +1,5 @@
from datetime import datetime, date
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from sqlalchemy import ForeignKey, Table, Column
from sqlalchemy.sql import expression
@@ -15,12 +15,8 @@ class WorkShift(BaseModel):
__tablename__ = "work_shifts"
id: Mapped[int] = mapped_column(primary_key=True)
started_at: Mapped[datetime] = mapped_column(
nullable=False,
)
finished_at: Mapped[datetime] = mapped_column(
nullable=True,
)
started_at: Mapped[datetime] = mapped_column()
finished_at: Mapped[Optional[datetime]] = mapped_column()
is_paused: Mapped[bool] = mapped_column(
default=False,
server_default=expression.false(),
@@ -48,12 +44,8 @@ class WorkShiftPause(BaseModel):
__tablename__ = "work_shifts_pauses"
id: Mapped[int] = mapped_column(primary_key=True)
started_at: Mapped[datetime] = mapped_column(
nullable=False,
)
finished_at: Mapped[datetime] = mapped_column(
nullable=True,
)
started_at: Mapped[datetime] = mapped_column()
finished_at: Mapped[Optional[datetime]] = mapped_column()
work_shift_id: Mapped[int] = mapped_column(
ForeignKey("work_shifts.id"),
@@ -77,10 +69,10 @@ class PlannedWorkShift(BaseModel):
__tablename__ = "planned_work_shifts"
id: Mapped[int] = mapped_column(primary_key=True)
shift_date: Mapped[date] = mapped_column(nullable=False, index=True)
created_at: Mapped[datetime] = mapped_column(nullable=False)
shift_date: Mapped[date] = mapped_column(index=True)
created_at: Mapped[datetime] = mapped_column()
user_id: Mapped[int] = mapped_column(ForeignKey('users.id'), nullable=False, index=True)
user_id: Mapped[int] = mapped_column(ForeignKey('users.id'), index=True)
user: Mapped["User"] = relationship(lazy="selectin", backref="planned_work_shifts")
positions: Mapped[list["Position"]] = relationship(

View File

@@ -35,9 +35,15 @@ dependencies = [
"starlette>=0.47.3",
"redis[hiredis]>=5.2.1",
"typing-extensions>=4.15.0",
"pytest>=9.0.1",
"pytest-asyncio>=1.3.0",
"freezegun>=1.5.5",
]
[dependency-groups]
dev = [
"deptry>=0.23.1",
]
[tool.pytest.ini_options]
asyncio_mode = "auto"

View File

@@ -59,54 +59,64 @@ class TimeTrackingService(BaseService):
records.append(record)
return GetTimeTrackingRecordsResponse(records=records)
async def update_work_record(
self,
user: User,
request: UpdateTimeTrackingRecordRequest,
commit: bool = True,
) -> tuple[bool, str]:
record_user = await UserService(self.session).get_by_id(user_id=request.user_id)
if not record_user:
return False, "Указанный пользователь не найден!"
if not record_user.pay_rate:
return False, "У пользователя не указана схема оплаты!"
existing_record_stmt = (
select(
PaymentRecord
)
.where(
PaymentRecord.user_id == request.user_id,
PaymentRecord.start_date == request.date,
PaymentRecord.end_date == request.date,
)
)
amount = (
PayrollService(
self.session
)
.get_amount(
user=record_user,
work_units=request.hours
)
)
existing_record = await self.session.scalar(existing_record_stmt)
if existing_record:
existing_record: PaymentRecord
existing_record.work_units = request.hours
existing_record.amount = amount
else:
new_record = PaymentRecord(
user_id=request.user_id,
created_by_user_id=user.id,
start_date=request.date,
end_date=request.date,
created_at=datetime.datetime.now(),
payroll_scheme_key=record_user.pay_rate.payroll_scheme_key,
amount=amount,
work_units=request.hours
)
self.session.add(new_record)
if commit:
await self.session.commit()
return True, "Запись успешно добавлена"
async def update_record(
self,
user: User,
request: UpdateTimeTrackingRecordRequest
) -> UpdateTimeTrackingRecordResponse:
try:
record_user = await UserService(self.session).get_by_id(user_id=request.user_id)
if not record_user:
return UpdateTimeTrackingRecordResponse(ok=False, message="Указанный пользователь не найден!")
if not record_user.pay_rate:
return UpdateTimeTrackingRecordResponse(ok=False, message="У пользователя не указана схема оплаты!")
existing_record_stmt = (
select(
PaymentRecord
)
.where(
PaymentRecord.user_id == request.user_id,
PaymentRecord.start_date == request.date,
PaymentRecord.end_date == request.date,
)
)
amount = (
PayrollService(
self.session
)
.get_amount(
user=record_user,
work_units=request.hours
)
)
existing_record = await self.session.scalar(existing_record_stmt)
if existing_record:
existing_record: PaymentRecord
existing_record.work_units = request.hours
existing_record.amount = amount
else:
new_record = PaymentRecord(
user_id=request.user_id,
created_by_user_id=user.id,
start_date=request.date,
end_date=request.date,
created_at=datetime.datetime.now(),
payroll_scheme_key=record_user.pay_rate.payroll_scheme_key,
amount=amount,
work_units=request.hours
)
self.session.add(new_record)
await self.session.commit()
return UpdateTimeTrackingRecordResponse(ok=True, message="Запись успешно обновлена")
ok, message = await self.update_work_record(user, request)
return UpdateTimeTrackingRecordResponse(ok=ok, message=message)
except Exception as e:
return UpdateTimeTrackingRecordResponse(ok=False, message=str(e))

View File

@@ -1,10 +1,13 @@
from datetime import date, timedelta
import math
from collections import defaultdict
from datetime import date, timedelta
from typing import Generator, Any
from fastapi import HTTPException, status
from sqlalchemy import select, func, extract, literal, label, Select
from sqlalchemy.orm import joinedload, selectinload
from logger import logger_builder
from models import WorkShift, User
from models.work_shifts import WorkShiftPause
from schemas.base import PaginationSchema
@@ -74,33 +77,98 @@ class WorkShiftsService(BaseService):
return FinishShiftByIdResponse(ok=ok, message=message)
async def _finish_shift_common(self, user: User, work_shift: Optional[WorkShift]) -> tuple[bool, str]:
if not work_shift or work_shift.finished_at:
return False, "Смена для сотрудника еще не начата"
logger = logger_builder.get_logger()
try:
if not work_shift or work_shift.finished_at:
return False, "Смена для сотрудника еще не начата"
if work_shift.is_paused:
await self.finish_pause_by_shift_id(work_shift.id)
work_shift.finished_at = datetime.now()
work_shift.finished_at = datetime.now()
await self.session.commit()
# End active pause
if work_shift.is_paused:
work_shift.is_paused = False
work_shift.pauses[-1].finished_at = datetime.now()
pause_time = timedelta()
for pause in work_shift.pauses:
pause_time += pause.finished_at - pause.started_at
pauses = [
(pause.started_at, pause.finished_at)
for pause in work_shift.pauses
]
total_work_time: timedelta = work_shift.finished_at - work_shift.started_at
pure_work_seconds = total_work_time.total_seconds() - pause_time.total_seconds()
hours = pure_work_seconds / 3600
daily_seconds = self.calculate_daily_worked(work_shift.started_at, work_shift.finished_at, pauses)
if pure_work_seconds >= 60:
data = UpdateTimeTrackingRecordRequest(
user_id=work_shift.user_id,
date=work_shift.started_at.date(),
hours=hours,
)
await TimeTrackingService(self.session).update_record(user, data)
# Create work records per day
tts = TimeTrackingService(self.session)
for day, seconds in daily_seconds.items():
if seconds < 60: # ignore <1 minute
continue
hours, minutes = hours_to_hours_and_minutes(total_work_time)
return True, f"Смена закончена. Отработано {hours} ч. {minutes} мин."
data = UpdateTimeTrackingRecordRequest(
user_id=work_shift.user_id,
date=day,
hours=seconds / 3600,
)
ok, msg = await tts.update_work_record(user, data, False)
if not ok:
raise Exception(msg)
await self.session.commit()
# Build human-readable result message
total_work_seconds = sum(seconds for seconds in daily_seconds.values())
total_td = timedelta(seconds=total_work_seconds)
h, m = hours_to_hours_and_minutes(total_td)
logger.info(f"Успешное завершение смены. userID: {work_shift.user_id}. Отработано суммарно: {h} ч. {m} мин.")
return True, f"Смена закончена. Отработано {h} ч. {m} мин."
except Exception as e:
logger.error(f"Ошибка завершения смены. userID: {work_shift.user_id}. Ошибка: {str(e)}")
await self.session.rollback()
return False, str(e)
@staticmethod
def split_range_by_days(start: datetime, end: datetime) -> Generator[tuple[date, datetime, datetime], Any, None]:
"""
Yield (day_date, day_start, day_end) for each day in the datetime range.
"""
current = start
while current.date() < end.date():
day_end = datetime.combine(current.date(), datetime.max.time())
yield current.date(), current, day_end
current = day_end + timedelta(microseconds=1)
# final partial day
yield end.date(), current, end
@staticmethod
def intersect(a_start, a_end, b_start, b_end) -> Optional[tuple[datetime, datetime]]:
start = max(a_start, b_start)
end = min(a_end, b_end)
return (start, end) if start < end else None
@staticmethod
def calculate_daily_worked(
start_shift: datetime,
end_shift: datetime,
shift_pauses: list[tuple[datetime, datetime]]
) -> dict[date, float]:
# Step 1: break shift into days
daily_work = defaultdict(float)
for day, day_start, day_end in WorkShiftsService.split_range_by_days(start_shift, end_shift):
# Compute raw work for the day (before pauses)
day_work_seconds = (day_end - day_start).total_seconds()
# Subtract pauses intersecting with this day
for p_start, p_end in shift_pauses:
inter = WorkShiftsService.intersect(day_start, day_end, p_start, p_end)
if inter:
p_s, p_e = inter
day_work_seconds -= (p_e - p_s).total_seconds()
daily_work[day] += day_work_seconds
return daily_work
@staticmethod
def get_work_shifts_history_stmt() -> Select:

1
test.sh Normal file
View File

@@ -0,0 +1 @@
python -m pytest --disable-warnings

104
tests/conftest.py Normal file
View File

@@ -0,0 +1,104 @@
import asyncio
import os
from pathlib import Path
from typing import AsyncGenerator
import jwt
import pytest
from dotenv import load_dotenv
from httpx import ASGITransport, AsyncClient
from sqlalchemy import StaticPool
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from backend.session import get_session
from main import app
from models import BaseModel
from tests.fixture_loader import FixtureLoader
project_root = Path(__file__).parent.parent
env_path = project_root / ".env"
load_dotenv(env_path)
SECRET_KEY = os.getenv("SECRET_KEY")
PG_LOGIN = os.environ.get("PG_LOGIN")
PG_PASSWORD = os.environ.get("PG_PASSWORD")
TEST_DATABASE_URL = f"postgresql+asyncpg://{PG_LOGIN}:{PG_PASSWORD}@127.0.0.1/test"
TEST_DATABASE_URL_FOR_ASYNCPG = f"postgresql://{PG_LOGIN}:{PG_PASSWORD}@127.0.0.1/test"
# -------------------------------------------------------------------
# Create test database and session
# -------------------------------------------------------------------
@pytest.fixture(scope="function")
def event_loop():
"""Force pytest to use a session-scoped event loop (needed for asyncpg)."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="function")
async def db_session() -> AsyncGenerator[AsyncSession, None]:
test_engine = create_async_engine(
TEST_DATABASE_URL,
poolclass=StaticPool,
)
# Test session factory
TestAsyncSessionLocal = async_sessionmaker(
test_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with test_engine.begin() as conn:
await conn.run_sync(BaseModel.metadata.drop_all)
await conn.run_sync(BaseModel.metadata.create_all)
session = TestAsyncSessionLocal()
await FixtureLoader().load_fixtures(session)
try:
yield session
finally:
await session.close()
async with test_engine.begin() as conn:
await conn.run_sync(BaseModel.metadata.drop_all)
await test_engine.dispose()
# -------------------------------------------------------------------
# HTTPX AsyncClient for FastAPI tests
# -------------------------------------------------------------------
@pytest.fixture(scope="function")
async def client(db_session: AsyncSession):
def override_get_session():
try:
yield db_session
finally:
db_session.close()
app.dependency_overrides[get_session] = override_get_session
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://localhost:8000/api"
) as ac:
yield ac
# -------------------------------------------------------------------
# Authorized client fixture
# -------------------------------------------------------------------
@pytest.fixture(scope="function")
async def admin_client(client: AsyncClient):
payload = {
"sub": "1",
"role": "admin",
}
auth_token = jwt.encode(payload, SECRET_KEY, algorithm="HS256")
client.headers.update({"Authorization": f"Bearer {auth_token}"})
return client

95
tests/fixture_loader.py Normal file
View File

@@ -0,0 +1,95 @@
import json
import os
from datetime import datetime
from pathlib import Path
from typing import Any
from sqlalchemy import Table, insert
from sqlalchemy.ext.asyncio import AsyncSession
from models import User, Role, WorkShift, PayrollScheme, PayRate, user_pay_rate
from models.work_shifts import WorkShiftPause
class FixtureLoader:
def __init__(self, fixture_path: str = "fixtures"):
project_root = Path(__file__).parent
self.fixture_path = project_root / fixture_path
@staticmethod
def _fixtures_to_load() -> list[tuple[str, Any]]:
return [
("roles", Role),
("payroll_schemas", PayrollScheme),
("pay_rates", PayRate),
("users", User),
("work_shifts", WorkShift),
("work_shift_pauses", WorkShiftPause),
]
@staticmethod
def _many_to_many_fixtures() -> list[tuple[str, Table]]:
return [
("user_pay_rates", user_pay_rate),
]
async def load_fixtures(self, db: AsyncSession):
file_postfix = ".json"
for fixture_file, model in self._fixtures_to_load():
await self._load_model_fixtures(db, fixture_file + file_postfix, model)
for fixture_file, table in self._many_to_many_fixtures():
await self._load_m2m_fixtures(db, fixture_file + file_postfix, table)
async def _load_model_fixtures(
self,
db: AsyncSession,
fixture_file: str,
model: Any,
):
"""Load fixtures for a specific model"""
fixture_path = os.path.join(self.fixture_path, fixture_file)
if not os.path.exists(fixture_path):
print(f"Fixture file {fixture_path} not found")
return
with open(fixture_path, "r") as f:
data = json.load(f)
for item_data in data:
converted_data = {}
for key, value in item_data.items():
converted_data[key] = value
if isinstance(value, str) and len(value) == 19:
try:
converted_data[key] = datetime.strptime(value, "%Y-%m-%d %H:%M:%S")
except ValueError:
converted_data[key] = value
db_item = model(**converted_data)
db.add(db_item)
await db.commit()
async def _load_m2m_fixtures(
self,
db: AsyncSession,
fixture_file: str,
table: Table,
):
"""Load fixtures for many-to-many association tables"""
fixture_path = os.path.join(self.fixture_path, fixture_file)
if not os.path.exists(fixture_path):
print(f"Fixture file {fixture_path} not found")
return
with open(fixture_path, "r") as f:
data = json.load(f)
if not data:
return
await db.execute(insert(table), data)
await db.commit()

18
tests/fixtures/pay_rates.json vendored Normal file
View File

@@ -0,0 +1,18 @@
[
{
"id": 1,
"name": "Старший упаковщик",
"payroll_scheme_key": "hourly",
"base_rate": 350,
"overtime_rate": 450,
"overtime_threshold": 8
},
{
"id": 2,
"name": "Менеджер по продажам",
"payroll_scheme_key": "hourly",
"base_rate": 450,
"overtime_rate": 550,
"overtime_threshold": 8
}
]

14
tests/fixtures/payroll_schemas.json vendored Normal file
View File

@@ -0,0 +1,14 @@
[
{
"key": "hourly",
"name": "Почасовая"
},
{
"key": "daily",
"name": "Подневная"
},
{
"key": "monthly",
"name": "Помесячная"
}
]

14
tests/fixtures/roles.json vendored Normal file
View File

@@ -0,0 +1,14 @@
[
{
"key": "admin",
"name": "Админ"
},
{
"key": "user",
"name": "Базовый пользователь"
},
{
"key": "employee",
"name": "Сотрудник"
}
]

10
tests/fixtures/user_pay_rates.json vendored Normal file
View File

@@ -0,0 +1,10 @@
[
{
"pay_rate_id": 1,
"user_id": 2
},
{
"pay_rate_id": 2,
"user_id": 1
}
]

30
tests/fixtures/users.json vendored Normal file
View File

@@ -0,0 +1,30 @@
[
{
"id": 1,
"first_name": "Алексей",
"second_name": "Васильев",
"patronymic": "Алексеевич",
"comment": "First admin user",
"telegram_id": 123123123,
"phone_number": "88005553535",
"passport_data": "3443556677",
"is_admin": true,
"is_blocked": false,
"is_deleted": false,
"role_key": "admin"
},
{
"id": 2,
"first_name": "Магаджан",
"second_name": "Хузургалиев",
"patronymic": "Татариевич",
"comment": "First employee user",
"telegram_id": 33322211122,
"phone_number": "88005553536",
"passport_data": "8899123321",
"is_admin": false,
"is_blocked": false,
"is_deleted": false,
"role_key": "employee"
}
]

20
tests/fixtures/work_shift_pauses.json vendored Normal file
View File

@@ -0,0 +1,20 @@
[
{
"id": 100,
"started_at": "2024-11-21 18:55:00",
"finished_at": "2024-11-21 19:25:00",
"work_shift_id": 100
},
{
"id": 101,
"started_at": "2024-11-11 23:30:00",
"finished_at": "2024-11-12 00:30:00",
"work_shift_id": 101
},
{
"id": 102,
"started_at": "2024-11-13 09:30:00",
"finished_at": null,
"work_shift_id": 101
}
]

16
tests/fixtures/work_shifts.json vendored Normal file
View File

@@ -0,0 +1,16 @@
[
{
"id": 100,
"started_at": "2024-11-21 11:55:00",
"finished_at": null,
"user_id": 2,
"is_paused": false
},
{
"id": 101,
"started_at": "2024-11-11 12:00:00",
"finished_at": null,
"user_id": 1,
"is_paused": true
}
]

126
tests/test_time_tracking.py Normal file
View File

@@ -0,0 +1,126 @@
from datetime import datetime
from typing import Optional
import pytest
from freezegun import freeze_time
from httpx import AsyncClient, Response
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql.functions import count
from models import WorkShift, PaymentRecord
from tests.conftest import db_session
@pytest.mark.asyncio
async def test_start_shift(admin_client: AsyncClient, db_session: AsyncSession):
now = datetime.now()
user_id = 2
response: Response = await admin_client.post(f"/work-shifts/start-shift/{user_id}")
# Assert response
assert response.status_code == 200
assert response.json().get("ok") is True
get_workshift = select(WorkShift).where(WorkShift.user_id == user_id, func.date(WorkShift.started_at) == now.date())
work_shift: Optional[WorkShift] = (await db_session.execute(get_workshift)).scalars().one_or_none()
# Assert database
assert work_shift is not None
assert work_shift.finished_at is None
assert work_shift.started_at.hour == now.hour
assert work_shift.started_at.minute == now.minute
@pytest.mark.asyncio
async def test_forbidden_starting_second_shift_per_day(admin_client: AsyncClient, db_session: AsyncSession):
now = datetime.now()
user_id = 2
response: Response = await admin_client.post(f"/work-shifts/start-shift/{user_id}")
# Assert first response
assert response.status_code == 200
assert response.json().get("ok") is True
response: Response = await admin_client.post(f"/work-shifts/start-shift/{user_id}")
# Assert second response
assert response.status_code == 200
assert response.json().get("ok") is False
get_count = (
select(count(WorkShift.id))
.where(WorkShift.user_id == user_id, func.date(WorkShift.started_at) == now.date())
)
work_shift_count: int = (await db_session.execute(get_count)).scalar()
# Assert database
assert work_shift_count == 1
@pytest.mark.asyncio
@freeze_time("2024-11-21 22:25:00")
async def test_finish_one_day_shift(admin_client: AsyncClient, db_session: AsyncSession):
fixed_now = datetime(2024, 11, 21, 22, 25, 0)
user_id = 2
response: Response = await admin_client.post(f"/work-shifts/finish-shift/{user_id}")
# Assert response
assert response.status_code == 200
assert response.json().get("ok") is True
# Assert database
get_workshift = select(WorkShift).where(WorkShift.id == 100)
work_shift: Optional[WorkShift] = (await db_session.execute(get_workshift)).scalars().one_or_none()
assert work_shift is not None
assert work_shift.finished_at == fixed_now
assert work_shift.user_id == user_id
get_payments = select(PaymentRecord).where(PaymentRecord.user_id == user_id,
PaymentRecord.start_date == fixed_now.date())
payment: Optional[PaymentRecord] = (await db_session.execute(get_payments)).scalars().one_or_none()
assert payment is not None
assert payment.created_by_user_id == 1
# работа: 8 * 350;
# переработка: 2.5 * 450;
# из них был перерыв: 0.5 * 450;
assert abs(payment.amount - 3700) < 0.01
@pytest.mark.asyncio
@freeze_time("2024-11-13 10:00:00")
async def test_finish_three_days_shift(admin_client: AsyncClient, db_session: AsyncSession):
fixed_now = datetime(2024, 11, 13, 10, 00, 0)
user_id = 1
response: Response = await admin_client.post(f"/work-shifts/finish-shift/{user_id}")
# Assert response
assert response.status_code == 200
assert response.json().get("ok") is True
# Assert database
get_workshift = select(WorkShift).where(WorkShift.id == 101)
work_shift: Optional[WorkShift] = (await db_session.execute(get_workshift)).scalars().one_or_none()
assert work_shift is not None
assert work_shift.finished_at == fixed_now
get_payments = select(PaymentRecord).where(
PaymentRecord.user_id == user_id,
PaymentRecord.start_date.between(
datetime(2024, 11, 11, 00, 00),
datetime(2024, 11, 14, 00, 00, 00)
)
)
payments = (await db_session.execute(get_payments)).scalars().all()
assert len(payments) == 3
# работа: 8 * 450;
# переработка: 4 * 550;
# из них был перерыв: 0.5 * 550;
assert abs(payments[0].amount - 5525) < 0.01
# работа: 8 * 450;
# переработка: 16 * 550;
# из них был перерыв: 0.5 * 550;
assert abs(payments[1].amount - 12125) < 0.01
# работа: 8 * 450;
# переработка: 2 * 550;
# из них был перерыв: 0.5 * 550;
assert abs(payments[2].amount - 4425) < 0.01

2755
uv.lock generated

File diff suppressed because it is too large Load Diff