"""Redis-based rate limiting service"""
import redis.asyncio as redis
from app.core.logging import get_logger
from app.core.exceptions import RateLimitExceeded

logger = get_logger(__name__)


class RedisRateLimiter:
    """Redis-backed rate limiter for production"""
    
    def __init__(self, redis_client: redis.Redis, max_requests: int = 30, time_window: int = 60):
        self.redis = redis_client
        self.max_requests = max_requests
        self.time_window = time_window
        self.prefix = "ratelimit:"
    
    async def is_allowed(self, identifier: str) -> bool:
        """Check if request is allowed using sliding window"""
        key = f"{self.prefix}{identifier}"
        
        try:
            # Use Redis pipeline for atomic operations
            pipe = self.redis.pipeline()
            
            # Get current count
            current = await self.redis.get(key)
            
            if current is None:
                # First request - set with expiry
                await self.redis.setex(key, self.time_window, "1")
                return True
            
            count = int(current)
            
            if count >= self.max_requests:
                logger.warning(f"Rate limit exceeded for: {identifier}")
                return False
            
            # Increment counter
            await self.redis.incr(key)
            return True
            
        except Exception as e:
            logger.error(f"Redis error in rate limiter: {e}")
            # Fail open - allow request on Redis error
            return True
    
    async def check_or_raise(self, identifier: str) -> None:
        """Check rate limit or raise exception"""
        if not await self.is_allowed(identifier):
            raise RateLimitExceeded(f"Rate limit exceeded for {identifier}")
    
    async def get_remaining(self, identifier: str) -> int:
        """Get remaining requests"""
        key = f"{self.prefix}{identifier}"
        
        try:
            current = await self.redis.get(key)
            if current is None:
                return self.max_requests
            
            count = int(current)
            return max(0, self.max_requests - count)
            
        except Exception as e:
            logger.error(f"Redis error getting remaining: {e}")
            return self.max_requests
