import torch
import torch.nn.functional as F
from scipy.stats import norm

from qrcp.methods.robust_cp import RobustCP
from qrcp.robust.confidence import dkw_cdf, clopper_pearson_lower

class QRCPThresholds(RobustCP):
    def __init__(self, smoothing_sigma=0.0, confidence_level=0.999, n_dcal=None, n_classes=None, cutoff_prob=0.5,
                 error_correction=True, **kwargs):
        self.smoothing_sigma = smoothing_sigma
        self.eta = 1 - confidence_level
        self.cutoff_prob = cutoff_prob
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.error_correction = error_correction

        super().__init__(**kwargs)
        # self.confidence_adv = norm.cdf(self.r, scale=self.smoothing_sigma) # this does not include the cutoff probability
        self.confidence_adv = norm.cdf(norm.ppf(self.cutoff_prob) + self.r / self.smoothing_sigma)
        self.n_dcal = n_dcal
        self.n_classes = n_classes
    
    def correct_coverage_guarantee(self):
        if self.error_correction == False:
            self.final_coverage = self.nominal_coverage
            return
        self.final_coverage = self.nominal_coverage + self.eta
        # print("I am called with ", self.final_coverage)

    def calibrate_from_scores(self, S_sample, y, return_scores=False):
        if self.error_correction == False:
            calibration_scores = S_sample.quantile(self.cutoff_prob, dim=-1)
        else:
            true_smooth_scores = S_sample[torch.arange(S_sample.shape[0]), y]
            dkw_results = [dkw_cdf(true_smooth_scores[i], confidence=1 - self.eta/(2 * self.n_dcal), bonferroni_tasks=1, type="upper") for i in range(true_smooth_scores.shape[0])]
            calibration_scores = torch.tensor([dkw_results[i][1][dkw_results[i][0] < self.cutoff_prob][-1] for i in range(len(dkw_results))])
        # print("calibration_dim ", calibration_scores.shape, calibration_scores.ndim)
        self.conformal_threshold = self.calibrate_from_refined_scores(calibration_scores, y)
        if return_scores:
            return self.conformal_threshold, calibration_scores
        return self.conformal_threshold
    
    def predict_from_scores(self, S_sample, return_scores=False):    
        pred_votes = (S_sample <= self.conformal_threshold).sum(dim=-1)
        if self.error_correction:
            corrected_votes = torch.tensor(
                    clopper_pearson_lower(pred_votes.cpu(), S_sample.shape[-1], alpha=self.eta/(self.n_dcal + self.n_classes))).to(self.device)
        else:
            corrected_votes = pred_votes
        test_scores = corrected_votes
        
        pred_sets = self.predict_from_refined_scores(test_scores)
        if return_scores:
            return pred_sets, S_sample
        return pred_sets

    def pre_compute_predict(self, test_mask):
        S_sample = self.precompute__test_scores[test_mask]
        pred_votes = (S_sample <= self.conformal_threshold).sum(dim=-1)
        corrected_votes = torch.tensor(
                clopper_pearson_lower(pred_votes.cpu(), S_sample.shape[-1], alpha=self.eta/(self.n_dcal + self.n_classes))).to(self.device)

        test_scores = corrected_votes
        
        pred_sets = self.predict_from_refined_scores(test_scores)
        return pred_sets
    
    def predict_from_refined_scores(self, test_scores):
        pred_set = ~(test_scores > self.confidence_adv)
        return pred_set
    
    def calibrate_from_refined_scores(self, calibration_scores, y):
        if calibration_scores.ndim == 1:
            y_true_mask = torch.ones(size=(calibration_scores.shape[0], 1), dtype=torch.bool)
            # print(y_true_mask)
            # print(calibration_scores.shape)
            self.conformal_threshold = self.internal_cp.calibrate_from_scores(calibration_scores.reshape(-1, 1), y_true_mask)
        else:
            y_true_mask = F.one_hot(y, num_classes=calibration_scores.shape[1]).bool()
            self.conformal_threshold = self.internal_cp.calibrate_from_scores(calibration_scores, y_true_mask)
        return self.conformal_threshold