89 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			89 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import asyncio
 | 
						|
import logging
 | 
						|
from datetime import datetime
 | 
						|
 | 
						|
from redis import asyncio as aioredis
 | 
						|
 | 
						|
import backend.config
 | 
						|
from constants import APP_PATH
 | 
						|
 | 
						|
 | 
						|
class RedisConnectionManager:
 | 
						|
    _redis_connection = None
 | 
						|
    _redis_lock = asyncio.Lock()
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    async def get_redis_connection(cls):
 | 
						|
        async with cls._redis_lock:
 | 
						|
            if cls._redis_connection is None:
 | 
						|
                cls._redis_connection = await aioredis.from_url(backend.config.REDIS_URL)
 | 
						|
                path = APP_PATH + "/redis.log"
 | 
						|
 | 
						|
                with open(path, "a") as f:
 | 
						|
                    current_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
 | 
						|
                    line = f"[{current_datetime}] Created connection to redis\n"
 | 
						|
                    f.write(line)
 | 
						|
 | 
						|
            return cls._redis_connection
 | 
						|
 | 
						|
 | 
						|
class BatchLimiter:
 | 
						|
    def __init__(self):
 | 
						|
        self.lock = asyncio.Lock()
 | 
						|
 | 
						|
    async def acquire(self, key, max_requests, period):
 | 
						|
        redis = await RedisConnectionManager.get_redis_connection()
 | 
						|
        while True:
 | 
						|
            async with redis.lock(f"{key}_lock"):
 | 
						|
                try:
 | 
						|
                    start_time = await redis.get(f"{key}:start_time")
 | 
						|
                    if start_time:
 | 
						|
                        start_time = datetime.fromisoformat(start_time.decode())
 | 
						|
                    current_requests = await redis.get(key)
 | 
						|
                    current_requests = int(current_requests) if current_requests else 0
 | 
						|
 | 
						|
                    if start_time:
 | 
						|
                        elapsed_time = (datetime.now() - start_time).total_seconds()
 | 
						|
                    else:
 | 
						|
                        elapsed_time = period
 | 
						|
 | 
						|
                    if elapsed_time >= period:
 | 
						|
                        await redis.set(key, 1)
 | 
						|
                        await redis.set(f"{key}:start_time", datetime.now().isoformat())
 | 
						|
                        return
 | 
						|
                    else:
 | 
						|
                        if current_requests < max_requests:
 | 
						|
                            await redis.incr(key)
 | 
						|
                            return
 | 
						|
                        else:
 | 
						|
                            await asyncio.sleep(period - elapsed_time)
 | 
						|
                            await redis.set(key, 1)
 | 
						|
                            await redis.set(f"{key}:start_time", datetime.now().isoformat())
 | 
						|
                except aioredis.RedisError as e:
 | 
						|
                    logging.error(f"Redis error: {e}")
 | 
						|
                    await asyncio.sleep(1)
 | 
						|
 | 
						|
    async def clear_locks(self):
 | 
						|
        redis = await RedisConnectionManager.get_redis_connection()
 | 
						|
        keys = []
 | 
						|
        async for key in redis.scan_iter('*_lock*'):
 | 
						|
            keys.append(key)
 | 
						|
        if not keys:
 | 
						|
            return
 | 
						|
        await redis.delete(*keys)
 | 
						|
 | 
						|
    async def acquire_wildberries(self, key):
 | 
						|
        max_requests = 300
 | 
						|
        period = 60
 | 
						|
        await self.acquire('wildberries:' + key, max_requests, period)
 | 
						|
 | 
						|
    async def acquire_ozon(self, key):
 | 
						|
        max_requests = 80
 | 
						|
        period = 60
 | 
						|
        await self.acquire('ozon:' + key, max_requests, period)
 | 
						|
 | 
						|
    async def acquire_yandexmarket(self, key):
 | 
						|
        max_requests = 50
 | 
						|
        period = 60
 | 
						|
        await self.acquire('yandexmarket:' + key, max_requests, period)
 |