feat: async pytest, testcases for starting and finishing shifts

This commit is contained in:
2025-11-26 15:58:59 +04:00
parent c71a460170
commit ed00d1483d
14 changed files with 1933 additions and 1393 deletions

View File

@@ -35,9 +35,15 @@ dependencies = [
"starlette>=0.47.3",
"redis[hiredis]>=5.2.1",
"typing-extensions>=4.15.0",
"pytest>=9.0.1",
"pytest-asyncio>=1.3.0",
"freezegun>=1.5.5",
]
[dependency-groups]
dev = [
"deptry>=0.23.1",
]
[tool.pytest.ini_options]
asyncio_mode = "auto"

View File

@@ -1,5 +1,7 @@
import math
from datetime import date, timedelta, time
from collections import defaultdict
from datetime import date, timedelta
from typing import Generator, Any
from fastapi import HTTPException, status
from sqlalchemy import select, func, extract, literal, label, Select
@@ -80,38 +82,30 @@ class WorkShiftsService(BaseService):
if not work_shift or work_shift.finished_at:
return False, "Смена для сотрудника еще не начата"
# End active pause
if work_shift.is_paused:
await self.finish_pause_by_shift_id(work_shift.id)
work_shift.finished_at = datetime.now()
# Collect pauses
# End active pause
if work_shift.is_paused:
work_shift.is_paused = False
work_shift.pauses[-1].finished_at = datetime.now()
pauses = [
(pause.started_at, pause.finished_at)
for pause in work_shift.pauses
]
# Build raw work intervals
# Start with one interval: whole shift
raw_intervals = [(work_shift.started_at, work_shift.finished_at)]
# Subtract pauses from work intervals
work_intervals = self.subtract_pauses(raw_intervals, pauses)
# Split intervals by days
daily_hours = self.split_intervals_by_days(work_intervals)
daily_seconds = self.calculate_daily_worked(work_shift.started_at, work_shift.finished_at, pauses)
# Create work records per day
tts = TimeTrackingService(self.session)
for day, hours in daily_hours.items():
if hours < (1 / 60): # ignore <1 minute
for day, seconds in daily_seconds.items():
if seconds < 60: # ignore <1 minute
continue
data = UpdateTimeTrackingRecordRequest(
user_id=work_shift.user_id,
date=day,
hours=hours,
hours=seconds / 3600,
)
ok, msg = await tts.update_work_record(user, data, False)
if not ok:
@@ -120,7 +114,7 @@ class WorkShiftsService(BaseService):
await self.session.commit()
# Build human-readable result message
total_work_seconds = sum(hours * 3600 for hours in daily_hours.values())
total_work_seconds = sum(seconds for seconds in daily_seconds.values())
total_td = timedelta(seconds=total_work_seconds)
h, m = hours_to_hours_and_minutes(total_td)
logger.info(f"Успешное завершение смены. userID: {work_shift.user_id}. Отработано суммарно: {h} ч. {m} мин.")
@@ -131,50 +125,50 @@ class WorkShiftsService(BaseService):
await self.session.rollback()
return False, str(e)
def subtract_pauses(
self,
work_intervals: list[tuple[datetime, datetime]],
pauses: list[tuple[datetime, datetime]]
) -> list[tuple[datetime, datetime]]:
result = []
@staticmethod
def split_range_by_days(start: datetime, end: datetime) -> Generator[tuple[date, datetime, datetime], Any, None]:
"""
Yield (day_date, day_start, day_end) for each day in the datetime range.
"""
current = start
while current.date() < end.date():
day_end = datetime.combine(current.date(), datetime.max.time())
yield current.date(), current, day_end
current = day_end + timedelta(microseconds=1)
for w_start, w_end in work_intervals:
temp = [(w_start, w_end)]
for p_start, p_end in pauses:
new_temp = []
for s, e in temp:
# pause outside interval → keep original
if p_end <= s or p_start >= e:
new_temp.append((s, e))
else:
# pause cuts interval
if p_start > s:
new_temp.append((s, p_start))
if p_end < e:
new_temp.append((p_end, e))
temp = new_temp
result.extend(temp)
# final partial day
yield end.date(), current, end
return result
@staticmethod
def intersect(a_start, a_end, b_start, b_end) -> Optional[tuple[datetime, datetime]]:
start = max(a_start, b_start)
end = min(a_end, b_end)
return (start, end) if start < end else None
def split_intervals_by_days(self, intervals: list[tuple[datetime, datetime]]) -> dict[date, float]:
from collections import defaultdict
res = defaultdict(float)
@staticmethod
def calculate_daily_worked(
start_shift: datetime,
end_shift: datetime,
shift_pauses: list[tuple[datetime, datetime]]
) -> dict[date, float]:
# Step 1: break shift into days
daily_work = defaultdict(float)
for start, end in intervals:
cur = start
while cur.date() < end.date():
# end of current day
day_end = datetime.combine(cur.date(), time.max)
seconds = (day_end - cur).total_seconds()
res[cur.date()] += seconds / 3600
cur = day_end + timedelta(seconds=1)
for day, day_start, day_end in WorkShiftsService.split_range_by_days(start_shift, end_shift):
# last segment (same day)
seconds = (end - cur).total_seconds()
res[cur.date()] += seconds / 3600
# Compute raw work for the day (before pauses)
day_work_seconds = (day_end - day_start).total_seconds()
return res
# Subtract pauses intersecting with this day
for p_start, p_end in shift_pauses:
inter = WorkShiftsService.intersect(day_start, day_end, p_start, p_end)
if inter:
p_s, p_e = inter
day_work_seconds -= (p_e - p_s).total_seconds()
daily_work[day] += day_work_seconds
return daily_work
@staticmethod
def get_work_shifts_history_stmt() -> Select:

