import logging
from typing import Dict, Any, Optional
import torch
from omegaconf import DictConfig

from .base_evaluator import BaseEvaluator
from .ml_evaluator import MLEvaluator
from .physics_evaluator import PhysicsEvaluator


logger = logging.getLogger(__name__)


class EvaluatorFactory:
    """
    Factory class for creating evaluators from different sources.

    This factory can create evaluators from:
    1. MLflow models (ML-based evaluators)
    2. Physics-based configurations (PyRosetta, FoldX, etc.)
    3. Custom scoring functions
    """

    @staticmethod
    def create_ml_evaluator(
        name: str,
        model: Any,
        model_type: str,
        task_type: str,
        cfg: Optional[DictConfig] = None,
        is_seq_prob_evaluator: bool = False,
        **kwargs,
    ) -> MLEvaluator:
        """
        Create an ML-based evaluator from a loaded model.

        Args:
            name: Name of the evaluator
            model: Loaded model from MLflow
            model_type: Type of model ('sklearn', 'pytorch', etc.)
            task_type: Type of task ('filter', 'score', or 'seq_prob')
            cfg: Configuration for the model (optional)
            is_seq_prob_evaluator: If True, use model's sequence_probability method
            **kwargs: Additional configuration parameters

        Returns:
            MLEvaluator: Configured ML evaluator
        """
        return MLEvaluator(
            name=name,
            task_type=task_type,
            model=model,
            model_type=model_type,
            cfg=cfg,
            is_seq_prob_evaluator=is_seq_prob_evaluator,
            **kwargs,
        )

    @staticmethod
    def create_physics_evaluator(
        name: str, task_type: str, evaluator_type: str = "pyrosetta", **kwargs
    ) -> PhysicsEvaluator:
        """
        Create a physics-based evaluator.

        Args:
            name: Name of the evaluator
            task_type: Type of task ('filter', 'score', or 'seq_prob')
            evaluator_type: Type of physics evaluator ('pyrosetta', 'foldx', 'custom')
            **kwargs: Additional configuration parameters

        Returns:
            PhysicsEvaluator: Configured physics evaluator
        """
        if evaluator_type == "pyrosetta":
            kwargs["use_pyrosetta"] = True
        elif evaluator_type == "foldx":
            kwargs["use_foldx"] = True
        elif evaluator_type == "custom":
            # Custom scoring function should be provided in kwargs
            pass
        else:
            logger.warning(f"Unknown physics evaluator type: {evaluator_type}")

        return PhysicsEvaluator(name=name, task_type=task_type, **kwargs)

    @staticmethod
    def create_evaluator_from_config(
        name: str, config: Dict[str, Any], **kwargs
    ) -> BaseEvaluator:
        """
        Create an evaluator from a configuration dictionary.

        Args:
            name: Name of the evaluator
            config: Configuration dictionary
            **kwargs: Additional parameters

        Returns:
            BaseEvaluator: Configured evaluator
        """
        evaluator_type = config.get("type", "ml")
        task_type = config.get("task_type", "score")

        if evaluator_type == "ml":
            # This would typically be used with a pre-loaded model
            model = kwargs.get("model")
            model_type = config.get("model_type", "pytorch")
            cfg = kwargs.get("cfg")
            is_seq_prob_evaluator = config.get("is_seq_prob_evaluator", False)

            if model is None:
                raise ValueError(f"Model must be provided for ML evaluator {name}")

            return EvaluatorFactory.create_ml_evaluator(
                name=name,
                model=model,
                model_type=model_type,
                task_type=task_type,
                cfg=cfg,
                is_seq_prob_evaluator=is_seq_prob_evaluator,
                **config,
            )

        elif evaluator_type == "physics":
            physics_type = config.get("physics_type", "pyrosetta")
            return EvaluatorFactory.create_physics_evaluator(
                name=name, task_type=task_type, evaluator_type=physics_type, **config
            )

        else:
            raise ValueError(f"Unknown evaluator type: {evaluator_type}")

    @staticmethod
    def create_evaluators_from_predictors(
        predictors: Dict[str, Any],
        physics_configs: Optional[Dict[str, Dict[str, Any]]] = None,
        cfg: Optional[DictConfig] = None,
        embedding_manager: Any = None,
    ) -> Dict[str, BaseEvaluator]:
        """
        Create evaluators from loaded predictors and optional physics configurations.

        Args:
            predictors: Dictionary of loaded predictors from MLflow
            physics_configs: Optional dictionary of physics evaluator configurations
            cfg: Configuration object
            embedding_manager: Embedding manager for centralized embedding computation

        Returns:
            Dict[str, BaseEvaluator]: Dictionary mapping evaluator names to evaluators
        """
        evaluators: Dict[str, BaseEvaluator] = {}

        # Create ML evaluators from predictors
        for predictor_key, predictor_info in predictors.items():
            model = predictor_info["model"]
            # compile the model if it is a pytorch model
            if predictor_info["model_type"] == "pytorch":
                model.to("cuda")
                torch.compile(model)

            model_type = predictor_info["model_type"]
            task_type = predictor_info["task_type"]
            stage_name = predictor_info["stage_name"]

            # Create evaluator name based on stage and model type
            evaluator_name = f"{stage_name}_{model_type}"

            ml_evaluator = EvaluatorFactory.create_ml_evaluator(
                name=evaluator_name,
                model=model,
                model_type=model_type,
                task_type=task_type,
                cfg=cfg,
                run_id=predictor_info["run_id"],
                run_info=predictor_info["run_info"],
                embedding_manager=embedding_manager,
            )

            evaluators[predictor_key] = ml_evaluator
            logger.info(f"Created ML evaluator: {evaluator_name}")

        # Create physics evaluators if configurations are provided
        if physics_configs:
            for name, config in physics_configs.items():
                physics_evaluator: BaseEvaluator = (
                    EvaluatorFactory.create_physics_evaluator(
                        name=name,
                        task_type=config.get("task_type", "score"),
                        evaluator_type=config.get("type", "pyrosetta"),
                        **config,
                    )
                )

                evaluators[name] = physics_evaluator
                logger.info(f"Created physics evaluator: {name}")

        return evaluators
