"""High-Fidelity Blackbox Module."""
from typing import Dict, List, Any, Optional
from pathlib import Path

# Import evaluators for all tasks
from .fecr import FeCrEvaluator
from .cof import COFEvaluator
from .sandwich import SandwichEvaluator
from .fullerene import FullereneEvaluator
from .p3ht import P3HTEvaluator
from .pce10 import PCE10Evaluator

# Mapping from task name to evaluator class
TASK_EVALUATORS = {
    "FeCr": FeCrEvaluator,
    "COF": COFEvaluator,
    "Sandwich": SandwichEvaluator,
    "Fullerene": FullereneEvaluator,
    "P3HT": P3HTEvaluator,
    "PCE10": PCE10Evaluator,
}

# Mapping from task name to target column name (extracted from evaluators)
TASK_TARGET_COLS = {
    "FeCr": "f",  # from FeCrEvaluator.METRIC_KEY
    "COF": "gcmc_y",
    "Sandwich": "Total_Score",
    "Fullerene": "mole_fraction",
    "P3HT": "conductivity",
    "PCE10": "Degradation",
}

# Mapping from configuration parameter names to evaluator parameter names
TASK_PARAM_MAPPING = {
    "FeCr": {"fen": "Fe", "crn": "Cr", "hn": "H"},
    # Other tasks use configuration names directly when no mapping is defined
}


class HighFidelityBlackbox:
    """Unified high-fidelity blackbox wrapper."""
    
    def __init__(
        self,
        task_name: str,
        csv_path: Optional[str] = None,
        feature_names: Optional[List[str]] = None,
        target_col: Optional[str] = None
    ):
        """Initialize the high-fidelity blackbox."""
        if task_name not in TASK_EVALUATORS:
            available = ", ".join(TASK_EVALUATORS.keys())
            raise ValueError(
                f"Unknown task: {task_name}. Available tasks: {available}"
            )
        
        self.task_name = task_name
        EvaluatorClass = TASK_EVALUATORS[task_name]
        
        if csv_path is None:
            self.evaluator = EvaluatorClass()
        else:
            self.evaluator = EvaluatorClass(csv_path=csv_path)
        
        if feature_names is None:
            import importlib
            module_name = f"high_fidelity.{task_name.lower()}"
            try:
                module = importlib.import_module(module_name)
                self.feature_names = getattr(module, "PARAM_NAMES", None)
            except:
                self.feature_names = None
        else:
            self.feature_names = feature_names
        
        if target_col is None:
            self.target_col = TASK_TARGET_COLS.get(task_name, "target")
        else:
            self.target_col = target_col
        
        self.bounds = self.evaluator.bounds
        
        self.X = None
        self.y = None
        self._try_extract_data()
    
    def _try_extract_data(self):
        """Extract underlying data from the evaluator."""
        try:
            if hasattr(self.evaluator, 'X_scaled') and hasattr(self.evaluator, 'y'):
                if hasattr(self.evaluator, 'lo') and hasattr(self.evaluator, 'span'):
                    self.X = self.evaluator.X_scaled * self.evaluator.span + self.evaluator.lo
                    self.y = self.evaluator.y
            elif hasattr(self.evaluator, 'levels') and hasattr(self.evaluator, 'values'):
                import numpy as np
                from itertools import product
                grids = self.evaluator.levels
                all_points = list(product(*[g.tolist() for g in grids]))
                self.X = np.array(all_points)
                values_flat = self.evaluator.values.flatten()
                self.y = values_flat
        except Exception:
            pass
    
    def evaluate(self, x: Dict[str, float]) -> float:
        """Evaluate a single input point."""
        x_mapped = self._map_param_names(x)
        result = self.evaluator.evaluate_from_dict(x_mapped)
        return float(result[self.target_col])
    
    def _map_param_names(self, x: Dict[str, float]) -> Dict[str, float]:
        """Map parameter names from configuration to evaluator names."""
        if self.task_name in TASK_PARAM_MAPPING:
            mapping = TASK_PARAM_MAPPING[self.task_name]
            x_mapped = {}
            for config_name, value in x.items():
                evaluator_name = mapping.get(config_name, config_name)
                x_mapped[evaluator_name] = value
            return x_mapped
        
        if self.feature_names is None:
            return x
        
        import importlib
        module_name = f"high_fidelity.{self.task_name.lower()}"
        try:
            module = importlib.import_module(module_name)
            evaluator_param_names = getattr(module, "PARAM_NAMES", None)
        except:
            evaluator_param_names = None
        
        if evaluator_param_names is None:
            return x
        
        if set(x.keys()) == set(evaluator_param_names):
            return x
        
        if len(self.feature_names) == len(evaluator_param_names):
            x_mapped = {}
            for config_name, evaluator_name in zip(self.feature_names, evaluator_param_names):
                if config_name in x:
                    x_mapped[evaluator_name] = x[config_name]
            return x_mapped
        
        return x
    
    def evaluate_batch(self, X: List[Dict[str, float]]) -> List[float]:
        """Batch evaluation for a list of input points."""
        results = []
        for x in X:
            x_mapped = self._map_param_names(x)
            result = self.evaluator.evaluate_from_dict(x_mapped)
            results.append(float(result[self.target_col]))
        return results
    
    @classmethod
    def from_config(cls, config) -> 'HighFidelityBlackbox':
        """Create a blackbox from a legacy TaskConfig."""
        task_name = config.name
        
        csv_path = getattr(config.hf, 'csv_path', None)
        if csv_path is None:
            if hasattr(config.hf, 'full_data_filename'):
                data_dir = Path(config.data_root) / config.name
                csv_path = str(data_dir / config.hf.full_data_filename)
                if not Path(csv_path).exists():
                    csv_path = None
        
        return cls(
            task_name=task_name,
            csv_path=csv_path,
            feature_names=config.space.names,
            target_col=getattr(config.hf, 'target_col', None)
        )
    
    def __call__(self, x: Dict[str, float]) -> float:
        return self.evaluate(x)


__all__ = ['HighFidelityBlackbox', 'TASK_EVALUATORS', 'TASK_TARGET_COLS']

