"""
Base processor classes for chemsets tasks.
Provides common patterns and structure for task-specific processors.
"""
import os
import statistics
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Dict, List, Optional, Any
import datasets

# Try different import paths for flexibility  
try:
    from lm_eval.tasks.chemsets.common.rpc_base import rpc_client_call
except ImportError:
    try:
        from chemsets.common.rpc_base import rpc_client_call
    except ImportError:
        try:
            from rpc_base import rpc_client_call
        except ImportError:
            # Last resort - direct import from same directory
            import os
            import sys
            current_dir = os.path.dirname(__file__)
            if current_dir not in sys.path:
                sys.path.insert(0, current_dir)
            from rpc_base import rpc_client_call


class BaseTaskProcessor(ABC):
    """
    Base class for task-specific processors.
    Each chemsets subtask should inherit from this and implement the abstract methods.
    """
    
    def __init__(self, task_name: str, rpc_base_url: Optional[str] = None):
        self.task_name = task_name
        self.rpc_base_url = rpc_base_url
    
    def rpc_call(self, function_name: str, **kwargs) -> Any:
        """Make an RPC call for this task, with auto-launch support."""
        # # Try to ensure server is running (if auto-launch is enabled)
        # try:
        #     try:
        #         from lm_eval.tasks.chemsets.auto_launcher import ensure_server_running
        #     except ImportError:
        #         try:
        #             from chemsets.auto_launcher import ensure_server_running
        #         except ImportError:
        #             import sys
        #             import os
        #             sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
        #             from auto_launcher import ensure_server_running
            
        #     ensure_server_running(self.task_name)
        # except ImportError:
        #     pass  # Auto-launcher not available, proceed anyway
        # except Exception as e:
        #     # Log warning but don't fail - server might be manually started
        #     import logging
        #     logging.getLogger(__name__).warning(f"Auto-launch attempt failed for {self.task_name}: {e}")
        
        return rpc_client_call(self.task_name, function_name, self.rpc_base_url, **kwargs)
    
    @abstractmethod
    def doc_to_text(self, doc: dict) -> str:
        """Convert a document to the text that should be fed to the model."""
        pass
    
    @abstractmethod
    def process_docs(self, dataset: datasets.Dataset) -> datasets.Dataset:
        """Process the raw dataset documents into the format needed by the task."""
        pass
    
    @abstractmethod
    def process_results(self, doc: dict, results: List[str]) -> Dict[str, Any]:
        """Process the model results and compute metrics."""
        pass


class StandardChemProcessor(BaseTaskProcessor):
    """
    Standard processor for chemistry tasks with common patterns.
    Provides default implementations that can be customized by subclasses.
    """
    
    def __init__(self, task_name: str, rpc_base_url: Optional[str] = None, 
                 evaluation_function: str = None):
        super().__init__(task_name, rpc_base_url)
        self.evaluation_function = evaluation_function or f"{task_name}_bencheval"
    
    def process_results(self, doc: dict, results: List[str]) -> Dict[str, Any]:
        """
        Standard process_results implementation with common patterns.
        Can be customized by overriding _extract_metrics or _process_single_result.
        """
        metrics = defaultdict(list)
        
        # Handle nested list results
        if isinstance(results[0], list):
            results = results[0]
        
        # Validate results format
        if not (isinstance(results, list) and isinstance(results[0], str)):
            raise ValueError(f"Results must be a list of strings, got {type(results[0])}!")
        
        per_category_rewards = defaultdict(list)
        
        for i, result in enumerate(results, start=1):
            if not isinstance(result, str):
                raise ValueError(f"Result must be string, got {type(result)}!")
            
            # Get evaluation from RPC
            eval_result = self.rpc_call(self.evaluation_function, row=doc, answer=result)
            
            # Extract category and reward
            category = eval_result.get('problem_cat', 'unknown')
            reward = eval_result.get('reward', 0)
            
            per_category_rewards[category].append(reward)
            metrics['extracted_answers'].append(eval_result.get('extracted_answer', ''))
            
            # Custom metric extraction
            self._extract_additional_metrics(eval_result, metrics, doc)
        
        # Calculate category averages
        for category, rewards in sorted(per_category_rewards.items()):
            if rewards:  # Avoid division by zero
                avg_reward = statistics.mean(rewards)
                metrics[category] = avg_reward
                print(f"In category {category!r} of {len(rewards)} questions, "
                      f"average reward was {avg_reward:.3f}.")
        
        return dict(metrics)
    
    def _extract_additional_metrics(self, eval_result: dict, metrics: dict, doc: dict):
        """
        Hook for extracting additional metrics from eval_result.
        Override this in subclasses to add task-specific metrics.
        """
        pass
    
    def _get_document_fields(self, doc: dict, required_fields: List[str], 
                           optional_fields: Dict[str, Any] = None) -> dict:
        """
        Helper to extract and validate document fields.
        
        Args:
            doc: The document dictionary
            required_fields: List of required field names
            optional_fields: Dict of {field_name: default_value} for optional fields
            
        Returns:
            Dictionary with extracted fields
        """
        result = {}
        optional_fields = optional_fields or {}
        
        # Extract required fields
        for field in required_fields:
            value = doc.get(field)
            if value is None:
                raise ValueError(f"{field} must be provided in the document!")
            result[field] = value
        
        # Extract optional fields
        for field, default in optional_fields.items():
            result[field] = doc.get(field, default)
        
        return result