






from typing import Dict, List

import torch
from torch.utils.data import DataLoader

from torchcp.classification.predictor import SplitPredictor
from torchcp.utils.common import calculate_conformal_value


class RC3PPredictor(SplitPredictor):
    """
    Rank Calibrated Class-conditional Conformal Prediction (RC3P) as described in 
    "Conformal Prediction for Class-wise Coverage via Augmented Label Rank Calibration" 
    by Shi et al., NeurIPS 2024.
    
    Args:
        score_function (callable): Non-conformity score function (e.g., APS or RAPS).
        model (torch.nn.Module, optional): A PyTorch model. Default is None.
        temperature (float, optional): Temperature for scaling logits. Default is 1.
    """

    def __init__(self, score_function, model=None):
        super().__init__(score_function, model)
        self.num_classes = None  
        self.class_thresholds = None  
        self.class_rank_limits = None  

    
    
    
    def calibrate(self, cal_dataloader, alpha):
        """
        Calibrate the RC3P predictor using class-wise conformal scores and label ranks.

        Args:
            cal_dataloader (DataLoader): Calibration data loader.
            alpha (float): Target miscoverage rate (0 < alpha < 1).
        """
        if not (0 < alpha < 1):
            raise ValueError("alpha should be a value in (0, 1).")

        if self._model is None:
            raise ValueError("Model is not defined. Please provide a valid model.")

        self._model.eval()
        logits_list = []
        labels_list = []
        with torch.no_grad():
            for examples in cal_dataloader:
                tmp_x, tmp_labels = examples[0].to(self._device), examples[1].to(self._device)
                tmp_logits = self._logits_transformation(self._model(tmp_x)).detach()
                logits_list.append(tmp_logits)
                labels_list.append(tmp_labels)
            logits = torch.cat(logits_list).float()
            labels = torch.cat(labels_list)
        
        
        self.num_classes = logits.shape[1]
                
        self.calculate_threshold(logits, labels, alpha)

    def calculate_threshold(self, logits, labels, alpha):
        """
        Perform class-wise calibration for conformal thresholds and label ranks.

        Args:
            logits (torch.Tensor): Model logits for calibration data.
            labels (torch.Tensor): True labels for calibration data.
            alpha (float): Target miscoverage rate.
        """
        logits = logits.to(self._device)
        labels = labels.to(self._device)
        
        num_classes = logits.shape[1]
        
        self.class_thresholds = torch.full(size=(self.num_classes,), fill_value=float('inf')).to(self._device)
        self.class_rank_limits = torch.full(size=(self.num_classes,), fill_value=self.num_classes).to(self._device)
        
        
        ranks = torch.argsort(logits, dim=1, descending=True)  

        for y in range(num_classes):
            
            mask = (labels == y)
            class_logits = logits[mask]  
            class_labels = labels[mask]  

            if class_logits.size(0) == 0:  
                continue

            
            scores = self.score_function(class_logits, class_labels)  

            
            class_ranks = ranks[mask]  
            y_tensor = torch.tensor(y, device=class_ranks.device)
            true_label_rank = (class_ranks == y_tensor.unsqueeze(-1)).nonzero(as_tuple=True)[1]  
            top_k_errors = []
            for k in range(1, num_classes + 1):
                error = (true_label_rank >= k).float().mean().item()  
                top_k_errors.append(error)

            
            k_y = next((k + 1 for k, err in enumerate(top_k_errors) if err < alpha), num_classes)
            epsilon_y = top_k_errors[k_y - 1] if k_y <= num_classes else 0
            alpha_y = alpha - epsilon_y  

            
            q_hat_y = calculate_conformal_value(scores, alpha_y)

            
            self.class_thresholds[y] = q_hat_y
            self.class_rank_limits[y] = k_y

    
    
    
    def predict(self, x_batch):
        """
        Generate prediction sets for a batch of instances using RC3P.

        Args:
            x_batch (torch.Tensor): A batch of input instances.

        Returns:
            torch.Tensor: Prediction sets for each instance in the batch (as boolean tensors).
        """
        if self._model is None:
            raise ValueError("Model is not defined. Please provide a valid model.")

        self._model.eval()
        x_batch = self._model(x_batch.to(self._device)).float()
        x_batch = self._logits_transformation(x_batch).detach()
        sets = self.predict_with_logits(x_batch)
        return sets

    def predict_with_logits(self, logits):
        """
        Generate prediction sets from logits using class-wise thresholds and rank limits.

        Args:
            logits (torch.Tensor): Model logits for test data (B, K).

        Returns:
            torch.Tensor: Prediction sets for each instance (as boolean tensors).
        """
        if self.class_thresholds is None:
            raise ValueError("Calibration not performed. Please run calibrate() first.")
            
        batch_size, num_classes = logits.shape
        ranks = torch.sort(logits, dim=1, descending=True)[1]  
        
        scores = self.score_function(logits)  
                
        
        
        expanded_ranks = ranks.unsqueeze(2).expand(batch_size, num_classes, num_classes)
        
        
        
        expanded_class_indices = torch.arange(num_classes, device=self._device) \
                                    .reshape(1, 1, num_classes) \
                                    .expand(batch_size, num_classes, num_classes)
        
        
        
        matches = (expanded_ranks == expanded_class_indices)
        
        
        b_indices, p_indices, c_indices = matches.nonzero(as_tuple=True)
        
        
        ranks_all = torch.zeros((batch_size, num_classes), dtype=torch.long, device=self._device)
        
        
        
        ranks_all[b_indices, c_indices] = p_indices + 1
        prediction_sets = (scores <= self.class_thresholds.unsqueeze(0)) & \
                (ranks_all <= self.class_rank_limits.unsqueeze(0))
        
        return prediction_sets
