from sklearn.svm import SVR
from omegaconf import DictConfig
from haipr.predictor import BasePredictor
from haipr.data import HAIPRData
from haipr.utils import compute_regression_metrics
from typing import Dict, Any, List, Tuple
import numpy as np
import logging
import torch

logger = logging.getLogger(__name__)


class SVRPredictor(BasePredictor):
    """
    A predictor class for Support Vector Regression (SVR) models.

    This class encapsulates the training, prediction, and evaluation
    of SVR models, conforming to the BasePredictor interface.
    """

    def __init__(
        self,
        params: DictConfig = None,
        C: float = 75,
        gamma: str = "scale",
        epsilon: float = 0.1,
        kernel: str = "rbf",
        degree: int = 4,
        **kwargs,
    ):
        """
        Initialize the SVRPredictor.

        Args:
            params (DictConfig): Configuration for SVR model parameters.
            **kwargs: Additional keyword arguments.
        """
        super().__init__()
        self.model = SVR(
            C=C, gamma=gamma, epsilon=epsilon, kernel=kernel, degree=degree
        )
        logger.info(f"Initialized SVR model with parameters: {params}")

    @classmethod
    def from_params(cls, params: DictConfig):
        return cls(**params)

    def setup_model(self, data: HAIPRData = None, cfg: DictConfig = None):
        """
        Sets up the data and embedder for the predictor.

        Args:
            data (HAIPRData): The dataset to be used.
            cfg (DictConfig): Configuration containing embedder information.
        """
        self.embedder_instance = None
        self.embedder_config = None
        self.embedding_manager = None  # Will be set by inference pipeline
        if data is not None:
            self.data = data
            if self.data.features_loaded:
                logger.info("Features already loaded. Skipping embedder initialization.")
                return
        if cfg and hasattr(cfg, "embedder"):
            self.embedder_config = cfg.embedder
            # Only initialize local embedder if no embedding manager is available
            if not hasattr(self, "embedding_manager") or self.embedding_manager is None:
                self._initialize_embedder()
        logger.info("Data setup for SVRPredictor complete.")

    def _initialize_embedder(self):
        """Initialize the embedder instance from configuration."""
        if not self.embedder_config:
            logger.warning(
                "No embedder configuration found. Cannot initialize embedder."
            )
            return

        if self.embedder_config.name != "protenc":
            raise NotImplementedError(
                f"Embedder '{self.embedder_config.name}' not supported. Only 'protenc' is implemented."
            )

        try:
            import protenc

            model_name = self.embedder_config.model
            batch_size = getattr(self.embedder_config, "batch_size", 32)
            device = "cuda" if torch.cuda.is_available() else "cpu"
            data_parallel = getattr(self.embedder_config, "data_parallel", False)

            logger.info(f"Initializing protenc embedder: {model_name}")

            self.embedder_instance = protenc.get_encoder(
                model_name,
                device=device,
                batch_size=batch_size,
                data_parallel=data_parallel,
            )

            logger.info("Protenc embedder initialized successfully")

        except Exception as e:
            logger.error(f"Failed to initialize protenc embedder: {e}")
            raise ValueError(f"Could not initialize protenc embedder: {e}")

    def load_model(self, model_path: str):
        pass

    def prepare_features(self, sequences: List[str]) -> np.ndarray:
        """
        Prepare the features for the SVR model using embeddings.
        Uses the embedding manager from inference if available, otherwise fallback to local embedder.

        Args:
            sequences: List of protein sequences to embed.

        Returns:
            np.ndarray: Embedded features for the sequences.
        """
        if isinstance(sequences, HAIPRData):
            # Legacy behavior: if HAIPRData is passed, use its prepare_features method
            sequences.prepare_features()
            return sequences

        # Try to use embedding manager from inference pipeline first
        if hasattr(self, "embedding_manager") and self.embedding_manager is not None:
            logger.info(
                f"SVRPredictor using CacheManager for {len(sequences)} sequences"
            )
            embeddings = self.embedding_manager.get_embeddings(sequences)
            return embeddings

        # Fallback to local embedder
        if not self.embedder_instance:
            # Try to initialize embedder from data if available
            if hasattr(self, "data") and self.data and hasattr(self.data, "config"):
                if hasattr(self.data.config, "embedder"):
                    self.embedder_config = self.data.config.embedder
                    self._initialize_embedder()

            if not self.embedder_instance:
                raise RuntimeError(
                    "No embedder available. Please ensure embedder is configured and initialized."
                )

        logger.debug(f"Using local embedder for {len(sequences)} sequences")

        try:
            # Use local embedder to generate features
            embeddings_list = []
            average_embeddings = getattr(self.embedder_config, "average_sequence", True)

            for embed_output in self.embedder_instance(
                sequences,
                average_sequence=average_embeddings,
                return_format="numpy",
            ):
                embeddings_list.append(embed_output)

            # Stack all embeddings into a single array
            features = np.vstack(embeddings_list)
            logger.debug(f"Generated features with shape: {features.shape}")

            return features

        except Exception as e:
            logger.error(f"Error preparing features: {e}")
            raise RuntimeError(f"Failed to prepare features: {e}")

    def fit_model(
        self,
        dataset: HAIPRData,
        train_indices: List[int],
        val_indices: List[int],
        trainer_instance: Any = None,  # Not used for sklearn models
        cfg: DictConfig = None,
    ) -> Dict[str, Any]:
        """
        Fit the SVR model on the training data and evaluate on the validation data.

        Args:
            dataset (HAIPRData): The full dataset.
            train_indices (List[int]): Indices for the training set.
            val_indices (List[int]): Indices for the validation set.
            trainer_instance (Any, optional): A PyTorch Lightning trainer instance.
                                             Not used for scikit-learn models. Defaults to None.

        Returns:
            Dict[str, Any]: A dictionary containing 'metrics' and 'predictions'.
                            'metrics': A dictionary of evaluation metrics.
                            'predictions': A dictionary with 'indices', 'predictions',
                                           and 'true_values'.
        """
        if not hasattr(self, "data"):
            logger.warning("Data not set up for SVRPredictor. Setting it up now.")
            self.setup_model(dataset, cfg)

        # get train and val data
        X_train, y_train = self.data[train_indices]
        X_val, y_val = self.data[val_indices]

        logger.info(f"Training SVR model on {len(train_indices)} samples. validation len {len(X_val)}.")
        try:
            self.model.fit(X_train, y_train)
            logger.info("SVR model training complete.")
        except Exception as e:
            logger.error(f"Error during SVR model training: {e}", exc_info=True)
            raise

        logger.info(f"Evaluating SVR model on {len(val_indices)} validation samples.")
        y_pred = self.model.predict(X_val)

        metrics = compute_regression_metrics(y_val, y_pred)
        logger.info(f"SVR validation metrics: {metrics}")

        pred_dict = {
            "indices": val_indices,  # Keep as list of ints if that's the original type
            "predictions": y_pred.tolist(),
            "true_values": y_val.tolist(),
        }

        return {"metrics": metrics, "predictions": pred_dict}

    def predict(self, sequences: List[str], batch_size: int = 1) -> np.ndarray:
        """
        Make predictions using the trained SVR model.

        Args:
            sequences (List[str]): List of protein sequences for prediction.
            batch_size (int): Batch size (not used for SVR but kept for interface consistency).

        Returns:
            np.ndarray: Predicted values.
        """
        if not hasattr(self.model, "support_"):
            raise RuntimeError("Model has not been trained yet. Call fit_model first.")

        logger.info(f"Making predictions with SVR model on {len(sequences)} samples.")

        # Prepare features from sequences using embedder
        X = self.prepare_features(sequences)
        predictions = self.model.predict(X)

        return predictions

    def predict_with_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
        """
        Make predictions using pre-computed embeddings.

        Args:
            embeddings: Pre-computed embeddings for sequences

        Returns:
            np.ndarray: Predicted values
        """
        if not hasattr(self.model, "support_"):
            raise RuntimeError("Model has not been trained yet. Call fit_model first.")

        logger.info(
            f"Making predictions with SVR model on {len(embeddings)} pre-computed embeddings"
        )

        # Use embeddings directly for prediction
        logger.info(f"Predicting SVR model on {len(embeddings)} pre-computed embeddings")
        predictions = self.model.predict(embeddings)
        logger.info("SVR model prediction complete")
        return predictions

    def forward(self, batch: Any) -> Any:
        """
        Forward pass for the SVR model.
        For scikit-learn models, this typically means calling predict.

        Args:
            batch (Any): Input data. Expected to be a list of sequences or pre-computed features.

        Returns:
            Any: Predictions from the model.
        """
        # Handle both raw sequences and pre-computed features
        if isinstance(batch, list) and isinstance(batch[0], str):
            # batch is a list of sequences
            return self.predict(batch)
        else:
            # batch is pre-computed features
            return self.model.predict(batch)