1
test.sh Normal file
View File

@@ -0,0 +1 @@
python -m pytest --disable-warnings

105
tests/conftest.py Normal file
View File

@@ -0,0 +1,105 @@
import asyncio
import os
from pathlib import Path
from typing import AsyncGenerator
import jwt
import pytest
from dotenv import load_dotenv
from httpx import ASGITransport, AsyncClient
from sqlalchemy import StaticPool
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from backend.session import get_session
from main import app
from models import BaseModel
from tests.fixture_loader import FixtureLoader
project_root = Path(__file__).parent.parent
env_path = project_root / ".env"
load_dotenv(env_path)
SECRET_KEY = os.getenv("SECRET_KEY")
PG_LOGIN = os.environ.get("PG_LOGIN")
PG_PASSWORD = os.environ.get("PG_PASSWORD")
TEST_DATABASE_URL = f"postgresql+asyncpg://{PG_LOGIN}:{PG_PASSWORD}@127.0.0.1/test"
TEST_DATABASE_URL_FOR_ASYNCPG = f"postgresql://{PG_LOGIN}:{PG_PASSWORD}@127.0.0.1/test"
# -------------------------------------------------------------------
# Create test database + connection pool
# -------------------------------------------------------------------
@pytest.fixture(scope="function")
def event_loop():
"""Force pytest to use a session-scoped event loop (needed for asyncpg)."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="function")
async def db_session() -> AsyncGenerator[AsyncSession, None]:
# Create test engine
test_engine = create_async_engine(
TEST_DATABASE_URL,
poolclass=StaticPool, # Useful for tests
)
# Test session factory
TestAsyncSessionLocal = async_sessionmaker(
test_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with test_engine.begin() as conn:
await conn.run_sync(BaseModel.metadata.drop_all)
await conn.run_sync(BaseModel.metadata.create_all)
session = TestAsyncSessionLocal()
await FixtureLoader().load_fixtures(session)
try:
yield session
finally:
await session.close()
async with test_engine.begin() as conn:
await conn.run_sync(BaseModel.metadata.drop_all)
await test_engine.dispose()
# -------------------------------------------------------------------
# HTTPX AsyncClient for FastAPI tests
# -------------------------------------------------------------------
@pytest.fixture(scope="function")
async def client(db_session: AsyncSession):
def override_get_session():
try:
yield db_session
finally:
db_session.close()
app.dependency_overrides[get_session] = override_get_session
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://localhost:8000/api"
) as ac:
yield ac
# -------------------------------------------------------------------
# Auth token fixture
# -------------------------------------------------------------------
@pytest.fixture(scope="function")
async def admin_client(client: AsyncClient):
payload = {
"sub": "1",
"role": "admin",
}
auth_token = jwt.encode(payload, SECRET_KEY, algorithm="HS256")
client.headers.update({"Authorization": f"Bearer {auth_token}"})
return client

103
tests/fixture_loader.py Normal file
View File

@@ -0,0 +1,103 @@
import json
import os
from datetime import datetime
from pathlib import Path
from typing import Any
from sqlalchemy import text, Table, insert
from sqlalchemy.ext.asyncio import AsyncSession
from models import User, Role, WorkShift, PayrollScheme, PayRate, user_pay_rate
from models.work_shifts import WorkShiftPause
class FixtureLoader:
def __init__(self, fixture_path: str = "fixtures"):
project_root = Path(__file__).parent
self.fixture_path = project_root / fixture_path
@staticmethod
def _fixtures_to_load() -> list[tuple[str, Any]]:
return [
("roles", Role),
("payroll_schemas", PayrollScheme),
("pay_rates", PayRate),
("users", User),
("work_shifts", WorkShift),
("work_shift_pauses", WorkShiftPause),
]
@staticmethod
def _many_to_many_fixtures() -> list[tuple[str, Table]]:
return [
("user_pay_rates", user_pay_rate),
]
async def load_fixtures(self, db: AsyncSession):
file_postfix = ".json"
for fixture_file, model in self._fixtures_to_load():
await self._load_model_fixtures(db, fixture_file + file_postfix, model)
for fixture_file, table in self._many_to_many_fixtures():
await self._load_m2m_fixtures(db, fixture_file + file_postfix, table)
async def _load_model_fixtures(
self,
db: AsyncSession,
fixture_file: str,
model: Any,
):
"""Load fixtures for a specific model"""
fixture_path = os.path.join(self.fixture_path, fixture_file)
if not os.path.exists(fixture_path):
print(f"Fixture file {fixture_path} not found")
return 0
with open(fixture_path, "r") as f:
data = json.load(f)
for item_data in data:
converted_data = {}
for key, value in item_data.items():
if isinstance(value, str) and len(value) == 19:
try:
# Try to parse as datetime
converted_data[key] = datetime.strptime(value, "%Y-%m-%d %H:%M:%S")
except ValueError:
# If it fails, keep the original value
converted_data[key] = value
else:
converted_data[key] = value
db_item = model(**converted_data)
db.add(db_item)
await db.commit()
async def _load_m2m_fixtures(
self,
db: AsyncSession,
fixture_file: str,
table: Table,
):
"""Load fixtures for many-to-many association tables"""
fixture_path = os.path.join(self.fixture_path, fixture_file)
if not os.path.exists(fixture_path):
print(f"Fixture file {fixture_path} not found")
return 0
with open(fixture_path, "r") as f:
data = json.load(f)
# Use SQLAlchemy insert for association tables
if data:
await db.execute(insert(table), data)
await db.commit()
async def clear_fixtures(self, db: AsyncSession):
"""Clear all fixture data (useful for testing)"""
for fixture_file, _ in self._fixtures_to_load()[::-1]:
await db.execute(text("DELETE FROM " + fixture_file))
await db.commit()

18
tests/fixtures/pay_rates.json vendored Normal file
View File

@@ -0,0 +1,18 @@
[
{
"id": 1,
"name": "Старший упаковщик",
"payroll_scheme_key": "hourly",
"base_rate": 350,
"overtime_rate": 450,
"overtime_threshold": 8
},
{
"id": 2,
"name": "Менеджер по продажам",
"payroll_scheme_key": "hourly",
"base_rate": 450,
"overtime_rate": 550,
"overtime_threshold": 8
}
]

14
tests/fixtures/payroll_schemas.json vendored Normal file
View File

@@ -0,0 +1,14 @@
[
{
"key": "hourly",
"name": "Почасовая"
},
{
"key": "daily",
"name": "Подневная"
},
{
"key": "monthly",
"name": "Помесячная"
}
]

14
tests/fixtures/roles.json vendored Normal file
View File

@@ -0,0 +1,14 @@
[
{
"key": "admin",
"name": "Админ"
},
{
"key": "user",
"name": "Базовый пользователь"
},
{
"key": "employee",
"name": "Сотрудник"
}
]

10
tests/fixtures/user_pay_rates.json vendored Normal file
View File

@@ -0,0 +1,10 @@
[
{
"pay_rate_id": 1,
"user_id": 2
},
{
"pay_rate_id": 2,
"user_id": 1
}
]

30
tests/fixtures/users.json vendored Normal file
View File

@@ -0,0 +1,30 @@
[
{
"id": 1,
"first_name": "Алексей",
"second_name": "Васильев",
"patronymic": "Алексеевич",
"comment": "First admin user",
"telegram_id": 123123123,
"phone_number": "88005553535",
"passport_data": "3443556677",
"is_admin": true,
"is_blocked": false,
"is_deleted": false,
"role_key": "admin"
},
{
"id": 2,
"first_name": "Магаджан",
"second_name": "Хузургалиев",
"patronymic": "Татариевич",
"comment": "First employee user",
"telegram_id": 33322211122,
"phone_number": "88005553536",
"passport_data": "8899123321",
"is_admin": false,
"is_blocked": false,
"is_deleted": false,
"role_key": "employee"
}
]

20
tests/fixtures/work_shift_pauses.json vendored Normal file
View File

@@ -0,0 +1,20 @@
[
{
"id": 100,
"started_at": "2024-11-21 18:55:00",
"finished_at": "2024-11-21 19:25:00",
"work_shift_id": 100
},
{
"id": 101,
"started_at": "2024-11-11 23:30:00",
"finished_at": "2024-11-12 00:30:00",
"work_shift_id": 101
},
{
"id": 102,
"started_at": "2024-11-13 09:30:00",
"finished_at": null,
"work_shift_id": 101
}
]

16
tests/fixtures/work_shifts.json vendored Normal file
View File

@@ -0,0 +1,16 @@
[
{
"id": 100,
"started_at": "2024-11-21 11:55:00",
"finished_at": null,
"user_id": 2,
"is_paused": false
},
{
"id": 101,
"started_at": "2024-11-11 12:00:00",
"finished_at": null,
"user_id": 1,
"is_paused": true
}
]

126
tests/test_time_tracking.py Normal file
View File

@@ -0,0 +1,126 @@
from datetime import datetime
from typing import Optional
import pytest
from freezegun import freeze_time
from httpx import AsyncClient, Response
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql.functions import count
from models import WorkShift, PaymentRecord
from tests.conftest import db_session
@pytest.mark.asyncio
async def test_start_shift(admin_client: AsyncClient, db_session: AsyncSession):
now = datetime.now()
user_id = 2
response: Response = await admin_client.post(f"/work-shifts/start-shift/{user_id}")
# Assert response
assert response.status_code == 200
assert response.json().get("ok") is True
get_workshift = select(WorkShift).where(WorkShift.user_id == user_id, func.date(WorkShift.started_at) == now.date())
work_shift: Optional[WorkShift] = (await db_session.execute(get_workshift)).scalars().one_or_none()
# Assert database
assert work_shift is not None
assert work_shift.finished_at is None
assert work_shift.started_at.hour == now.hour
assert work_shift.started_at.minute == now.minute
@pytest.mark.asyncio
async def test_forbidden_starting_second_shift_per_day(admin_client: AsyncClient, db_session: AsyncSession):
now = datetime.now()
user_id = 2
response: Response = await admin_client.post(f"/work-shifts/start-shift/{user_id}")
# Assert first response
assert response.status_code == 200
assert response.json().get("ok") is True
response: Response = await admin_client.post(f"/work-shifts/start-shift/{user_id}")
# Assert second response
assert response.status_code == 200
assert response.json().get("ok") is False
get_count = (
select(count(WorkShift.id))
.where(WorkShift.user_id == user_id, func.date(WorkShift.started_at) == now.date())
)
work_shift_count: int = (await db_session.execute(get_count)).scalar()
# Assert database
assert work_shift_count == 1
@pytest.mark.asyncio
@freeze_time("2024-11-21 22:25:00")
async def test_finish_one_day_shift(admin_client: AsyncClient, db_session: AsyncSession):
fixed_now = datetime(2024, 11, 21, 22, 25, 0)
user_id = 2
response: Response = await admin_client.post(f"/work-shifts/finish-shift/{user_id}")
# Assert response
assert response.status_code == 200
assert response.json().get("ok") is True
# Assert database
get_workshift = select(WorkShift).where(WorkShift.id == 100)
work_shift: Optional[WorkShift] = (await db_session.execute(get_workshift)).scalars().one_or_none()
assert work_shift is not None
assert work_shift.finished_at == fixed_now
assert work_shift.user_id == user_id
get_payments = select(PaymentRecord).where(PaymentRecord.user_id == user_id,
PaymentRecord.start_date == fixed_now.date())
payment: Optional[PaymentRecord] = (await db_session.execute(get_payments)).scalars().one_or_none()
assert payment is not None
assert payment.created_by_user_id == 1
# работа: 8 * 350;
# переработка: 2.5 * 450;
# из них был перерыв: 0.5 * 450;
assert abs(payment.amount - 3700) < 0.01
@pytest.mark.asyncio
@freeze_time("2024-11-13 10:00:00")
async def test_finish_three_days_shift(admin_client: AsyncClient, db_session: AsyncSession):
fixed_now = datetime(2024, 11, 13, 10, 00, 0)
user_id = 1
response: Response = await admin_client.post(f"/work-shifts/finish-shift/{user_id}")
# Assert response
assert response.status_code == 200
assert response.json().get("ok") is True
# Assert database
get_workshift = select(WorkShift).where(WorkShift.id == 101)
work_shift: Optional[WorkShift] = (await db_session.execute(get_workshift)).scalars().one_or_none()
assert work_shift is not None
assert work_shift.finished_at == fixed_now
get_payments = select(PaymentRecord).where(
PaymentRecord.user_id == user_id,
PaymentRecord.start_date.between(
datetime(2024, 11, 11, 00, 00),
datetime(2024, 11, 14, 00, 00, 00)
)
)
payments = (await db_session.execute(get_payments)).scalars().all()
assert len(payments) == 3
# работа: 8 * 450;
# переработка: 4 * 550;
# из них был перерыв: 0.5 * 550;
assert abs(payments[0].amount - 5525) < 0.01
# работа: 8 * 450;
# переработка: 16 * 550;
# из них был перерыв: 0.5 * 550;
assert abs(payments[1].amount - 12125) < 0.01
# работа: 8 * 450;
# переработка: 2 * 550;
# из них был перерыв: 0.5 * 550;
assert abs(payments[2].amount - 4425) < 0.01

2755
uv.lock generated

File diff suppressed because it is too large Load Diff