import torch
import torch.nn.functional as F
import pandas as pd
from scipy.stats import norm
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pathlib
import os
from copy import deepcopy

print("Started running marginal-binary-certificate-updated.py")

from bin_cp.helpers.storage import load_smooth_prediction
from bin_cp.helpers.tensor import get_smooth_scores, get_cal_mask, quantization_pdf, bound_tensor
from bin_cp.robust.confidence import bernstein_bound, dkw_cdf
from bin_cp.robust.confidence import clopper_pearson_lower, clopper_pearson_upper
from bin_cp.robust.bounds import mean_bounds_l2, CDF_bounds_l2

from bin_cp.cp.core import ConformalClassifier as CP
from bin_cp.cp.scores import APSScore, TPSScore, LogitScore

from bin_cp.methods.robust_cp import RobustCP, VanillaSmoothCP
from bin_cp.methods.cas import CAS
from bin_cp.methods.bincp import BinCP
from bin_cp.methods.rcp_one import RCP1, RCP1Plus
# from qrcp.methods.binary import QRCPThresholds
import time

from tqdm import tqdm

nn = 500
nk = 10000

from statsmodels.stats.proportion import proportion_confint
proportion_confint(
        0.0 * nk, nk, alpha=0.001/(nn + 2), method="beta")

# IMPORTANT: make sure that the result is close to 0.0 otherwise there is a bug due to the version of scipy

#region primary configs of the experiment

output_dir = "./results/"

dataset_name = "cifar10"
model_sigma = 0.5
n_classes=10
n_datapoints = 2048
smoothing_sigma = 0.5
n_samples = 10000
n_trial_samples = 10000

score_method = "TPS"
calibration_budget = 0.5
n_iterations = 100

confidence = 0.999
coverage_range = [0.85, 0.9, 0.95]
# coverage_range = [0.9]
r_range = [0.06, 0.12, 0.18, 0.25, 0.37, 0.5, 0.75]
# r_range = [0.12,  0.25, 0.5, ]

marginal_coverage = 0.8
n_0 = 50
#endregion


