from abc import ABC, abstractmethod
from omegaconf import DictConfig, OmegaConf
from typing import Any, Dict, Union, List, Optional
import numpy as np
import torch
import logging
import lightning.pytorch as pl
from haipr.data import HAIPRData
from mlflow.pyfunc.model import PythonModel

logger = logging.getLogger(__name__)


class InferenceWrapper(PythonModel):
    """
    Lightweight wrapper for MLflow pyfunc inference.
    Does not inherit from PyTorch Lightning - only for inference.
    """

    def __init__(self, predictor_instance):
        """
        Create inference wrapper from a trained predictor.
        """
        # Store the predictor instance
        self._predictor = predictor_instance
        # Set default device - will be updated in load_context()
        self.device = "cpu"

    def predict(self, context, model_input: list[str], params=None):
        """
        Pyfunc predict interface - MLflow standard signature.
        """
        logger.debug(f"InferenceWrapper context: {context}")
        params = context.model_config
        logger.debug(f"InferenceWrapper params: {params}")
        return self._predictor.predict_sequences(model_input, params)

    def load_context(self, context):
        """Load artifacts when model is loaded from MLflow."""
        logger.info(f"InferenceWrapper delegating to predictor.load_context()")
        self._predictor.load_context(context)


class BasePredictor(PythonModel, ABC):
    """
    Abstract base class for all HAIPR predictors.

    Inherits from mlflow.pyfunc.PythonModel to provide native pyfunc compatibility.
    """

    _inference_trainer = None

    def __init__(self):
        """Initialize base predictor attributes."""
        super().__init__()
        self.embedder_instance = None
        self.embedder_config = None
        self.embedding_manager = None  # Will be set by inference pipeline
        self.data = None
        self.cfg = None

    def _initialize_embedder(self):
        """Initialize the embedder instance from configuration."""
        if not self.embedder_config:
            logger.warning(
                "No embedder configuration found. Cannot initialize embedder."
            )
            return

        if self.embedder_config.name != "protenc":
            raise NotImplementedError(
                f"Embedder '{self.embedder_config.name}' not supported. Only 'protenc' is implemented."
            )

        try:
            import protenc

            model_name = self.embedder_config.model
            batch_size = getattr(self.embedder_config, "batch_size", 32)
            device = "cuda" if torch.cuda.is_available() else "cpu"
            data_parallel = getattr(
                self.embedder_config, "data_parallel", False)

            logger.info(f"Initializing protenc embedder: {model_name}")

            self.embedder_instance = protenc.get_encoder(
                model_name,
                device=device,
                batch_size=batch_size,
                data_parallel=data_parallel,
            )

            logger.info("Protenc embedder initialized successfully")

        except Exception as e:
            logger.error(f"Failed to initialize protenc embedder: {e}")
            raise ValueError(f"Could not initialize protenc embedder: {e}")

    def _get_embeddings_for_sequences(self, sequences: List[str]) -> np.ndarray:
        """
        Get embeddings for sequences using available embedding source.
        Priority: embedding_manager (inference) > local embedder (fallback)
        """
        # Try embedding manager from inference pipeline first
        if hasattr(self, "embedding_manager") and self.embedding_manager is not None:
            logger.debug(
                f"Using EmbeddingManager for {len(sequences)} sequences")
            return self.embedding_manager.get_embeddings(sequences)

        # Fallback to local embedder (for standalone usage)
        if not hasattr(self, "embedder_instance") or self.embedder_instance is None:
            if hasattr(self, "embedder_config") and self.embedder_config:
                self._initialize_embedder()
            else:
                raise RuntimeError(
                    "No embedding source available. Need either embedding_manager "
                    "(from inference) or embedder_config (for local embedder)."
                )

        logger.debug(f"Using local embedder for {len(sequences)} sequences")
        embeddings_list = []
        average_embeddings = getattr(
            self.embedder_config, "average_sequence", True)

        for embed_output in self.embedder_instance(
            sequences,
            average_sequence=average_embeddings,
            return_format="numpy",
        ):
            embeddings_list.append(embed_output)

        return np.vstack(embeddings_list)

    @abstractmethod
    def setup_model(self, data: HAIPRData, cfg: DictConfig) -> None:
        """Initialize predictor with configuration."""
        self.data = data
        self.cfg = cfg

    @abstractmethod
    def fit_model(
        self,
        dataset: HAIPRData,
        train_indices: Any,
        val_indices: Any,
    ) -> Dict[str, Any]:
        """Train the predictor on given data.
        Returns a dictionary with the following keys:
        - metrics: a dictionary of metrics
        - predictions: a dictionary of predictions indices, predictions, true_values and optionally probabilities
        """
        pass

    @abstractmethod
    def prepare_training_features(
        self, dataset: HAIPRData, indices: np.ndarray
    ) -> Dict[str, Any]:
        """Pre-compute all features for given indices.
        Returns dict of features ready for model training.
        - return Dict[str, torch.Tensor] for TensorDataset for nn.Module models.
        - return Dict[str, np.ndarray] for direct use
        """
        pass

    @abstractmethod
    def prepare_batch_features(
        self, batch_items: List[Dict[str, Any]]
    ) -> Dict[str, Any]:
        """Prepare features for a single batch during inference.
        Returns dict of features ready for model prediction.
        """
        pass

    # Backward compatibility alias
    def prepare_features(self, batch_items: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Backward compatibility alias for prepare_batch_features."""
        return self.prepare_batch_features(batch_items)

    def _prepare_features_from_sequences(self, sequences: List[str]) -> np.ndarray:
        """
        Prepare features from sequences using embeddings.
        Uses the embedding manager from inference if available, otherwise fallback to local embedder.

        Args:
            sequences: List of protein sequences to embed.

        Returns:
            np.ndarray: Embedded features for the sequences.
        """
        if isinstance(sequences, HAIPRData):
            # Legacy behavior: if HAIPRData is passed, use its prepare_features method
            sequences.prepare_features()
            return sequences

        # Try to use embedding manager from inference pipeline first
        if hasattr(self, "embedding_manager") and self.embedding_manager is not None:
            model_name = self.__class__.__name__
            logger.info(
                f"{model_name} using CacheManager for {len(sequences)} sequences"
            )
            embeddings = self.embedding_manager.get_embeddings(sequences)
            return embeddings

        # Fallback to local embedder
        if not self.embedder_instance:
            # Try to initialize embedder from data if available
            if hasattr(self, "data") and self.data and hasattr(self.data, "config"):
                if hasattr(self.data.config, "embedder"):
                    self.embedder_config = self.data.config.embedder
                    self._initialize_embedder()

            if not self.embedder_instance:
                raise RuntimeError(
                    "No embedder available. Please ensure embedder is configured and initialized."
                )

        logger.debug(f"Using local embedder for {len(sequences)} sequences")

        try:
            # Use local embedder to generate features
            embeddings_list = []
            average_embeddings = getattr(
                self.embedder_config, "average_sequence", True)

            for embed_output in self.embedder_instance(
                sequences,
                average_sequence=average_embeddings,
                return_format="numpy",
            ):
                embeddings_list.append(embed_output)

            # Stack all embeddings into a single array
            features = np.vstack(embeddings_list)
            logger.debug(f"Generated features with shape: {features.shape}")

            return features

        except Exception as e:
            logger.error(f"Error preparing features: {e}")
            raise RuntimeError(f"Failed to prepare features: {e}")

    @abstractmethod
    def save_model(self, save_dir: str) -> str:
        """Save the model (for checkpoint restoration during training).
        Returns the path to the saved model.
        """

    def load_context(self, context):
        """
        Load model artifacts from MLflow context (pyfunc requirement).

        This is called when loading a saved pyfunc model. Subclasses can override
        to load specific artifacts, but should call super().load_context(context).
        """
        # Load config if available
        if "config" in context.artifacts:
            self.cfg = OmegaConf.load(context.artifacts["config"])
            logger.info("Loaded config from artifacts")

        # Initialize embedder config if available
        if hasattr(self, "cfg") and hasattr(self.cfg, "embedder"):
            self.embedder_config = self.cfg.embedder
            logger.info("Set embedder config from loaded config")
            # Don't initialize embedder here - it will be set by inference pipeline
            # via embedding_manager or initialized on-demand during predict

    @abstractmethod
    def predict_sequences(
        self, sequences: List[str], params: Dict[str, Any] | None
    ) -> Union[np.ndarray, Dict[str, Any]]:
        """Make predictions on sequences (model-specific implementation)."""
        pass

    def predict(self, context, model_input, params=None) -> Any:
        """
        Unified pyfunc predict interface.
        """
        return self.predict_sequences(model_input, params)
