91 lines
3.1 KiB
Python
91 lines
3.1 KiB
Python
from typing import Union, Annotated
|
|
|
|
from fastapi import Depends, HTTPException
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from jose import jwt, JWTError
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from starlette import status
|
|
|
|
import backend.config
|
|
import constants
|
|
from backend.session import get_session
|
|
from enums.user import UserRole
|
|
from models import User
|
|
from schemas.auth import *
|
|
from services.base import BaseService
|
|
|
|
oauth2_schema = HTTPBearer()
|
|
algorithm = 'HS256'
|
|
|
|
|
|
async def get_current_user(
|
|
session: Annotated[AsyncSession, Depends(get_session)],
|
|
token: Annotated[HTTPAuthorizationCredentials, Depends(oauth2_schema)]
|
|
) -> Union[User, None, dict]:
|
|
if not token.credentials:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Invalid token')
|
|
try:
|
|
payload = jwt.decode(token.credentials, backend.config.SECRET_KEY, algorithms=[algorithm])
|
|
user_id = payload.get('sub')
|
|
if not user_id:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='Invalid credentials')
|
|
if user_id == 'guest':
|
|
return payload
|
|
user_id = int(user_id)
|
|
|
|
user = await session.get(User, user_id)
|
|
if not user:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Invalid credentials')
|
|
return user
|
|
except JWTError as e:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Invalid token')
|
|
|
|
|
|
async def authorized_user(
|
|
user: Annotated[User, Depends(get_current_user)]
|
|
):
|
|
if type(user) is User:
|
|
return user
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Invalid token')
|
|
|
|
|
|
async def guest_user(user: Annotated[User, Depends(get_current_user)]):
|
|
if (type(user) is User) or (type(user) is dict):
|
|
return user
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Invalid token')
|
|
|
|
|
|
class AuthService(BaseService):
|
|
@staticmethod
|
|
def _generate_jwt_token(payload: dict) -> str:
|
|
return jwt.encode(payload, backend.config.SECRET_KEY, algorithm=algorithm)
|
|
|
|
async def authenticate(self, request: AuthLoginRequest):
|
|
if request.id not in constants.allowed_telegram_ids:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Invalid credentials')
|
|
|
|
user: Union[User, None] = await self.session.scalar(select(User).where(User.telegram_id == request.id))
|
|
if not user:
|
|
user = User(
|
|
telegram_id=request.id,
|
|
is_admin=False,
|
|
role_key=UserRole.user
|
|
)
|
|
self.session.add(user)
|
|
await self.session.commit()
|
|
payload = {
|
|
'sub': str(user.id),
|
|
'role': user.role_key,
|
|
}
|
|
access_token = self._generate_jwt_token(payload)
|
|
return AuthLoginResponse(access_token=access_token)
|
|
|
|
def create_deal_guest_token(self, deal_id: int):
|
|
payload = {
|
|
'sub': 'guest',
|
|
'deal_id': deal_id
|
|
}
|
|
|
|
return self._generate_jwt_token(payload)
|