class ProxyRobustnessCertificate(object):
    def __init__(self, r=0, smoothing_sigma=0, marginal_guarantee=0.9, proxy_certificate=None, risk_setup="marginal", proxy_setup="vanilla", **kwargs):
        self.r = r
        self.smoothing_sigma = smoothing_sigma
        self._marginal_guarantee = marginal_guarantee
        self.proxy_setup = proxy_setup

        if self.proxy_setup == "vanilla":
            self.marginal_guarantee = deepcopy(self._marginal_guarantee)
            
        else:
            self.marginal_guarantee = self.adjust_robust_marginal_guarantee()

        if "certificate_setup" in kwargs:
            self.certificate_setup = kwargs["certificate_setup"]
            self.certificate_confidence = self.certificate_setup.get("confidence", 0.999)
            self.certificate_n_0 = self.certificate_setup.get("n_0", 100)
        
        self.proxy_certificate = proxy_certificate  # proxy certificate receives two arguments: logits, lambd and returns two elements: robust_labels, y_top
        self.risk_setup = risk_setup
        # for logits samples are always along the dim=1

        self._acceptable_risk = None
        self._acceptable_lambd = None

    def adjust_robust_marginal_guarantee(self):
        if self.r == 0:
            adjusted_guarantee = deepcopy(self._marginal_guarantee)
        else:
            adjusted_guarantee = norm.cdf(norm.ppf(self._marginal_guarantee, scale=self.smoothing_sigma) + self.r, scale=self.smoothing_sigma)
        return adjusted_guarantee

    def hard_certify(self, logits, y_top=None, dim=1, r_coef=1):
        precomputed = (y_top is not None)
        if y_top is None:
            y_pred = torch.argmax(logits[:, :self.certificate_n_0, :], dim=dim)
            votes, y_top = torch.nn.functional.one_hot(y_pred, num_classes=logits.shape[-1]).float().mean(1).max(-1)

        votes_intensive = (torch.argmax(logits[:, self.certificate_n_0 if not precomputed else 0:, :], dim=-1) == y_top.reshape(-1, 1)).sum(-1)
        probs = clopper_pearson_lower(votes_intensive.cpu(), logits.shape[1] - (0 if precomputed else self.certificate_n_0), alpha=1 - self.marginal_guarantee)
        prob_lower = norm.cdf(norm.ppf(probs, scale=self.smoothing_sigma) - r_coef * self.r, scale=self.smoothing_sigma)
        robust_labels = torch.tensor(prob_lower >= 0.5)
        return robust_labels, y_top
    
    def tune_proxy_certificate(self, logits, dim=2, verbose=True):
        certified, y_top_certified = self.hard_certify(logits, r_coef= 1 if self.proxy_setup == "vanilla" else 2, dim=dim)
        certified = certified.to(device)
        y_top_certified = y_top_certified.to(device)

        lambd_min = 0.0
        lambd_max = 1.0
        while lambd_max - lambd_min > 1e-7:
            lambd_mid = (lambd_min + lambd_max) / 2
            certified_estim, y_top_estim = self.proxy_certificate(logits, lambd_mid, y_top=y_top_certified)

            y_top_acc = (y_top_estim == y_top_certified).float().mean().item()
            robust_acc = ((certified_estim == certified).float().mean().item())
            
            if self.risk_setup == "marginal":
                # risk = ((~certified) & (certified_estim)) | (y_top_estim != y_top_certified)
                risk = ((~certified) & (certified_estim)) # | (y_top_estim != y_top_certified) TODO: RETURN
                risk = (risk.float().sum().item() + 1) / (logits.shape[0] + 1)
            else:
                risk = (y_top_estim != y_top_certified) | (certified_estim)
                risk = risk[~certified]
                risk = (risk.float().sum().item() + 1) / ((~certified).sum().item() + 1)

            if self._acceptable_risk is None and risk <= (1 - self.marginal_guarantee):
                self._acceptable_risk = risk
                self._acceptable_lambd = lambd_mid
            elif self._acceptable_risk is not None and risk >= self._acceptable_risk and risk <= (1 - self.marginal_guarantee):
                self._acceptable_risk = risk
                self._acceptable_lambd = lambd_mid

            if risk > (1 - self.marginal_guarantee):
                lambd_min = lambd_mid
            else:
                lambd_max = lambd_mid

            if verbose:    
                print(f"lambda: {lambd_mid}, robust accuracy: {robust_acc:.4f}, y_top accuracy: {y_top_acc:.4f}, risk: {risk:.4f}")
        if verbose:
            if self._acceptable_lambd is not None:
                print(f"Found acceptable lambda: {self._acceptable_lambd} with risk: {self._acceptable_risk}")
            else: 
                print("No acceptable lambda found.")
    
    def certify_proxy(self, logits, y_top=None, dim=2):
        if self._acceptable_lambd is None:
            print(Warning("Proxy certificate not tuned yet. Please run tune_proxy_certificate first."))
            result, result_label = torch.tensor([False]*logits.shape[0]), torch.tensor([-1]*logits.shape[0])
        else:
            result, result_label = self.proxy_certificate(logits, self._acceptable_lambd, y_top=y_top)
        return result, result_label

def single_sample_proxy_certificate(logits, lambd, y_top=None, vectorized=False):
    if logits.ndim == 3 and not vectorized:
        computable_logits = logits[:, 0, :] 
    else:
        computable_logits = logits

    softmaxes = F.softmax(computable_logits, dim=-1)
    if y_top is not None:
        if not vectorized:
            conf = softmaxes[torch.arange(softmaxes.shape[0]), y_top]
        else:
            conf = torch.stack([softmaxes[torch.arange(softmaxes.shape[0]), i, y_top] for i in range(softmaxes.shape[1])]).permute(1, 0)
    else:
        if vectorized:
            raise NotImplementedError("Vectorized proxy certificate requires y_top to be provided.")
        conf, y_top = softmaxes.max(dim=-1)

    votes = (conf >= lambd)
    return votes, y_top

