"""
Middleware for FastAPI
Rate limiting, logging, security headers
"""
from fastapi import Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from datetime import datetime, timedelta
from collections import defaultdict
import time
import logging
from config import RATE_LIMIT_ENABLED, RATE_LIMIT_CALLS, RATE_LIMIT_PERIOD

logger = logging.getLogger(__name__)

# ==================== RATE LIMITING ====================

class RateLimitMiddleware(BaseHTTPMiddleware):
    """
    محدود کردن تعداد درخواست‌ها از هر IP
    جلوگیری از abuse و DDoS
    """
    def __init__(self, app):
        super().__init__(app)
        self.requests = defaultdict(list)
        self.enabled = RATE_LIMIT_ENABLED
        self.max_calls = RATE_LIMIT_CALLS
        self.period = RATE_LIMIT_PERIOD  # seconds
    
    async def dispatch(self, request: Request, call_next):
        if not self.enabled:
            return await call_next(request)
        
        # Skip rate limiting for health check
        if request.url.path == "/api/health":
            return await call_next(request)
        
        client_ip = request.client.host
        current_time = time.time()
        
        # پاک کردن درخواست‌های قدیمی
        self.requests[client_ip] = [
            req_time for req_time in self.requests[client_ip]
            if current_time - req_time < self.period
        ]
        
        # بررسی تعداد درخواست‌ها
        if len(self.requests[client_ip]) >= self.max_calls:
            logger.warning(f"⚠️ Rate limit exceeded for IP: {client_ip}")
            raise HTTPException(
                status_code=429,
                detail=f"تعداد درخواست‌های شما بیش از حد مجاز است. لطفاً {self.period} ثانیه صبر کنید."
            )
        
        # ثبت درخواست جدید
        self.requests[client_ip].append(current_time)
        
        response = await call_next(request)
        
        # اضافه کردن headers
        response.headers["X-RateLimit-Limit"] = str(self.max_calls)
        response.headers["X-RateLimit-Remaining"] = str(self.max_calls - len(self.requests[client_ip]))
        response.headers["X-RateLimit-Reset"] = str(int(current_time + self.period))
        
        return response

# ==================== SECURITY HEADERS ====================

class SecurityHeadersMiddleware(BaseHTTPMiddleware):
    """اضافه کردن security headers به response"""
    
    async def dispatch(self, request: Request, call_next):
        response = await call_next(request)
        
        # Security headers
        response.headers["X-Content-Type-Options"] = "nosniff"
        response.headers["X-Frame-Options"] = "DENY"
        response.headers["X-XSS-Protection"] = "1; mode=block"
        response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
        
        # Content Security Policy
        response.headers["Content-Security-Policy"] = (
            "default-src 'self'; "
            "script-src 'self' 'unsafe-inline' 'unsafe-eval'; "
            "style-src 'self' 'unsafe-inline'; "
            "img-src 'self' data: https:; "
            "font-src 'self' data:; "
            "connect-src 'self' https:; "
        )
        
        return response

# ==================== REQUEST LOGGING ====================

class RequestLoggingMiddleware(BaseHTTPMiddleware):
    """لاگ کردن تمام درخواست‌ها"""
    
    async def dispatch(self, request: Request, call_next):
        start_time = time.time()
        
        # لاگ درخواست
        logger.info(f"➡️ {request.method} {request.url.path} from {request.client.host}")
        
        try:
            response = await call_next(request)
            
            # محاسبه زمان پردازش
            process_time = time.time() - start_time
            response.headers["X-Process-Time"] = str(process_time)
            
            # لاگ پاسخ
            logger.info(
                f"⬅️ {request.method} {request.url.path} "
                f"[{response.status_code}] in {process_time:.3f}s"
            )
            
            return response
            
        except Exception as e:
            process_time = time.time() - start_time
            logger.error(
                f"❌ {request.method} {request.url.path} "
                f"failed in {process_time:.3f}s: {str(e)}"
            )
            raise

# ==================== AUDIT LOG ====================

class AuditLogMiddleware(BaseHTTPMiddleware):
    """
    ثبت تمام عملیات مهم ادمین
    برای security و compliance
    """
    
    ADMIN_PATHS = [
        "/api/admin/users",
        "/api/admin/test",
        "/api/admin/settings",
        "/api/admin/announcements"
    ]
    
    async def dispatch(self, request: Request, call_next):
        # بررسی اینکه آیا این یک admin endpoint است
        is_admin_request = any(
            request.url.path.startswith(path) 
            for path in self.ADMIN_PATHS
        )
        
        if is_admin_request and request.method in ["POST", "PUT", "DELETE"]:
            # ثبت در audit log
            from database import db
            
            try:
                audit_entry = {
                    "timestamp": datetime.now().isoformat(),
                    "ip": request.client.host,
                    "method": request.method,
                    "path": request.url.path,
                    "user_agent": request.headers.get("user-agent", ""),
                    # TODO: add user_id from JWT token
                }
                
                if db:
                    await db.audit_logs.insert_one(audit_entry)
                    logger.info(f"📝 Audit log: {request.method} {request.url.path}")
            except Exception as e:
                logger.error(f"Failed to create audit log: {e}")
        
        return await call_next(request)
