import logging
from typing import Dict, Any, Optional
import torch
from omegaconf import DictConfig

from .base_evaluator import BaseEvaluator
from .ml_evaluator import MLEvaluator
from .physics_evaluator import PhysicsEvaluator


logger = logging.getLogger(__name__)


class EvaluatorFactory:
    """
    Factory class for creating evaluators from different sources.

    This factory can create evaluators from:
    1. MLflow models (ML-based evaluators)
    2. Physics-based configurations (PyRosetta, etc.)
    """

    @staticmethod
    def create_ml_evaluator(
        name: str,
        model: Any,  # mlflow.pyfunc.PyFuncModel wrapping BasePredictor
        task_type: str,
        cfg: Optional[DictConfig] = None,
        is_seq_prob_evaluator: bool = False,
        **kwargs,
    ) -> MLEvaluator:
        """
        Create an ML-based evaluator from a loaded pyfunc model.

        Args:
            name: Name of the evaluator
            model: Loaded pyfunc model from MLflow
            task_type: Type of task ('filter', 'score', or 'seq_prob')
            cfg: Configuration for the model (optional)
            is_seq_prob_evaluator: If True, use model's sequence_probability method
            **kwargs: Additional configuration parameters

        Returns:
            MLEvaluator: Configured ML evaluator
        """
        return MLEvaluator(
            name=name,
            task_type=task_type,
            model=model,
            cfg=cfg,
            is_seq_prob_evaluator=is_seq_prob_evaluator,
            **kwargs,
        )

    @staticmethod
    def create_physics_evaluator(
        name: str, task_type: str, evaluator_type: str = "pyrosetta", **kwargs
    ) -> PhysicsEvaluator:
        """
        Create a physics-based evaluator.

        Args:
            name: Name of the evaluator
            task_type: Type of task ('filter', 'score', or 'seq_prob')
            evaluator_type: Type of physics evaluator ('pyrosetta', 'foldx', 'custom')
            **kwargs: Additional configuration parameters

        Returns:
            PhysicsEvaluator: Configured physics evaluator
        """
        if evaluator_type == "pyrosetta":
            kwargs["use_pyrosetta"] = True
        elif evaluator_type == "foldx":
            kwargs["use_foldx"] = True
        elif evaluator_type == "custom":
            # Custom scoring function should be provided in kwargs
            pass
        else:
            logger.warning(f"Unknown physics evaluator type: {evaluator_type}")

        return PhysicsEvaluator(name=name, task_type=task_type, **kwargs)

    @staticmethod
    def create_evaluator_from_config(
        name: str, config: Dict[str, Any], **kwargs
    ) -> BaseEvaluator:
        """
        Create an evaluator from a configuration dictionary.

        Args:
            name: Name of the evaluator
            config: Configuration dictionary
            **kwargs: Additional parameters

        Returns:
            BaseEvaluator: Configured evaluator
        """
        evaluator_type = config.get("type", "ml")
        task_type = config.get("task_type", "score")

        if evaluator_type == "ml":
            # This would typically be used with a pre-loaded model
            model = kwargs.get("model")
            model_type = config.get("model_type", "pytorch")
            cfg = kwargs.get("cfg")
            is_seq_prob_evaluator = config.get("is_seq_prob_evaluator", False)

            if model is None:
                raise ValueError(f"Model must be provided for ML evaluator {name}")

            return EvaluatorFactory.create_ml_evaluator(
                name=name,
                model=model,
                model_type=model_type,
                task_type=task_type,
                cfg=cfg,
                is_seq_prob_evaluator=is_seq_prob_evaluator,
                **config,
            )

        elif evaluator_type == "physics":
            physics_type = config.get("physics_type", "pyrosetta")
            return EvaluatorFactory.create_physics_evaluator(
                name=name, task_type=task_type, evaluator_type=physics_type, **config
            )

        else:
            raise ValueError(f"Unknown evaluator type: {evaluator_type}")

    @staticmethod
    def create_evaluators_from_predictors(
        predictors: Dict[str, Any],
        physics_configs: Optional[Dict[str, Dict[str, Any]]] = None,
        cfg: Optional[DictConfig] = None,
        embedding_managers: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, BaseEvaluator]:
        """
        Create evaluators from loaded predictors and optional physics configurations.

        Args:
            predictors: Dictionary of loaded predictors from MLflow
            physics_configs: Optional dictionary of physics evaluator configurations
            cfg: Configuration object
            embedding_managers: Dictionary of embedding managers keyed by embedder config

        Returns:
            Dict[str, BaseEvaluator]: Dictionary mapping evaluator names to evaluators
        """
        evaluators: Dict[str, BaseEvaluator] = {}

        # Create ML evaluators from predictors
        for predictor_key, predictor_info in predictors.items():
            model = predictor_info["model"]

            # Get the appropriate embedding manager for this predictor
            predictor_embedding_manager = None
            if embedding_managers:
                # Find the embedding manager that matches this predictor's embedder config
                predictor_embedding_manager = (
                    EvaluatorFactory._get_embedding_manager_for_predictor(
                        predictor_info, embedding_managers
                    )
                )

            task_type = predictor_info["task_type"]
            stage_name = predictor_info["stage_name"]

            # Create evaluator name based on stage
            evaluator_name = f"{stage_name}_{predictor_key}"

            ml_evaluator = EvaluatorFactory.create_ml_evaluator(
                name=evaluator_name,
                model=model,
                task_type=task_type,
                cfg=cfg,
                embedding_manager=predictor_embedding_manager,
            )

            evaluators[predictor_key] = ml_evaluator
            logger.info(f"Created ML evaluator: {evaluator_name}")

        # Create physics evaluators if configurations are provided
        if physics_configs:
            for name, config in physics_configs.items():
                physics_evaluator: BaseEvaluator = (
                    EvaluatorFactory.create_physics_evaluator(
                        name=name,
                        task_type=config.get("task_type", "score"),
                        evaluator_type=config.get("type", "pyrosetta"),
                        **config,
                    )
                )

                evaluators[name] = physics_evaluator
                logger.info(f"Created physics evaluator: {name}")

        return evaluators

    @staticmethod
    def _get_embedding_manager_for_predictor(
        predictor_info: Dict[str, Any], embedding_managers: Dict[str, Any]
    ) -> Optional[Any]:
        """Get the appropriate embedding manager for a predictor based on its embedder config."""
        run_info = predictor_info.get("run_info")
        if (
            not run_info
            or not hasattr(run_info, "data")
            or not hasattr(run_info.data, "params")
        ):
            return None

        params = run_info.data.params

        # Extract embedder-related parameters
        embedder_params = {}
        for key, value in params.items():
            if key.startswith("embedder."):
                embedder_params[key] = value

        if not embedder_params:
            return None

        # Create embedder key (same logic as in inference.py)
        import hashlib

        config_str = str(sorted(embedder_params.items()))
        embedder_key = hashlib.md5(config_str.encode()).hexdigest()[:8]

        return embedding_managers.get(embedder_key)
