"""
Base RPC client and server functionality for chemsets tasks.
Consolidates the common RPC code that was duplicated across tasks.
"""
import requests
import traceback
import logging
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from json import loads, dumps
from typing import Any, Callable, Dict, Optional
# Try different import paths for flexibility
try:
    from lm_eval.tasks.chemsets.server_registry import get_task_port, get_base_url, get_registry
except ImportError:
    try:
        from chemsets.server_registry import get_task_port, get_base_url, get_registry
    except ImportError:
        import sys
        import os
        sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
        from server_registry import get_task_port, get_base_url, get_registry

logger = logging.getLogger(__name__)

class RPCClient:
    """Generic RPC client for chemsets tasks."""
    
    def __init__(self, task_name: str, base_url: Optional[str] = None):
        self.task_name = task_name
        self._base_url = base_url
    
    @property
    def base_url(self) -> str:
        """Get the base URL for this task's RPC server."""
        if self._base_url:
            return self._base_url
        return get_base_url(self.task_name)
    
    def call(self, function_name: str, **kwargs) -> Any:
        """
        Make an RPC call to the server.
        
        Args:
            function_name: Name of the function to call
            **kwargs: Arguments to pass to the function
            
        Returns:
            The result from the server
        """
        try:
            # Use longer timeout for chemistry evaluation functions that may run subprocesses
            timeout = 300  # 5 minutes timeout for chemistry evaluations
            response = requests.post(f"{self.base_url}/{function_name}", json=kwargs, timeout=timeout)
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            logger.error(f"RPC call failed for {self.task_name}.{function_name}: {e}")
            raise
        except Exception as e:
            logger.error(f"Unexpected error in RPC call: {e}")
            raise


class RPCServer:
    """Generic RPC server for chemsets tasks."""
    
    N_RETRIES = 5
    
    def __init__(self, task_name: str, port: Optional[int] = None):
        self.task_name = task_name
        self.port = port or get_task_port(task_name)
        self.routes: Dict[str, Callable] = {}
        self.server = None
    
    def register_function(self, path: str, func: Callable):
        """Register a function to handle requests at a specific path."""
        self.routes[f"/{path}"] = func
        logger.info(f"Registered function '{path}' for task '{self.task_name}'")
    
    def route(self, path: str):
        """Decorator to register a function as a route handler."""
        def decorator(func: Callable):
            self.register_function(path, func)
            return func
        return decorator
    
    def start_server(self):
        """Start the RPC server."""
        handler_class = self._create_handler_class()
        
        try:
            self.server = ThreadingHTTPServer(('localhost', self.port), handler_class)
            logger.info(f"Starting RPC server for '{self.task_name}' on port {self.port}")
            
            # Start serving requests first (this blocks)
            # We'll mark as started in a different thread or after binding
            logger.info(f"Server bound to port {self.port}, starting to serve requests...")
            
            # Mark server as started in registry (moved before serve_forever to avoid blocking)
            try:
                get_registry().mark_server_started(self.task_name, self.server)
                logger.info(f"Server marked as started in registry")
            except Exception as registry_error:
                logger.warning(f"Failed to mark server as started in registry: {registry_error}")
                # Continue anyway - this shouldn't block the server
            
            self.server.serve_forever()
        except Exception as e:
            logger.error(f"Failed to start server for '{self.task_name}': {e}")
            get_registry().mark_server_stopped(self.task_name)
            raise
    
    def stop_server(self):
        """Stop the RPC server."""
        if self.server:
            logger.info(f"Stopping RPC server for '{self.task_name}'")
            self.server.shutdown()
            self.server.server_close()
            get_registry().mark_server_stopped(self.task_name)
    
    def _create_handler_class(self):
        """Create a request handler class with access to routes."""
        routes = self.routes
        task_name = self.task_name
        
        class RPCHandler(BaseHTTPRequestHandler):
            def do_POST(self):
                content_length = int(self.headers['Content-Length'])
                raw_data = self.rfile.read(content_length)
                
                for current_try in range(RPCServer.N_RETRIES):
                    try:
                        data = loads(raw_data)
                        func = routes.get(self.path)
                        
                        if not func:
                            logger.warning(f"Endpoint not found: {self.path} for task {task_name}")
                            self.send_error(404, "Endpoint not found")
                            return
                        
                        result = func(**data)
                        self.send_response(200)
                        self.send_header('Content-Type', 'application/json')
                        self.end_headers()
                        self.wfile.write(dumps(result).encode())
                        return
                        
                    except Exception as e:
                        if current_try == RPCServer.N_RETRIES - 1:
                            logger.error(f"Request failed after {RPCServer.N_RETRIES} retries: {e}")
                            logger.error(traceback.format_exc())
                            self.send_error(500, f"Internal server error: {str(e)}")
                        else:
                            logger.warning(f"Request failed, retrying ({current_try + 1}/{RPCServer.N_RETRIES}): {e}")
            
            def log_message(self, format, *args):
                """Override to use our logger instead of stderr."""
                logger.info(f"{self.address_string()} - {format % args}")
        
        return RPCHandler


def create_rpc_client(task_name: str, base_url: Optional[str] = None) -> RPCClient:
    """Factory function to create an RPC client for a task."""
    return RPCClient(task_name, base_url)

def rpc_client_call(task_name: str, function_name: str, base_url: Optional[str] = None, **kwargs) -> Any:
    """
    Convenience function for making a single RPC call.
    Compatible with the existing rpc_client function signature.
    Includes auto-launch support if enabled.
    """
    # Try to ensure server is running (if auto-launch is enabled)
    if base_url is None:  # Only auto-launch for default URLs
        try:
            try:
                from lm_eval.tasks.chemsets.auto_launcher import ensure_server_running
            except ImportError:
                try:
                    from chemsets.auto_launcher import ensure_server_running
                except ImportError:
                    import sys
                    import os
                    sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
                    from auto_launcher import ensure_server_running
            
            ensure_server_running(task_name)
        except ImportError:
            pass  # Auto-launcher not available, proceed anyway
        except Exception as e:
            # Log warning but don't fail - server might be manually started
            logger.warning(f"Auto-launch attempt failed for {task_name}: {e}")
    
    client = create_rpc_client(task_name, base_url)
    try:
        return client.call(function_name, **kwargs)
    except (requests.exceptions.JSONDecodeError, 
            requests.exceptions.Timeout, 
            requests.exceptions.ConnectTimeout, 
            requests.exceptions.ReadTimeout,
            requests.exceptions.ConnectionError) as e:
        # Handle errors at this level to prevent auto-launcher restart loops
        if function_name.endswith('eval') or 'eval' in function_name:
            # Extract document ID for debugging
            doc_info = ""
            if 'row' in kwargs and isinstance(kwargs['row'], dict):
                doc_id = kwargs['row'].get('uuid') or kwargs['row'].get('id') or kwargs['row'].get('doc_id')
                if doc_id:
                    doc_info = f" [doc_id: {doc_id}]"
            
            error_msg = f"RPC call failed for {task_name}.{function_name}{doc_info}: {e}. Returning default failure result (reward=0)"
            logger.warning(error_msg)
            
            return {
                'problem_cat': 'unknown',
                'reward': 0.0,
                'extracted_answer': f'ERROR: {type(e).__name__} - {str(e)}{doc_info}'
            }
        raise