"""Rate limiting service with Redis support"""
from datetime import datetime, timedelta
from typing import Dict, List
from app.core.logging import get_logger
from app.core.exceptions import RateLimitExceeded

logger = get_logger(__name__)


class RateLimiter:
    """In-memory rate limiter (use Redis in production)"""
    
    def __init__(self, max_requests: int = 30, time_window: int = 60):
        self.max_requests = max_requests
        self.time_window = time_window
        self.requests: Dict[str, List[datetime]] = {}
    
    def is_allowed(self, identifier: str) -> bool:
        """Check if request is allowed"""
        now = datetime.now()
        
        # Clean old entries
        if identifier in self.requests:
            self.requests[identifier] = [
                ts for ts in self.requests[identifier]
                if now - ts < timedelta(seconds=self.time_window)
            ]
        else:
            self.requests[identifier] = []
        
        # Check limit
        if len(self.requests[identifier]) >= self.max_requests:
            logger.warning(f"Rate limit exceeded for: {identifier}")
            return False
        
        # Add request
        self.requests[identifier].append(now)
        return True
    
    def check_or_raise(self, identifier: str) -> None:
        """Check rate limit or raise exception"""
        if not self.is_allowed(identifier):
            raise RateLimitExceeded(f"Rate limit exceeded for {identifier}")
    
    def get_remaining(self, identifier: str) -> int:
        """Get remaining requests"""
        if identifier not in self.requests:
            return self.max_requests
        
        now = datetime.now()
        valid_requests = [
            ts for ts in self.requests[identifier]
            if now - ts < timedelta(seconds=self.time_window)
        ]
        
        return max(0, self.max_requests - len(valid_requests))
