import torch
from typing import Literal
import math
from scipy.stats import norm
from bin_cp.methods.robust_cp import RobustCP
from bin_cp.robust.confidence import clopper_pearson_lower, clopper_pearson_upper
from sparse_smoothing.cert import regions_binary, compute_rho, binary_certificate_grid, compute_rho_for_many
from bin_cp.methods.rcp_one import RCP1Plus
from scipy.stats import binom


class FloRK(RCP1Plus):
    def __init__(self, *args, bias=0.1, coverage_bias=0.01, k_samples=1, **kwargs):
        super().__init__(*args, **kwargs)
        self.bias = bias
        self.k_samples = k_samples
        self.coverage_bias = coverage_bias
        if bias is None:
            required_offset = find_exceeding_point(self.k_samples, biased_cov=self.nominal_coverage + coverage_bias, target_coverage=self.nominal_coverage)
            self.bias = round(required_offset - 0.5, 3) + 1e-3

    def calibrate_from_scores(self, S_sampled, y, return_scores=False, return_guarantee=True, **kwargs):

        tuning_scores = S_sampled[:, :, -self.tuning_n:]
        empirical_scores = S_sampled[:, :, :-self.tuning_n]

        threshold_min = S_sampled.min().item()
        threshold_max = S_sampled.max().item()

        while threshold_max - threshold_min > 1e-7:
            threshold = (threshold_min + threshold_max) / 2

            beta_estimates = (tuning_scores >= threshold).float().mean(-1)
            true_betas = beta_estimates[torch.arange(tuning_scores.shape[0]), y]

            # estimating what the confidence intervals would be over the true samples.
            probs = torch.tensor(clopper_pearson_lower(true_betas.cpu() * empirical_scores.shape[-1], empirical_scores.shape[-1], alpha=self.delta/empirical_scores.shape[0]), device=S_sampled.device)
            probs = torch.cat([probs, torch.tensor([0.0], device=probs.device)])
            certified_values = self.certificate_function(probs)
            certified_values = torch.mean(torch.tensor(certified_values > 0.5 + self.bias).float())
            
            # print(" threshold=", threshold, "certified_values = ", certified_values.item(),  "rcp1guarantee = ", selfcertificate_function(true_probs.mean()))
            if certified_values < self.nominal_coverage + self.delta + self.coverage_bias:
                threshold_max = threshold
            else:
                threshold_min = threshold
            if threshold_max <= S_sampled.min().item():
                threshold_min = -1000
                break

        self.conformal_threshold = threshold_min 
        
        
        empirical_betas_estimates = (empirical_scores >= threshold_min).float().mean(-1)
        empirical_true_betas = empirical_betas_estimates[torch.arange(empirical_scores.shape[0]), y]
        empirical_probs = torch.tensor(clopper_pearson_lower(empirical_true_betas.cpu() * empirical_scores.shape[-1], empirical_scores.shape[-1], alpha=self.delta/empirical_scores.shape[0]), device=S_sampled.device)
        empirical_probs = torch.cat([empirical_probs, torch.tensor([0.0], device=empirical_probs.device)])
        guarantee_estimation = self.certificate_function(empirical_probs)
        guarantee_estimation = torch.mean(torch.tensor((guarantee_estimation > 0.5 + self.bias)).float())
        self.exact_guarantee = guarantee_estimation.item() - self.delta 

        if return_guarantee:
            return threshold_min, self.exact_guarantee

        if return_scores:
            return threshold_min, S_sampled
        return threshold_min

    def return_worst_case_guarantee(self, S_sampled, y, **kwargs):
        threshold = self.internal_cp.calibrate_from_scores(S_sampled[:, :, 0], y_true_mask=torch.nn.functional.one_hot(y, num_classes=S_sampled.shape[1]).bool())
        empirical_betas_estimates = (S_sampled >= threshold).float().mean(-1)
        empirical_true_betas = empirical_betas_estimates[torch.arange(S_sampled.shape[0]), y]
        empirical_probs = torch.tensor(clopper_pearson_lower(empirical_true_betas.cpu() * S_sampled.shape[-1], S_sampled.shape[-1], alpha=self.delta/S_sampled.shape[0]), device=S_sampled.device)
        empirical_probs = torch.cat([empirical_probs, torch.tensor([0.0], device=empirical_probs.device)])
        exact_guarantee = self.certificate_function(empirical_probs)
        exact_guarantee = torch.mean(torch.tensor((exact_guarantee > 0.5 + self.bias)).float())
        exact_guarantee = exact_guarantee.item() - self.delta
        return exact_guarantee


    def predict_from_scores(self, S_sampled, return_scores=False):
        if S_sampled.ndim == 2:
            raise ValueError("S_sampled must be 3-dimensional for derandomized prediction.")
        scores = S_sampled[:, :, :self.k_samples]

        pred_set = (scores >= self.conformal_threshold).float()
        
        pred_set = pred_set.mean(-1) >= 0.5
            # print(f"base={self.p_base}, conf={confidence_p}")

        return pred_set
    



def binomial_cdf(k, n, p):
    """
    Compute P(X <= k) for X ~ Binomial(n, p).

    Parameters
    ----------
    k : int or array-like
        Upper bound (inclusive).
    n : int
        Number of trials.
    p : float
        Success probability.

    Returns
    -------
    float or ndarray
        Binomial CDF evaluated at k.
    """
    return binom.cdf(k, n, p)


# Example:
# P(X <= 3) where X ~ Binomial(n=7, p=0.6)
print(binomial_cdf(3, 7, 0.6))


def kl_binomial(p, q):
    p = torch.tensor(p)
    q = torch.tensor(q)
    return p * torch.log(p / q) + (1 - p) * torch.log((1 - p) / (1 - q))

def find_exceeding_point(n, biased_cov=0.907, target_coverage=0.9):
    prob_range = torch.linspace(0.5, 1.0, steps=1000)
    y = (1 - binomial_cdf((n - 1) // 2, n, prob_range)) * biased_cov
    exceeding_points = prob_range[y > target_coverage]
    if len(exceeding_points) == 0:
        return None
    return exceeding_points[0].item()
# print(find_exceeding_point(51, biased_cov=0.91, target_coverage=0.9))