from sklearn.svm import SVC
from omegaconf import DictConfig
from haipr.predictor import BasePredictor
from haipr.data import HAIPRData
from haipr.utils import compute_classification_metrics
from typing import Dict, Any, List
import numpy as np
import logging

logger = logging.getLogger(__name__)


class SVCPredictor(BasePredictor):
    """
    A predictor class for Support Vector Classification (SVC) models.

    This class encapsulates the training, prediction, and evaluation
    of SVC models, conforming to the BasePredictor interface.
    """

    def __init__(
        self,
        params: DictConfig = None,
        C: float = 1.0,
        gamma: str = "scale",
        kernel: str = "rbf",
        degree: int = 3,
        probability: bool = True,
        **kwargs,
    ):
        """
        Initialize the SVCPredictor.

        Args:
            params (DictConfig): Configuration for SVC model parameters.
            **kwargs: Additional keyword arguments.
        """
        super().__init__()
        self.model = SVC(
            C=C, gamma=gamma, kernel=kernel, degree=degree, probability=probability
        )
        logger.info(f"Initialized SVC model with parameters: {params}")

    @classmethod
    def from_params(cls, params: DictConfig):
        return cls(**params)

    def setup_model(self, data: HAIPRData = None, cfg: DictConfig = None):
        """Sets up the data and embedder config for the predictor."""
        self.embedding_manager = None  # Will be set by inference pipeline
        if data is not None:
            self.data = data
        if cfg and hasattr(cfg, "embedder"):
            self.embedder_config = cfg.embedder
        logger.info("Data setup for SVCPredictor complete.")

    def load_model(self, model_path: str):
        pass

    def prepare_training_features(
        self, dataset: HAIPRData, indices: np.ndarray
    ) -> Dict[str, np.ndarray]:
        """Get pre-computed embeddings from HAIPRData for training."""
        X, y = dataset[indices]
        return {
            "features": X.numpy() if isinstance(X, torch.Tensor) else X,
            "labels": y.numpy() if isinstance(y, torch.Tensor) else y
        }

    def prepare_batch_features(self, batch_items: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Prepare features for inference batch."""
        sequences = [item["sequence"] for item in batch_items]
        embeddings = self._get_embeddings_for_sequences(sequences)
        return {"features": embeddings}

    def prepare_features(self, data: HAIPRData):
        """
        Prepare the features for the SVC model.
        """
        data.prepare_features()
        pass

    def _get_embeddings_for_sequences(self, sequences: List[str]) -> np.ndarray:
        """
        Get embeddings for sequences using available embedding source.
        Priority: embedding_manager (inference) > local embedder (fallback)
        """
        # Try embedding manager from inference pipeline first
        if hasattr(self, "embedding_manager") and self.embedding_manager is not None:
            logger.debug(f"Using EmbeddingManager for {len(sequences)} sequences")
            return self.embedding_manager.get_embeddings(sequences)
        
        # Fallback to local embedder (for standalone usage)
        if not hasattr(self, "embedder_instance") or self.embedder_instance is None:
            if hasattr(self, "embedder_config") and self.embedder_config:
                self._initialize_embedder()
            else:
                raise RuntimeError(
                    "No embedding source available. Need either embedding_manager "
                    "(from inference) or embedder_config (for local embedder)."
                )
        
        logger.debug(f"Using local embedder for {len(sequences)} sequences")
        embeddings_list = []
        average_embeddings = getattr(self.embedder_config, "average_sequence", True)
        
        for embed_output in self.embedder_instance(
            sequences,
            average_sequence=average_embeddings,
            return_format="numpy",
        ):
            embeddings_list.append(embed_output)
        
        return np.vstack(embeddings_list)

    def _initialize_embedder(self):
        """Initialize the embedder instance from configuration."""
        if not self.embedder_config:
            logger.warning(
                "No embedder configuration found. Cannot initialize embedder."
            )
            return

        if self.embedder_config.name != "protenc":
            raise NotImplementedError(
                f"Embedder '{self.embedder_config.name}' not supported. Only 'protenc' is implemented."
            )

        try:
            import protenc

            model_name = self.embedder_config.model
            batch_size = getattr(self.embedder_config, "batch_size", 32)
            device = "cuda" if torch.cuda.is_available() else "cpu"
            data_parallel = getattr(self.embedder_config, "data_parallel", False)

            logger.info(f"Initializing protenc embedder: {model_name}")

            self.embedder_instance = protenc.get_encoder(
                model_name,
                device=device,
                batch_size=batch_size,
                data_parallel=data_parallel,
            )

            logger.info("Protenc embedder initialized successfully")

        except Exception as e:
            logger.error(f"Failed to initialize protenc embedder: {e}")
            raise ValueError(f"Could not initialize protenc embedder: {e}")

    def fit_model(
        self,
        dataset: HAIPRData,
        train_indices: List[int],
        val_indices: List[int],
        trainer_instance: Any = None,  # Not used for sklearn models
        cfg: DictConfig = None,
    ) -> Dict[str, Any]:
        """Fit SVC model using pre-computed embeddings from HAIPRData."""
        if not hasattr(self, "data"):
            self.setup_model(dataset, cfg)
        
        # Get pre-computed features from HAIPRData
        # HAIPRData already has embeddings stored in feature columns
        X_train, y_train = self.data[train_indices]
        X_val, y_val = self.data[val_indices]
        
        # Convert tensors to numpy if needed
        if isinstance(X_train, torch.Tensor):
            X_train = X_train.numpy()
            y_train = y_train.numpy()
        if isinstance(X_val, torch.Tensor):
            X_val = X_val.numpy()
            y_val = y_val.numpy()
        
        logger.info(f"Training SVC model on {len(train_indices)} samples")
        self.model.fit(X_train, y_train)
        
        logger.info(f"Evaluating SVC model on {len(val_indices)} validation samples")
        y_pred = self.model.predict(X_val)
        y_prob = self.model.predict_proba(X_val)
        
        metrics = compute_classification_metrics(y_val, y_pred, y_prob)
        pred_dict = {
            "indices": val_indices,
            "predictions": y_pred.tolist(),
            "true_values": y_val.tolist(),
            "probabilities": y_prob.tolist(),
        }
        
        return {"metrics": metrics, "predictions": pred_dict}

    def predict_sequences(self, sequences: List[str], batch_size: int = 1, **kwargs) -> np.ndarray:
        """Make predictions on sequences (pyfunc interface)."""
        if not hasattr(self.model, "support_"):
            raise RuntimeError("Model has not been trained yet.")
        
        logger.info(f"Making predictions with SVC model on {len(sequences)} sequences")
        embeddings = self._get_embeddings_for_sequences(sequences)
        predictions = self.model.predict(embeddings)
        
        return predictions

    def _load_artifacts(self, context):
        """Load SVC model from MLflow artifacts."""
        import pickle
        
        if "model" in context.artifacts:
            model_path = context.artifacts["model"]
            with open(model_path, "rb") as f:
                self.model = pickle.load(f)
            logger.info(f"Loaded SVC model from {model_path}")
        
        if hasattr(self, "cfg") and hasattr(self.cfg, "embedder"):
            self.embedder_config = self.cfg.embedder


    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        """
        Make probability predictions using the trained SVC model.

        Args:
            X (np.ndarray): Input features for prediction.

        Returns:
            np.ndarray: Probability predictions.
        """
        if not hasattr(self.model, "support_"):
            raise RuntimeError("Model has not been trained yet. Call fit_model first.")
        logger.info(f"Making probability predictions with SVC model on {X.shape[0]} samples.")
        return self.model.predict_proba(X)

    def forward(self, batch: Any) -> Any:
        """
        Forward pass for the SVC model.
        For scikit-learn models, this typically means calling predict.

        Args:
            batch (Any): Input data. Expected to be a NumPy array or similar
                         that self.model.predict can handle.

        Returns:
            Any: Predictions from the model.
        """
        # Assuming batch is X for sklearn models
        return self.predict(batch)

    def get_trained_model(self):
        """
        Get the trained model for logging purposes.
        
        Returns:
            The trained sklearn model instance.
        """
        if not hasattr(self.model, "support_"):
            raise RuntimeError("Model has not been trained yet. Call fit_model first.")
        return self.model 