import logging
from typing import List, Any, Optional
import numpy as np
import torch
from omegaconf import DictConfig
from .base_evaluator import BaseEvaluator
from haipr.predictor import BasePredictor

logger = logging.getLogger(__name__)


class MLEvaluator(BaseEvaluator):
    """
    ML-based evaluator that uses trained models from MLflow.

    This evaluator wraps trained models (ESM, SVR, SVC, etc.) and provides
    a unified interface for making predictions on protein sequences.
    """

    def __init__(
        self,
        name: str,
        task_type: str,
        model: Any,
        model_type: str,
        cfg: Optional[DictConfig] = None,
        embedding_manager: Any = None,
        is_seq_prob_evaluator: bool = False,
        threshold: float = 0.5,
        **kwargs,
    ):
        """
        Initialize the ML evaluator.

        Args:
            name: Name of the evaluator
            task_type: Type of task ('filter', 'score', or 'seq_prob')
            model: Loaded model from MLflow
            model_type: Type of model ('sklearn', 'pytorch', etc.)
            cfg: Configuration for the model (optional)
            embedding_manager: Embedding manager for centralized embedding computation
            is_seq_prob_evaluator: If True, use model's sequence_probability method
            **kwargs: Additional configuration parameters
        """
        super().__init__(name, task_type, **kwargs)
        self.model: BasePredictor = model
        self.model_type = model_type
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.cfg = cfg
        self.embedding_manager = embedding_manager
        self.is_seq_prob_evaluator = is_seq_prob_evaluator
        self.threshold = threshold

        # Set up the model for evaluation
        self.setup()

    def setup(self, **kwargs) -> None:
        """Set up the model for evaluation."""
        if self.model_type == "pytorch":
            logger.info(f"Set Model {self.name} eval")
            logger.info(f"Compiling model {self.name}")
            if hasattr(self.model, "model") and hasattr(self.model.model, "eval"):
                self.model.model.eval()

            # Wrap model in DataParallel if DDP is enabled and multiple GPUs available
            if (self.cfg and getattr(self.cfg, "ddp", False) and
                    torch.cuda.is_available() and torch.cuda.device_count() > 1):
                if hasattr(self.model, "model"):
                    self.model.model = torch.nn.DataParallel(self.model.model)
                    logger.info(f"Wrapped {self.name} model in DataParallel")
        elif self.model_type == "sklearn":
            # Pass embedding manager to sklearn models that support it
            if self.embedding_manager and hasattr(self.model, "prepare_features"):
                self.model.embedding_manager = self.embedding_manager
                cache_id = getattr(self.embedding_manager,
                                   "instance_id", "unknown")
                logger.info(
                    f"Assigned CacheManager instance {cache_id} to {self.name}")
            elif self.embedding_manager:
                logger.warning(
                    f"Model {self.name} does not support embedding_manager (no prepare_features method)"
                )
            else:
                logger.warning(
                    f"No embedding_manager provided for {self.name}")
        else:
            logger.warning(f"Unknown model type: {self.model_type}")

    def predict(self, sequences: List[str], batch_size: int = 1):
        """
        Make predictions on a list of sequences.
        """
        # If this is a sequence probability evaluator, use sequence_probability method
        if self.is_seq_prob_evaluator:
            prob_results = self.model.predict(
                sequences=sequences, batch_size=batch_size, perplexities=True
            )
            # For filtering tasks, return perplexity that can be thresholded
            if self.task_type == "filter":
                return np.array(np.array(prob_results["perplexities"]) < self.threshold)
            else:
                # For other tasks, return the full results
                return prob_results

        # Always use the regular predict method for sequences
        # predict_with_embeddings is for pre-computed embeddings only
        predictions = self.model.predict(
            sequences=sequences, batch_size=batch_size)

        if self.model_type == "pytorch" and isinstance(predictions, torch.Tensor):
            # Convert BFloat16 to float32 if needed before converting to numpy
            if predictions.dtype == torch.bfloat16:
                logger.debug("Converting predictions from bfloat16 to float32")
                predictions = predictions.float()
            predictions = predictions.cpu().numpy()

        if isinstance(predictions, dict):
            return predictions.get("predictions", predictions)
        elif isinstance(predictions, np.ndarray):
            return predictions
        else:
            return np.array(predictions)

    def predict_with_embeddings(self, embeddings: np.ndarray):
        """
        Make predictions using pre-computed embeddings.

        Args:
            embeddings: Pre-computed embeddings for sequences

        Returns:
            Predictions from the model
        """
        if not hasattr(self.model, "predict_with_embeddings"):
            raise AttributeError(
                f"Model {self.name} does not support predict_with_embeddings method")

        predictions = self.model.predict_with_embeddings(embeddings)

        if self.model_type == "pytorch" and isinstance(predictions, torch.Tensor):
            # Convert BFloat16 to float32 if needed before converting to numpy
            if predictions.dtype == torch.bfloat16:
                logger.debug("Converting predictions from bfloat16 to float32")
                predictions = predictions.float()
            predictions = predictions.cpu().numpy()

        if isinstance(predictions, dict):
            return predictions.get("predictions", predictions)
        elif isinstance(predictions, np.ndarray):
            return predictions
        else:
            return np.array(predictions)
