import torch
from typing import Literal
import math
from scipy.stats import norm
from bin_cp.methods.robust_cp import RobustCP
from bin_cp.robust.confidence import clopper_pearson_lower, clopper_pearson_upper
from sparse_smoothing.cert import regions_binary, compute_rho, binary_certificate_grid, compute_rho_for_many

class RCP1(RobustCP):
    def __init__(self, smoothing_sigma=0.0, n_dcal=None, n_classes=None,
                 scheme:Literal["guass", "sparse", "exact", "laplace-l1", "uniform", "uniform-l1", "uniform-l2"]="guass",
                 dataset_key=None,
                 **kwargs):
        self.smoothing_sigma = smoothing_sigma
        self.n_dcal = n_dcal
        self.n_classes = n_classes
        self.scheme = scheme
        self.dataset_key = dataset_key        

        super().__init__(stage="calibration", **kwargs)

    def correct_coverage_guarantee(self):
        self.final_coverage = self.compute_threat_p(
            self.nominal_coverage, sigma=self.smoothing_sigma, r=self.r, type="upper",
            scheme=self.scheme)
        return self.final_coverage

    def compute_lower_bound_scores(self, S_sampled: torch.Tensor, y=None):
        if S_sampled.ndim > 2:
            scores = S_sampled[:, :, 0]
        else:
            scores = S_sampled
        return scores
        
    def predict_from_scores(self, S_sampled, return_scores=False):
        
        if S_sampled.ndim > 2:
            scores = S_sampled[:, :, 0]
        else:
            scores = S_sampled

        pred_set = scores >= self.internal_cp.quantile_threshold
            # print(f"base={self.p_base}, conf={confidence_p}")

        return pred_set
    
    @staticmethod
    def compute_threat_p(p, r, sigma, scheme:Literal["guass", "sparse", "laplace-l1", "uniform-l1", "uniform-l2"]="guass", type:Literal["lower", "upper"]="lower", dataset_key=None):
        if scheme == "guass":
            if type == "lower":
                conf_p = norm.cdf(norm.ppf(p, scale=sigma) - r, scale=sigma)
            else:
                conf_p = norm.cdf(norm.ppf(p, scale=sigma) + r, scale=sigma)
            return conf_p
        if scheme == "exact":
            if type == "lower":
                conf_p = p - 1/(2*(sigma) * math.sqrt(3)) * r
            else:
                conf_p = p + 1/(2*(sigma) * math.sqrt(3)) * r
            return conf_p
        if scheme == "laplace-l1":
            from bin_cp.robust.noises import Laplace, get_dim
            from bin_cp.robust.robust_bounds import compute_upper_p_from_r, compute_lower_p_from_r
            r_function = Laplace(dim=get_dim(dataset_key or "cifar10"), sigma=sigma)
            if dataset_key is None:
                Warning("Dataset key is not provided. Using CIFAR10 as default")
            if type == "lower":
                conf_p = compute_lower_p_from_r(torch.tensor(p), r, r_function.certify_l1)
            else:
                conf_p = compute_upper_p_from_r(torch.tensor(p), r, r_function.certify_l1)
            return conf_p
        if scheme == "uniform-l1":
            from bin_cp.robust.noises import Uniform, get_dim
            from bin_cp.robust.robust_bounds import compute_upper_p_from_r, compute_lower_p_from_r
            r_function = Uniform(dim=get_dim(dataset_key or "cifar"), sigma=sigma)
            if dataset_key is None:
                Warning("Dataset key is not provided, using CIFAR10 as default")
            if type == "lower":
                conf_p = compute_lower_p_from_r(torch.tensor(p), r, r_function.certify_l1)
            else:
                conf_p = compute_upper_p_from_r(torch.tensor(p), r, r_function.certify_l1)
            return conf_p
        if scheme == "uniform-l2":
            from bin_cp.robust.noises import Uniform, get_dim
            from bin_cp.robust.robust_bounds import compute_upper_p_from_r, compute_lower_p_from_r
            r_function = Uniform(dim=get_dim(dataset_key or "cifar"), sigma=sigma)
            if dataset_key is None:
                Warning("Dataset key is not provided, using CIFAR10 as default")
            if type == "lower":
                conf_p = compute_lower_p_from_r(torch.tensor(p), r, r_function.certify_l2)
            else:
                conf_p = compute_upper_p_from_r(torch.tensor(p), r, r_function.certify_l2)
            return conf_p

        raise NotImplementedError("Sparse scheme is not implemented yet")


