from sklearn.svm import SVR
from omegaconf import DictConfig, OmegaConf
from haipr.predictor import BasePredictor
from haipr.data import HAIPRData
from haipr.utils import compute_regression_metrics
from typing import Dict, Any, List, Tuple
import numpy as np
import os
import pickle
import logging
import torch

logger = logging.getLogger(__name__)


class SVRPredictor(BasePredictor):
    """
    A predictor class for Support Vector Regression (SVR) models.

    This class encapsulates the training, prediction, and evaluation
    of SVR models, conforming to the BasePredictor interface.
    """

    def __init__(
        self,
        params: DictConfig = None,
        C: float = 75,
        gamma: str = "scale",
        epsilon: float = 0.1,
        kernel: str = "rbf",
        degree: int = 4,
        **kwargs,
    ):
        """
        Initialize the SVRPredictor.

        Args:
            params (DictConfig): Configuration for SVR model parameters.
            **kwargs: Additional keyword arguments.
        """
        super().__init__()
        self.model = SVR(
            C=C, gamma=gamma, epsilon=epsilon, kernel=kernel, degree=degree
        )
        logger.info(f"Initialized SVR model with parameters: {params}")

    @classmethod
    def from_params(cls, params: DictConfig):
        return cls(**params)

    def setup_model(self, data: HAIPRData = None, cfg: DictConfig = None):
        """
        Sets up the data and embedder for the predictor.

        Args:
            data (HAIPRData): The dataset to be used.
            cfg (DictConfig): Configuration containing embedder information.
        """
        self.embedder_instance = None
        self.embedder_config = None
        self.embedding_manager = None  # Will be set by inference pipeline
        if data is not None:
            self.data = data
            if self.data.features_loaded:
                logger.info(
                    "Features already loaded. Skipping embedder initialization."
                )
                return
        if cfg and hasattr(cfg, "embedder"):
            self.embedder_config = cfg.embedder
            # Only initialize local embedder if no embedding manager is available
            if not hasattr(self, "embedding_manager") or self.embedding_manager is None:
                self._initialize_embedder()
        logger.info("Data setup for SVRPredictor complete.")


    def prepare_training_features(
        self, dataset: HAIPRData, indices: np.ndarray
    ) -> Dict[str, np.ndarray]:
        """Get pre-computed embeddings from HAIPRData for training."""
        # HAIPRData already has embeddings computed and cached
        # Just extract features and labels for the given indices
        X, y = dataset[indices]
        return {
            "features": X.numpy() if isinstance(X, torch.Tensor) else X,
            "labels": y.numpy() if isinstance(y, torch.Tensor) else y,
        }

    def prepare_batch_features(
        self, batch_items: List[Dict[str, Any]]
    ) -> Dict[str, Any]:
        """For sklearn models, this is not used during training.
        Only used if someone manually calls it during inference."""
        sequences = [item["sequence"] for item in batch_items]
        embeddings = self._get_embeddings_for_sequences(sequences)
        return {"features": embeddings}

    def prepare_features(self, sequences: List[str]) -> np.ndarray:
        return self._prepare_features_from_sequences(sequences)

    def fit_model(
        self,
        dataset: HAIPRData,
        train_indices: List[int],
        val_indices: List[int],
        trainer_instance: Any = None,  # N
        cfg: DictConfig = None,
    ) -> Dict[str, Any]:
        """Fit SVR model using pre-computed embeddings from HAIPRData."""
        if not hasattr(self, "data"):
            self.setup_model(dataset, cfg)

        # Get pre-computed features from HAIPRData
        # HAIPRData already has embeddings stored in feature columns if feature_type is embedded
        X_train, y_train = self.data[train_indices]
        X_val, y_val = self.data[val_indices]

        # Convert tensors to numpy if needed
        if isinstance(X_train, torch.Tensor):
            X_train = X_train.numpy()
            y_train = y_train.numpy()
        if isinstance(X_val, torch.Tensor):
            X_val = X_val.numpy()
            y_val = y_val.numpy()

        logger.info(f"Training SVR model on {len(train_indices)} samples")
        self.model.fit(X_train, y_train)

        logger.info(f"Evaluating SVR model on {len(val_indices)} validation samples")
        y_pred = self.model.predict(X_val)

        metrics = compute_regression_metrics(y_val, y_pred)
        pred_dict = {
            "indices": val_indices,
            "predictions": y_pred.tolist(),
            "true_values": y_val.tolist(),
        }

        return {"metrics": metrics, "predictions": pred_dict}

    def predict_sequences(
        self, sequences: List[str], params: Dict[str, Any] | None = None
    ) -> np.ndarray:
        """Make predictions on sequences (pyfunc interface)."""
        if not hasattr(self.model, "support_"):
            raise RuntimeError("Model has not been trained yet.")

        logger.info(f"Making predictions with SVR model on {len(sequences)} sequences")

        # Get embeddings using the unified method
        embeddings = self._get_embeddings_for_sequences(sequences)
        predictions = self.model.predict(embeddings)

        return predictions

    def load_context(self, context):
        """Load SVR model from MLflow artifacts."""
        super().load_context(context)
        
        # Load SVR model from pickle
        if "model" in context.artifacts:
            import pickle
            model_path = context.artifacts["model"]
            with open(model_path, "rb") as f:
                self.model = pickle.load(f)
            logger.info(f"Loaded SVR model from {model_path}")

    def forward(self, batch: Any) -> Any:
        """
        Forward pass for the SVR model.
        For scikit-learn models, this typically means calling predict.

        Args:
            batch (Any): Input data. Expected to be a list of sequences or pre-computed features.

        Returns:
            Any: Predictions from the model.
        """
        # Handle both raw sequences and pre-computed features
        if isinstance(batch, list) and len(batch) > 0 and isinstance(batch[0], str):
            # batch is a list of sequences - use predict_sequences
            return self.predict_sequences(batch, params=None)
        else:
            # batch is pre-computed features - call model directly
            return self.model.predict(batch)

    def save_model(self, save_dir: str) -> str:
        """Save the SVR model to a pickle file."""
        model_path = os.path.join(save_dir, "svr_model.pkl")
        with open(model_path, "wb") as f:
            pickle.dump(self.model, f)
        return model_path

    def load_model(self, model: str) -> None:
        """Load the SVR model from a pickle file."""
        with open(model, "rb") as f:
            self.model = pickle.load(f)