def estimate_y_top(logits, n_0, dim=2):
    if logits.ndim == 3:
        computable_logits = logits[:, :n_0, :] 
    else:
        computable_logits = logits

    y_pred = torch.argmax(computable_logits, dim=dim)
    votes, y_top = torch.nn.functional.one_hot(y_pred, num_classes=logits.shape[-1]).float().mean(1).max(-1)
    return y_top


class ProxyNonconvexRobustnessCertificate(ProxyRobustnessCertificate):
    def __init__(self, ci=0.999, **kwargs):
        super().__init__(**kwargs)
        self.ci = ci
    
    def adjust_robust_marginal_guarantee(self):
        adjusted_guarantee = deepcopy(self._marginal_guarantee)
        return adjusted_guarantee
    
    def tune_proxy_certificate(self, logits, dim=2, verbose=True):
        certified, y_top_certified = self.hard_certify(logits, r_coef= 1 if self.proxy_setup == "vanilla" else 2, dim=dim)
        certified = certified.to(device)
        y_top_certified = y_top_certified.to(device)

        lambd_min = 0.0
        lambd_max = 1.0
        while lambd_max - lambd_min > 1e-7:
            lambd_mid = (lambd_min + lambd_max) / 2
            certified_estim, y_top_estim = self.proxy_certificate(logits, lambd_mid, y_top=y_top_certified, vectorized=True)
            
            robust_acc = ((certified_estim == certified.reshape(-1, 1)).float().mean(0).mean().item())
            
            if self.risk_setup == "marginal":
                # risk = ((~certified) & (certified_estim)) | (y_top_estim != y_top_certified)
                risk = ((~certified.reshape(-1, 1)) & (certified_estim)) # | (y_top_estim != y_top_certified) TODO: RETURN
                risk_votes = risk.float().sum(-1)
                upper_vote_proba = clopper_pearson_upper(risk_votes.cpu(), logits.shape[1], alpha=1 - self.ci)
                threat_proba = norm.cdf(norm.ppf(upper_vote_proba, scale=self.smoothing_sigma) + self.r, scale=self.smoothing_sigma)
                
                # risk = (risk.float().sum().item() + 1) / (logits.shape[0] + 1)
                risk = (threat_proba.sum().item() + 1.0) / (logits.shape[0] + 1)
            else:
                raise NotImplementedError("Non-convex risk setup only implemented for marginal risk.")

            if self._acceptable_risk is None and risk <= (1 - self.marginal_guarantee):
                self._acceptable_risk = risk
                self._acceptable_lambd = lambd_mid
            elif self._acceptable_risk is not None and risk >= self._acceptable_risk and risk <= (1 - self.marginal_guarantee):
                self._acceptable_risk = risk
                self._acceptable_lambd = lambd_mid

            if risk > (1 - self.marginal_guarantee):
                lambd_min = lambd_mid
            else:
                lambd_max = lambd_mid

            if verbose:    
                print(f"lambda: {lambd_mid}, robust accuracy: {robust_acc:.4f}, risk: {risk:.4f}")
        if verbose:
            if self._acceptable_lambd is not None:
                print(f"Found acceptable lambda: {self._acceptable_lambd} with risk: {self._acceptable_risk}")
            else: 
                print("No acceptable lambda found.")
    


output_dir = pathlib.Path(output_dir)/dataset_name
output_dir.mkdir(parents=True, exist_ok=True)


models_dir = pathlib.Path("")
dataset_dir = pathlib.Path("")
logits_dir = pathlib.Path("")