class RCP1Plus(RobustCP):
    def __init__(self, *args, smoothing_sigma, tuning_n=200, confidence=0.99, **kwargs):
        super().__init__(*args, **kwargs)
        self.smoothing_sigma = smoothing_sigma
        self.tuning_n = tuning_n
        self.certificate_function = lambda p: norm.cdf(norm.ppf(p.cpu(), scale=self.smoothing_sigma) - self.r, scale=self.smoothing_sigma)
        self.exact_guarantee = None
        self.delta = 1 - confidence

    def calibrate_from_scores(self, S_sampled, y, return_scores=False, return_guarantee=True, **kwargs):

        tuning_scores = S_sampled[:, :, -self.tuning_n:]
        empirical_scores = S_sampled[:, :, :-self.tuning_n]

        threshold_min = S_sampled.min().item()
        threshold_max = S_sampled.max().item()

        while threshold_max - threshold_min > 1e-7:
            threshold = (threshold_min + threshold_max) / 2

            beta_estimates = (tuning_scores >= threshold).float().mean(-1)
            true_betas = beta_estimates[torch.arange(tuning_scores.shape[0]), y]

            # estimating what the confidence intervals would be over the true samples.
            probs = torch.tensor(clopper_pearson_lower(true_betas.cpu() * empirical_scores.shape[-1], empirical_scores.shape[-1], alpha=self.delta/empirical_scores.shape[0]), device=S_sampled.device)
            probs = torch.cat([probs, torch.tensor([0.0], device=probs.device)])
            certified_values = self.certificate_function(probs).mean()
            
            # print(" threshold=", threshold, "certified_values = ", certified_values.item(),  "rcp1guarantee = ", selfcertificate_function(true_probs.mean()))
            if certified_values < self.nominal_coverage + self.delta + 0.005:
                threshold_max = threshold
            else:
                threshold_min = threshold
            if threshold_max <= S_sampled.min().item():
                threshold_min = -1000
                break

        self.conformal_threshold = threshold_min
        
        
        empirical_betas_estimates = (empirical_scores >= threshold_min).float().mean(-1)
        empirical_true_betas = empirical_betas_estimates[torch.arange(empirical_scores.shape[0]), y]
        empirical_probs = torch.tensor(clopper_pearson_lower(empirical_true_betas.cpu() * empirical_scores.shape[-1], empirical_scores.shape[-1], alpha=self.delta/empirical_scores.shape[0]), device=S_sampled.device)
        empirical_probs = torch.cat([empirical_probs, torch.tensor([0.0], device=empirical_probs.device)])
        self.exact_guarantee = self.certificate_function(empirical_probs).mean().item() - self.delta

        if return_guarantee:
            return threshold_min, self.exact_guarantee

        if return_scores:
            return threshold_min, S_sampled
        return threshold_min

    def return_worst_case_guarantee(self, S_sampled, y, **kwargs):
        threshold = self.internal_cp.calibrate_from_scores(S_sampled[:, :, 0], y_true_mask=torch.nn.functional.one_hot(y, num_classes=S_sampled.shape[1]).bool())
        empirical_betas_estimates = (S_sampled >= threshold).float().mean(-1)
        empirical_true_betas = empirical_betas_estimates[torch.arange(S_sampled.shape[0]), y]
        empirical_probs = torch.tensor(clopper_pearson_lower(empirical_true_betas.cpu() * S_sampled.shape[-1], S_sampled.shape[-1], alpha=self.delta/S_sampled.shape[0]), device=S_sampled.device)
        empirical_probs = torch.cat([empirical_probs, torch.tensor([0.0], device=empirical_probs.device)])
        exact_guarantee = self.certificate_function(empirical_probs).mean().item() - self.delta
        return exact_guarantee


    def predict_from_scores(self, S_sampled, return_scores=False):
        if S_sampled.ndim > 2:
            scores = S_sampled[:, :, 0]
        else:
            scores = S_sampled

        pred_set = scores >= self.conformal_threshold
            # print(f"base={self.p_base}, conf={confidence_p}")

        return pred_set
    
    