from sklearn.svm import SVC
from omegaconf import DictConfig
from haipr.predictor import BasePredictor
from haipr.data import HAIPRData
from haipr.utils import compute_classification_metrics
from typing import Dict, Any, List
import numpy as np
import logging

logger = logging.getLogger(__name__)


class SVCPredictor(BasePredictor):
    """
    A predictor class for Support Vector Classification (SVC) models.

    This class encapsulates the training, prediction, and evaluation
    of SVC models, conforming to the BasePredictor interface.
    """

    def __init__(
        self,
        params: DictConfig = None,
        C: float = 1.0,
        gamma: str = "scale",
        kernel: str = "rbf",
        degree: int = 3,
        probability: bool = True,
        **kwargs,
    ):
        """
        Initialize the SVCPredictor.

        Args:
            params (DictConfig): Configuration for SVC model parameters.
            **kwargs: Additional keyword arguments.
        """
        super().__init__()
        self.model = SVC(
            C=C, gamma=gamma, kernel=kernel, degree=degree, probability=probability
        )
        logger.info(f"Initialized SVC model with parameters: {params}")

    @classmethod
    def from_params(cls, params: DictConfig):
        return cls(**params)

    def setup_model(self, data: HAIPRData, cfg: DictConfig = None):
        """
        Sets up the data for the predictor.

        Args:
            data (HAIPRData): The dataset to be used.
        """
        self.data = data
        logger.info("Data setup for SVCPredictor complete.")

    def load_model(self, model_path: str):
        pass

    def prepare_features(self, data: HAIPRData):
        """
        Prepare the features for the SVC model.
        """
        data.prepare_features()
        pass

    def fit_model(
        self,
        dataset: HAIPRData,
        train_indices: List[int],
        val_indices: List[int],
        trainer_instance: Any = None,  # Not used for sklearn models
        cfg: DictConfig = None,
    ) -> Dict[str, Any]:
        """
        Fit the SVC model on the training data and evaluate on the validation data.

        Args:
            dataset (HAIPRData): The full dataset.
            train_indices (List[int]): Indices for the training set.
            val_indices (List[int]): Indices for the validation set.
            trainer_instance (Any, optional): A PyTorch Lightning trainer instance.
                                             Not used for scikit-learn models. Defaults to None.

        Returns:
            Dict[str, Any]: A dictionary containing 'metrics' and 'predictions'.
                            'metrics': A dictionary of evaluation metrics.
                            'predictions': A dictionary with 'indices', 'predictions',
                                           'true_values', and 'probabilities'.
        """
        if not hasattr(self, "data"):
            logger.warning("Data not set up for SVCPredictor. Setting it up now.")
            self.setup_model(dataset, cfg)

        # get train and val data
        X_train, y_train = self.data[train_indices]
        X_val, y_val = self.data[val_indices]

        logger.info(f"Training SVC model on {len(train_indices)} samples.")
        try:
            self.model.fit(X_train, y_train)
            logger.info("SVC model training complete.")
        except Exception as e:
            logger.error(f"Error during SVC model training: {e}", exc_info=True)
            raise

        logger.info(f"Evaluating SVC model on {len(val_indices)} validation samples.")
        y_pred = self.model.predict(X_val)
        y_prob = self.model.predict_proba(X_val)

        metrics = compute_classification_metrics(y_val, y_pred, y_prob)
        logger.info(f"SVC validation metrics: {metrics}")

        pred_dict = {
            "indices": val_indices,  # Keep as list of ints if that's the original type
            "predictions": y_pred.tolist(),
            "true_values": y_val.tolist(),
            "probabilities": y_prob.tolist(),
        }

        return {"metrics": metrics, "predictions": pred_dict}

    def predict(self, X: np.ndarray) -> np.ndarray:
        """
        Make predictions using the trained SVC model.

        Args:
            X (np.ndarray): Input features for prediction.

        Returns:
            np.ndarray: Predicted values.
        """
        if not hasattr(self.model, "support_"):
            raise RuntimeError("Model has not been trained yet. Call fit_model first.")
        logger.info(f"Making predictions with SVC model on {X.shape[0]} samples.")
        return self.model.predict(X)

    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        """
        Make probability predictions using the trained SVC model.

        Args:
            X (np.ndarray): Input features for prediction.

        Returns:
            np.ndarray: Probability predictions.
        """
        if not hasattr(self.model, "support_"):
            raise RuntimeError("Model has not been trained yet. Call fit_model first.")
        logger.info(f"Making probability predictions with SVC model on {X.shape[0]} samples.")
        return self.model.predict_proba(X)

    def forward(self, batch: Any) -> Any:
        """
        Forward pass for the SVC model.
        For scikit-learn models, this typically means calling predict.

        Args:
            batch (Any): Input data. Expected to be a NumPy array or similar
                         that self.model.predict can handle.

        Returns:
            Any: Predictions from the model.
        """
        # Assuming batch is X for sklearn models
        return self.predict(batch)

    def get_trained_model(self):
        """
        Get the trained model for logging purposes.
        
        Returns:
            The trained sklearn model instance.
        """
        if not hasattr(self.model, "support_"):
            raise RuntimeError("Model has not been trained yet. Call fit_model first.")
        return self.model 