#!/usr/bin/env python3
"""
NLI Classifier Load Balancer - Fixed and Optimized
Routes requests between two NLI classifier servers with failover and load balancing
"""

import os
import asyncio
import aiohttp
import time
import random
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from enum import Enum
from dotenv import load_dotenv

# Load environment variables from the correct path
load_dotenv(os.path.join(os.path.dirname(__file__), "..", "..", ".env"))

# Load classifier URLs from environment variables
CLASSIFIER_URL = os.getenv("CLASSIFIER_URL")
CLASSIFIER_URL_2 = os.getenv("CLASSIFIER_URL_2")

class ServerStatus(Enum):
    HEALTHY = "healthy"
    BUSY = "busy"
    UNHEALTHY = "unhealthy"

@dataclass
class ServerInfo:
    url: str
    status: ServerStatus = ServerStatus.HEALTHY
    last_check: float = 0
    response_time: float = 0
    error_count: int = 0
    request_count: int = 0

class NLIClassifierLoadBalancer:
    """Load balancer for NLI classifier servers with health checking and failover"""
    
    def __init__(self, 
                 primary_url: str = None, 
                 secondary_url: str = None,
                 health_check_interval: int = 30,
                 max_errors: int = 3,
                 timeout: int = 120):
        """
        Initialize the load balancer
        
        Args:
            primary_url: Primary NLI classifier server URL
            secondary_url: Secondary NLI classifier server URL
            health_check_interval: How often to check server health (seconds)
            max_errors: Maximum errors before marking server as unhealthy
            timeout: Request timeout in seconds
        """
        self.primary_url = primary_url or CLASSIFIER_URL
        self.secondary_url = secondary_url or CLASSIFIER_URL_2
        self.health_check_interval = health_check_interval
        self.max_errors = max_errors
        self.timeout = timeout
        
        # Initialize servers
        self.servers = [
            ServerInfo(url=self.primary_url),
            ServerInfo(url=self.secondary_url)
        ]
        
        self.session = None
        self._health_check_task = None
        
        print(f"🔄 NLI Load Balancer initialized:")
        print(f"   Primary: {self.primary_url}")
        print(f"   Secondary: {self.secondary_url}")
    
    async def __aenter__(self):
        """Async context manager entry"""
        self.session = aiohttp.ClientSession(
            connector=aiohttp.TCPConnector(limit=100, limit_per_host=50),
            timeout=aiohttp.ClientTimeout(total=self.timeout)
        )
        
        # Start health checking
        self._health_check_task = asyncio.create_task(self._health_check_loop())
        return self
    
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Async context manager exit"""
        if self._health_check_task:
            self._health_check_task.cancel()
            try:
                await self._health_check_task
            except asyncio.CancelledError:
                pass
        
        if self.session:
            await self.session.close()
    
    async def _health_check_loop(self):
        """Background health checking loop"""
        while True:
            try:
                await self._check_all_servers()
                await asyncio.sleep(self.health_check_interval)
            except asyncio.CancelledError:
                break
            except Exception as e:
                print(f"⚠️ Health check error: {e}")
                await asyncio.sleep(5)
    
    async def _check_all_servers(self):
        """Check health of all servers"""
        tasks = [self._check_server_health(server) for server in self.servers]
        await asyncio.gather(*tasks, return_exceptions=True)
    
    async def _check_server_health(self, server: ServerInfo):
        """Check health of a single server"""
        try:
            start_time = time.time()
            async with self.session.get(f"{server.url}/health", timeout=10) as resp:
                server.response_time = time.time() - start_time
                if resp.status == 200:
                    server.status = ServerStatus.HEALTHY
                    server.error_count = 0
                else:
                    server.status = ServerStatus.UNHEALTHY
                    server.error_count += 1
        except Exception as e:
            server.status = ServerStatus.UNHEALTHY
            server.error_count += 1
            server.response_time = 999  # High response time for failed requests
            print(f"⚠️ Server {server.url} health check failed: {e}")
        
        server.last_check = time.time()
    
    def _select_server(self) -> Optional[ServerInfo]:
        """Select the best available server using round-robin with health awareness"""
        healthy_servers = [s for s in self.servers if s.status == ServerStatus.HEALTHY]
        
        if not healthy_servers:
            # If no healthy servers, try any server
            available_servers = [s for s in self.servers if s.status != ServerStatus.UNHEALTHY]
            if not available_servers:
                return None
            healthy_servers = available_servers
        
        # Sort by response time and request count for load balancing
        healthy_servers.sort(key=lambda s: (s.response_time, s.request_count))
        
        # Use round-robin among the best servers
        best_servers = [s for s in healthy_servers if s.response_time == healthy_servers[0].response_time]
        return random.choice(best_servers)
    
    async def predict(self, pairs: List[List[str]]) -> Dict[str, Any]:
        """
        Make prediction request with load balancing and failover
        
        Args:
            pairs: List of code pairs to classify
            
        Returns:
            API response from the classifier
        """
        max_retries = len(self.servers)
        
        for attempt in range(max_retries):
            server = self._select_server()
            if not server:
                raise Exception("No available NLI classifier servers")
            
            try:
                server.request_count += 1
                start_time = time.time()
                
                async with self.session.post(
                    f"{server.url}/predict", 
                    json={"pairs": pairs}, 
                    timeout=self.timeout
                ) as resp:
                    response_time = time.time() - start_time
                    server.response_time = response_time
                    
                    if resp.status == 200:
                        server.status = ServerStatus.HEALTHY
                        server.error_count = 0
                        result = await resp.json()
                        return result
                    elif resp.status == 429:  # Too Many Requests
                        server.status = ServerStatus.BUSY
                        print(f"⚠️ Server {server.url} is busy (429), trying next server")
                        continue
                    else:
                        server.status = ServerStatus.UNHEALTHY
                        server.error_count += 1
                        print(f"⚠️ Server {server.url} returned status {resp.status}")
                        continue
                        
            except asyncio.TimeoutError:
                server.status = ServerStatus.BUSY
                server.error_count += 1
                print(f"⚠️ Server {server.url} timed out, trying next server")
                continue
            except Exception as e:
                server.status = ServerStatus.UNHEALTHY
                server.error_count += 1
                print(f"⚠️ Server {server.url} error: {e}, trying next server")
                continue
        
        raise Exception("All NLI classifier servers failed")
    
    def get_server_stats(self) -> Dict[str, Any]:
        """Get statistics about server performance"""
        stats = {
            "servers": [],
            "total_requests": sum(s.request_count for s in self.servers),
            "healthy_servers": len([s for s in self.servers if s.status == ServerStatus.HEALTHY]),
            "total_servers": len(self.servers)
        }
        
        for server in self.servers:
            stats["servers"].append({
                "url": server.url,
                "status": server.status.value,
                "request_count": server.request_count,
                "error_count": server.error_count,
                "response_time": server.response_time,
                "last_check": server.last_check
            })
        
        return stats

# Example usage and testing
async def test_load_balancer():
    """Test the load balancer with dummy data"""
    print("🧪 Testing NLI Load Balancer")
    print("=" * 40)
    
    async with NLIClassifierLoadBalancer() as lb:
        # Test with dummy pairs
        test_pairs = [
            ["user authentication", "login process"],
            ["database connection", "query execution"],
            ["file upload", "data processing"]
        ]
        
        try:
            result = await lb.predict(test_pairs)
            print(f"✅ Prediction successful: {len(result.get('predictions', []))} results")
            
            # Print server stats
            stats = lb.get_server_stats()
            print(f"\n📊 Server Statistics:")
            print(f"   Total requests: {stats['total_requests']}")
            print(f"   Healthy servers: {stats['healthy_servers']}/{stats['total_servers']}")
            
            for server_stat in stats['servers']:
                print(f"   {server_stat['url']}: {server_stat['status']} "
                      f"(requests: {server_stat['request_count']}, "
                      f"errors: {server_stat['error_count']}, "
                      f"response_time: {server_stat['response_time']:.2f}s)")
                      
        except Exception as e:
            print(f"❌ Load balancer test failed: {e}")

if __name__ == "__main__":
    asyncio.run(test_load_balancer())
