






from typing import Dict, List

import torch
from torch.utils.data import DataLoader

from torchcp.classification.predictor.base import BasePredictor
from torchcp.utils.common import calculate_conformal_value


class SplitPredictor(BasePredictor):
    """
    Split Conformal Prediction (Vovk et a., 2005).
    Book: https://link.springer.com/book/10.1007/978-3-031-06649-8.
    
    Args:
        score_function (callable): Non-conformity score function.
        model (torch.nn.Module, optional): A PyTorch model. Default is None.
        temperature (float, optional): The temperature of Temperature Scaling. Default is 1.
        alpha (float, optional): The significance level. Default is 0.1.
        class_conditional (bool, optional): Whether to use class-conditional conformal prediction. Default is False.
    """

    def __init__(self, score_function, model=None, temperature=1, alpha=0.1, class_conditional: bool = False):
        super().__init__(score_function, model, temperature)
        self.alpha = alpha
        self.class_conditional = class_conditional
        self.q_hat_by_class: Dict[int, float] = {}
        self.calibration_scores: torch.Tensor | None = None
        self.calibration_labels: torch.Tensor | None = None

    
    
    
    def calibrate(self, cal_dataloader, alpha=None):
        if alpha is None:
            alpha=self.alpha
        if not (0 < alpha < 1):
            raise ValueError("alpha should be a value in (0, 1).")

        if self._model is None:
            raise ValueError("Model is not defined. Please provide a valid model.")

        self._model.eval()
        logits_list = []
        labels_list = []
        with torch.no_grad():
            for examples in cal_dataloader:
                tmp_x, tmp_labels = examples[0].to(self._device), examples[1].to(self._device)
                tmp_logits = self._logits_transformation(self._model(tmp_x)).detach()
                logits_list.append(tmp_logits)
                labels_list.append(tmp_labels)
            logits = torch.cat(logits_list).float()
            labels = torch.cat(labels_list)
        self.calculate_threshold(logits, labels, alpha)

    def calibrate_with_logits(self, cal_logits, cal_labels, alpha=None):
        """
        Calibrate using pre-computed logits and labels.
        
        Args:
            cal_logits (torch.Tensor): Pre-computed calibration logits.
            cal_labels (torch.Tensor): Calibration labels.
            alpha (float, optional): The significance level. Default is None.
        """
        if alpha is None:
            alpha = self.alpha
        if not (0 < alpha < 1):
            raise ValueError("alpha should be a value in (0, 1).")
        
        self.calculate_threshold(cal_logits, cal_labels, alpha)

    def calculate_threshold(self, logits, labels, alpha=None):
        if alpha is None:
            alpha = self.alpha
        logits = logits.to(self._device)
        labels = labels.to(self._device)
        scores = self.score_function(logits, labels)
        
        
        self.calibration_scores = scores
        self.calibration_labels = labels
        
        
        self.q_hat = self._calculate_conformal_value(scores)
        
        
        if self.class_conditional:
            unique_classes = labels.unique()
            self.q_hat_by_class = {}
            for class_id in unique_classes:
                scores_c = scores[labels == class_id]
                q_hat_c = self._calculate_conformal_value(scores_c, alpha)
                self.q_hat_by_class[class_id.item()] = q_hat_c

    def _calculate_conformal_value(self, scores, alpha=None):
        if alpha is None:
            alpha = self.alpha
        return calculate_conformal_value(scores, alpha)

    
    
    
    def predict(self, x_batch):
        """
        Generate prediction sets for a batch of instances.

        Args:
            x_batch (torch.Tensor): A batch of instances.

        Returns:
            list: A list of prediction sets for each instance in the batch.
        """

        if self._model is None:
            raise ValueError("Model is not defined. Please provide a valid model.")

        self._model.eval()
        x_batch = self._model(x_batch.to(self._device)).float()
        x_batch = self._logits_transformation(x_batch).detach()
        sets = self.predict_with_logits(x_batch)
        return sets

    def predict_with_logits(self, logits, q_hat=None):
        """
        Generate prediction sets from logits.

        Args:
            logits (torch.Tensor): Model output before softmax.
            q_hat (torch.Tensor, optional): The conformal threshold. Default is None.

        Returns:
            list: A list of prediction sets for each instance in the batch.
        """

        scores = self.score_function(logits).to(self._device)
        
        if self.class_conditional:
            predicted_classes = logits.argmax(dim=1)
            batch_sets = []
            for i in range(logits.shape[0]):
                pred_class_i = predicted_classes[i].item()
                quantile_i = self.q_hat_by_class.get(pred_class_i, self.q_hat)
                scores_i = scores[i, :]
                set_i = (scores_i <= quantile_i)
                batch_sets.append(set_i)
            S = torch.stack(batch_sets, dim=0)
            return S
        else:
            if q_hat is None:
                if self.q_hat is None:
                    raise ValueError("Ensure self.q_hat is not None. Please perform calibration first.")
                q_hat = self.q_hat

            S = self._generate_prediction_set(scores, q_hat)
            return S

    
    
    

    def evaluate(self, val_dataloader: DataLoader) -> Dict[str, float]:
        """
        Evaluate prediction sets on validation dataset.
        
        Args:
            val_dataloader (torch.utils.data.DataLoader): Dataloader for validation set.
        
        Returns:
            dict: Dictionary containing evaluation metrics:
                - Coverage_rate: Empirical coverage rate on validation set
                - Average_size: Average size of prediction sets
        """
        predictions_sets_list: List[torch.Tensor] = []
        labels_list: List[torch.Tensor] = []

        
        self._model.eval()
        with torch.no_grad():
            for batch in val_dataloader:
                
                inputs = batch[0].to(self._device)
                labels = batch[1].to(self._device)

                
                batch_predictions = self.predict(inputs)

                
                predictions_sets_list.append(batch_predictions)
                labels_list.append(labels)

        
        val_prediction_sets = torch.cat(predictions_sets_list, dim=0)  
        val_labels = torch.cat(labels_list, dim=0)  

        
        num_classes = val_prediction_sets.shape[1]

        
        metrics = {
            "coverage_rate": self._metric('coverage_rate')(val_prediction_sets, val_labels),
            "average_size": self._metric('average_size')(val_prediction_sets, val_labels),
            "CovGap": self._metric('CovGap')(val_prediction_sets, val_labels, self.alpha, num_classes)
        }

        return metrics

    def evaluate_with_logits(self, test_logits: torch.Tensor, test_labels: torch.Tensor, diff_violation=False) -> Dict[str, float]:
        """
        Evaluate prediction sets using pre-computed logits and labels.
        
        Args:
            test_logits (torch.Tensor): Pre-computed test logits.
            test_labels (torch.Tensor): Test labels.
        
        Returns:
            dict: Dictionary containing evaluation metrics:
                - coverage_rate: Empirical coverage rate on test set
                - average_size: Average size of prediction sets
                - CovGap: Coverage gap metric
        """
        
        test_prediction_sets = self.predict_with_logits(test_logits)
        
        
        num_classes = test_prediction_sets.shape[1]
        
        
        metrics = {
            "coverage_rate": self._metric('coverage_rate')(test_prediction_sets, test_labels),
            "MacroCoverageRate": self._metric('coverage_rate')(test_prediction_sets, test_labels, coverage_type="macro", num_classes=num_classes),
            "average_size": self._metric('average_size')(test_prediction_sets, test_labels),
            "CovGap": self._metric('CovGap')(test_prediction_sets, test_labels, self.alpha, num_classes),
            "SSCV": self._metric('SSCV')(test_prediction_sets, test_labels, self.alpha),
            "VioClasses": self._metric('VioClasses')(test_prediction_sets, test_labels, self.alpha, num_classes),
            "EmptySetsPercentage": self._metric('empty_sets_percentage')(test_prediction_sets)
        }

        if diff_violation:
            metrics["DiffViolation"] = self._metric('DiffViolation')(test_logits, test_prediction_sets, test_labels, self.alpha)[1]

        return metrics
    