for r in [0.06, 0.12, 0.18, 0.25, 0.5]:
    print(f"\n\n\nStarting experiments for r = {r}\n\n\n")
    #region loding smooth logit predictions
    if n_samples < n_trial_samples:
        print(f"Number of trial samples is set to {n_trial_samples} as it is smaller than the number of samples.")
        
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("cpu")

    try:
        smooth_prediction = load_smooth_prediction(dataset_name=dataset_name,
            model_sigma=model_sigma,
            n_datapoints=n_datapoints,
            smoothing_sigma=smoothing_sigma,
            n_samples=n_samples,
            models_dir=models_dir,
            dataset_dir=dataset_dir,
            logits_dir=logits_dir,)
        n_classes = 10 if dataset_name == "cifar10" else None
    except FileNotFoundError as e:
        print("Smooth predictions not found, you can generate them using bin/smooth_logits_clean.py")
        print("Full description of the error: ", e)

    try:
        smooth_prediction_pert = load_smooth_prediction(dataset_name=dataset_name,
            model_sigma=model_sigma,
            n_datapoints=n_datapoints,
            smoothing_sigma=smoothing_sigma,
            models_dir=models_dir,
            dataset_dir=dataset_dir,
            logits_dir=logits_dir,
            n_samples=n_samples, r=r, attack="pgd_rs")
        
    except FileNotFoundError as e:
        print("Smooth predictions for perturbation not found, you can generate them using bin/smooth_logits_pert.py")
        print("Full description of the error: ", e)

    #endregion


    proxy_certificate = ProxyRobustnessCertificate(
        r=r,
        smoothing_sigma=smoothing_sigma,
        marginal_guarantee=marginal_coverage,
        certificate_setup={"confidence": confidence, "n_0": n_0}, 
        proxy_certificate=single_sample_proxy_certificate,
        proxy_setup="vanilla", risk_setup="marginal"
    )
    robust_labels, y_top = proxy_certificate.hard_certify(smooth_prediction.logits, r_coef=1, dim=2)
    robust_labels_pert, y_top_pert = proxy_certificate.hard_certify(smooth_prediction_pert.logits, r_coef=1, dim=2)
    n_robust = robust_labels.sum().item()
    n_robust_pert = robust_labels_pert.sum().item()
    print(f"Number of robustly classified points (clean): {n_robust} out of {n_datapoints}, ratio: {n_robust/n_datapoints:.4f}")
    print(f"Number of robustly classified points (perturbed): {n_robust_pert} out of {n_datapoints}, ratio: {n_robust_pert/n_datapoints:.4f}")
    certified_ratio_clean = n_robust / n_datapoints
    certified_ratio_pert = n_robust_pert / n_datapoints

    smooth_prediction.logits = smooth_prediction.logits.to(device)


    guarantee_range = [0.8, 0.9, 0.95]

    results = []

    for marginal_coverage in guarantee_range:
        for setup in ["vanilla", "robust"]:
            proxy_certificate = ProxyRobustnessCertificate(
                r=r,
                smoothing_sigma=smoothing_sigma,
                marginal_guarantee=marginal_coverage,
                certificate_setup={"confidence": confidence, "n_0": n_0}, 
                proxy_certificate=single_sample_proxy_certificate,
                proxy_setup=setup, risk_setup="marginal"
            )

            print(f"\n\n nominal marginal guarantee: {marginal_coverage}, adjusted robust marginal guarantee: {proxy_certificate.marginal_guarantee}")
            print("\n\n")

            for i in tqdm(range(10)):
                cal_mask = get_cal_mask(smooth_prediction.logits, fraction=calibration_budget)
                eval_mask = ~cal_mask

                proxy_certificate.tune_proxy_certificate(smooth_prediction.logits[cal_mask], dim=2, verbose=False)
                eval_y_top = estimate_y_top(smooth_prediction.logits[eval_mask], n_0=n_0, dim=2)
                proxy_prediction_clean, proxy_labels = proxy_certificate.certify_proxy(smooth_prediction.logits[eval_mask], dim=2, y_top=eval_y_top)
                eval_true_labels, hard_certificate_labels = proxy_certificate.hard_certify(smooth_prediction.logits[eval_mask][:, n_0:, :], dim=2, y_top=eval_y_top)

                certified_ratio_clean = eval_true_labels.float().mean().item()
                certified_ratio_clean_proxy = proxy_prediction_clean.float().mean().item()
                proxy_label_acc = (proxy_labels == hard_certificate_labels).float().mean().item()

                conditional_validity = 1 - ((proxy_prediction_clean == True)[eval_true_labels == False].sum().item() / eval_mask.sum().item())
                marginal_validity = 1 - (((proxy_prediction_clean == True) & (eval_true_labels == False)).sum().item() / eval_mask.sum().item())
                print(f"Certified ratio (proxy, clean): {certified_ratio_clean_proxy:.4f}, proxy label accuracy: {proxy_label_acc:.4f}, conditional validity: {conditional_validity:.4f} , marginal validity: {marginal_validity:.4f}")

                eval_y_top_pert = estimate_y_top(smooth_prediction_pert.logits[eval_mask], n_0=n_0, dim=2)
                proxy_prediction_pert, proxy_labels_pert = proxy_certificate.certify_proxy(smooth_prediction_pert.logits[eval_mask], y_top=eval_y_top_pert, dim=2)
                eval_true_labels_pert, hard_certificate_labels_pert = proxy_certificate.hard_certify(smooth_prediction_pert.logits[eval_mask][:, n_0:, :], dim=2)

                certified_ratio_pert = eval_true_labels_pert.float().mean().item()
                certified_ratio_pert_proxy = proxy_prediction_pert.float().mean().item()
                proxy_label_acc_pert = (proxy_labels_pert == hard_certificate_labels_pert).float().mean().item()
                conditional_validity_pert = 1 - ((proxy_prediction_pert == True)[eval_true_labels_pert == False].sum().item() / eval_mask.sum().item())
                marginal_validity_pert = 1 - (((proxy_prediction_pert == True) & (eval_true_labels_pert == False)).sum().item() / eval_mask.sum().item())
                print(f"Certified ratio (proxy, perturbed): {certified_ratio_pert_proxy:.4f}, proxy label accuracy: {proxy_label_acc_pert:.4f}, conditional validity: {conditional_validity_pert:.4f} , marginal validity: {marginal_validity_pert:.4f}")
                
                results.append({
                    "setup": setup,
                    "risk_type": "marginal",
                    "trial": i,
                    "r": r,
                    "smoothing_sigma": smoothing_sigma,
                    "n_samples": n_samples,
                    
                    "nominal_coverage": marginal_coverage,
                    "robust_nominal_coverage": proxy_certificate.marginal_guarantee,
                    
                    "certified_ratio_clean_proxy": certified_ratio_clean_proxy,
                    "certified_ratio_pert_proxy": certified_ratio_pert_proxy,

                    "certified_ratio_clean": certified_ratio_clean,
                    "certified_ratio_pert": certified_ratio_pert,

                    # "proxy_label_acc": proxy_label_acc,
                    "conditional_validity": conditional_validity,
                    "marginal_validity": marginal_validity,
                    "proxy_label_acc_pert": proxy_label_acc_pert,
                    "conditional_validity_pert": conditional_validity_pert,
                    "marginal_validity_pert": marginal_validity_pert,
                    "acceptable_risk": proxy_certificate._acceptable_risk,
                    "acceptable_lambd": proxy_certificate._acceptable_lambd,
                })
    results_convex_df = pd.DataFrame(results)


    results_convex_df.to_csv(output_dir/f"convex_marginal_binary_certificate_proxy_r{r}_sigma{smoothing_sigma}_n{n_samples}.csv", index=False)
    results_convex_df.groupby(["nominal_coverage", "setup", "risk_type"]).mean()



    guarantee_range = [0.8, 0.9, 0.95]

    results = []

    for marginal_coverage in guarantee_range:
        for setup in ["vanilla", "robust"]:
            proxy_certificate = ProxyNonconvexRobustnessCertificate(
                r=r,
                smoothing_sigma=smoothing_sigma,
                marginal_guarantee=marginal_coverage,
                certificate_setup={"confidence": confidence, "n_0": n_0}, 
                proxy_certificate=single_sample_proxy_certificate,
                proxy_setup=setup, risk_setup="marginal"
            )

            print(f"\n\n nominal marginal guarantee: {marginal_coverage}, adjusted robust marginal guarantee: {proxy_certificate.marginal_guarantee}")
            print("\n\n")

            for i in tqdm(range(10)):
                cal_mask = get_cal_mask(smooth_prediction.logits, fraction=calibration_budget)
                eval_mask = ~cal_mask

                proxy_certificate.tune_proxy_certificate(smooth_prediction.logits[cal_mask], dim=2, verbose=False)
                eval_y_top = estimate_y_top(smooth_prediction.logits[eval_mask], n_0=n_0, dim=2)
                proxy_prediction_clean, proxy_labels = proxy_certificate.certify_proxy(smooth_prediction.logits[eval_mask], dim=2, y_top=eval_y_top)
                eval_true_labels, hard_certificate_labels = proxy_certificate.hard_certify(smooth_prediction.logits[eval_mask][:, n_0:, :], dim=2, y_top=eval_y_top)

                certified_ratio_clean = eval_true_labels.float().mean().item()
                certified_ratio_clean_proxy = proxy_prediction_clean.float().mean().item()
                proxy_label_acc = (proxy_labels == hard_certificate_labels).float().mean().item()

                conditional_validity = 1 - ((proxy_prediction_clean == True)[eval_true_labels == False].sum().item() / eval_mask.sum().item())
                marginal_validity = 1 - (((proxy_prediction_clean == True) & (eval_true_labels == False)).sum().item() / eval_mask.sum().item())
                print(f"Certified ratio (proxy, clean): {certified_ratio_clean_proxy:.4f}, proxy label accuracy: {proxy_label_acc:.4f}, conditional validity: {conditional_validity:.4f} , marginal validity: {marginal_validity:.4f}")

                eval_y_top_pert = estimate_y_top(smooth_prediction_pert.logits[eval_mask], n_0=n_0, dim=2)
                proxy_prediction_pert, proxy_labels_pert = proxy_certificate.certify_proxy(smooth_prediction_pert.logits[eval_mask], y_top=eval_y_top_pert, dim=2)
                eval_true_labels_pert, hard_certificate_labels_pert = proxy_certificate.hard_certify(smooth_prediction_pert.logits[eval_mask][:, n_0:, :], dim=2)

                certified_ratio_pert = eval_true_labels_pert.float().mean().item()
                certified_ratio_pert_proxy = proxy_prediction_pert.float().mean().item()
                proxy_label_acc_pert = (proxy_labels_pert == hard_certificate_labels_pert).float().mean().item()
                conditional_validity_pert = 1 - ((proxy_prediction_pert == True)[eval_true_labels_pert == False].sum().item() / eval_mask.sum().item())
                marginal_validity_pert = 1 - (((proxy_prediction_pert == True) & (eval_true_labels_pert == False)).sum().item() / eval_mask.sum().item())
                print(f"Certified ratio (proxy, perturbed): {certified_ratio_pert_proxy:.4f}, proxy label accuracy: {proxy_label_acc_pert:.4f}, conditional validity: {conditional_validity_pert:.4f} , marginal validity: {marginal_validity_pert:.4f}")
                
                results.append({
                    "setup": setup,
                    "risk_type": "marginal",
                    "trial": i,
                    "r": r,
                    "smoothing_sigma": smoothing_sigma,
                    "n_samples": n_samples,
                    
                    "nominal_coverage": marginal_coverage,
                    "robust_nominal_coverage": proxy_certificate.marginal_guarantee,
                    
                    "certified_ratio_clean_proxy": certified_ratio_clean_proxy,
                    "certified_ratio_pert_proxy": certified_ratio_pert_proxy,

                    "certified_ratio_clean": certified_ratio_clean,
                    "certified_ratio_pert": certified_ratio_pert,

                    # "proxy_label_acc": proxy_label_acc,
                    "conditional_validity": conditional_validity,
                    "marginal_validity": marginal_validity,
                    "proxy_label_acc_pert": proxy_label_acc_pert,
                    "conditional_validity_pert": conditional_validity_pert,
                    "marginal_validity_pert": marginal_validity_pert,
                    "acceptable_risk": proxy_certificate._acceptable_risk,
                    "acceptable_lambd": proxy_certificate._acceptable_lambd,
                })
    results_non_convex_df = pd.DataFrame(results)


    results_non_convex_df.to_csv(output_dir/f"beta_marginal_binary_certificate_proxy_r{r}_sigma{smoothing_sigma}_n{n_samples}.csv", index=False)