import logging
from typing import List, Any, Optional, Dict
import numpy as np
from omegaconf import DictConfig
from .base_evaluator import BaseEvaluator

logger = logging.getLogger(__name__)


class MLEvaluator(BaseEvaluator):
    """
    ML-based evaluator that uses trained models from MLflow.

    This evaluator wraps trained models (ESM, SVR, SVC, etc.) and provides
    a unified interface for making predictions on protein sequences.
    """

    def __init__(
        self,
        name: str,
        task_type: str,
        model: Any,  # mlflow.pyfunc.PyFuncModel wrapping BasePredictor
        cfg: Optional[DictConfig] = None,
        embedding_manager: Any = None,
        is_seq_prob_evaluator: bool = False,
        threshold: float = 0.5,
        **kwargs,
    ):
        """
        Initialize the ML evaluator.

        Args:
            name: Name of the evaluator
            task_type: Type of task ('filter', 'score', or 'seq_prob')
            model: Loaded pyfunc model from MLflow
            cfg: Configuration for the model (optional)
            embedding_manager: Embedding manager for centralized embedding computation
            is_seq_prob_evaluator: If True, use model's sequence_probability method
            **kwargs: Additional configuration parameters
        """
        super().__init__(name, task_type, **kwargs)
        self.model = model  # PyFuncModel instance
        self.cfg = cfg
        self.embedding_manager = embedding_manager
        self.is_seq_prob_evaluator = is_seq_prob_evaluator
        self.threshold = threshold
        self.setup()

    def setup(self, **kwargs) -> None:
        """Setup - inject embedding manager into predictor if available."""
        if self.embedding_manager is not None:
            # Access the underlying predictor through pyfunc wrapper
            if hasattr(self.model, "_model_impl"):
                predictor = self.model._model_impl
                if hasattr(predictor, "embedding_manager"):
                    predictor.embedding_manager = self.embedding_manager
                    logger.info(f"Injected embedding manager into {self.name}")

        logger.info(f"Setup complete for evaluator: {self.name}")

    def predict(self, sequences: List[str]) -> Dict[str, Any]:
        """Make predictions using the pyfunc model."""

        if self.is_seq_prob_evaluator:
            result = self.model.predict(sequences)
            if (
                self.task_type == "filter"
                and isinstance(result, dict)
                and "perplexities" in result
            ):
                return {
                    "predictions": np.array(result["perplexities"]) < self.threshold
                }
            if isinstance(result, dict) and "predictions" in result:
                return result
            return {"predictions": np.array(result)}

        predictions = self.model.predict(sequences)
        if isinstance(predictions, dict) and "predictions" in predictions:
            logger.debug(f"Predictions shape: {predictions['predictions'].shape}")
            return predictions
        if isinstance(predictions, np.ndarray):
            logger.debug(f"Predictions shape: {predictions.shape}")
            return {"predictions": predictions}

        if isinstance(predictions, list):
            logger.debug(f"Predictions shape: {np.array(predictions).shape}")
            return {"predictions": np.array(predictions)}

        raise ValueError(f"Unknown predictions type: {type(predictions)}")
