import torch
import torch.nn.functional as F
import pandas as pd
from scipy.stats import norm
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pathlib
import os


from bin_cp.helpers.storage import load_smooth_prediction
from bin_cp.helpers.tensor import get_smooth_scores, get_cal_mask, quantization_pdf, bound_tensor
from bin_cp.robust.confidence import bernstein_bound, dkw_cdf
from bin_cp.robust.confidence import clopper_pearson_lower
from bin_cp.robust.bounds import mean_bounds_l2, CDF_bounds_l2

from bin_cp.cp.core import ConformalClassifier as CP
from bin_cp.cp.scores import APSScore, TPSScore, LogitScore

from bin_cp.methods.robust_cp import RobustCP, VanillaSmoothCP
from bin_cp.methods.cas import CAS
from bin_cp.methods.bincp import BinCP
from bin_cp.methods.rcp_one import RCP1
# from qrcp.methods.binary import QRCPThresholds

from scipy.stats import norm
import time

from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from sacred import Experiment

ex = Experiment('risk-control-test')

@ex.config
def config():
    #region primary configs of the experiment

    output_dir = "../../results/"

    dataset_name = "cityscapes"
    smoothing_sigma = 0.25
    r = 0.06
    n_dcal = 100
    main_risk_level = 0.2
    sensitive_class = 2




def extract_risks(exp_logits, label, sensitive_class):
    exp_logits = exp_logits.to(device)
    class_mask = (label == sensitive_class)
    # softmaxes = (exp_logits[0] / exp_logits.sum(dim=1)[0, :, :])
    softmaxes = (exp_logits[0])
    mask_values = softmaxes[sensitive_class, :, :][class_mask[0]]
    vals, counts = mask_values.unique(return_counts=True)
    risks = counts.cumsum(0) / counts.sum(0) if counts.sum(0) > 0 else torch.tensor([0.0]).to(device)
    exp_logits = exp_logits.cpu()
    return vals, risks

def compute_conformal_threshold(
    cal_vals, cal_risks, risk_level):
    threshold_levels, sorting_idxs = (torch.cat(cal_vals)).sort()

    # running a binary search to find the threshold
    start_idx = 0
    end_idx = threshold_levels.shape[0]

    while end_idx - start_idx > 1:
        mid_idx = (start_idx + end_idx) // 2
        threshold = threshold_levels[mid_idx]

        step_risks = []
        for i in range(len(cal_vals)):
            # print(torch.tensor(0.0).to(cal_risks[i].device) if len(cal_vals[i])==0 else cal_risks[i][(threshold > cal_vals[i])].max())
            # print("vals: ", cal_vals[i])
            # print("risks: ", cal_risks[i])
            # print(torch.tensor(0.0).to(cal_risks[i].device) if len(cal_vals[i])==0 else cal_risks[i][(threshold > cal_vals[i])])
            # import pdb; pdb.set_trace()
            if len(cal_vals[i]) == 0 or len(cal_risks[i][(threshold > cal_vals[i])]) == 0:
                step_risks.append(torch.tensor(0.0).to(cal_risks[i].device))
            else:
                step_risks.append(cal_risks[i][(threshold > cal_vals[i])].max())
            # print(cal_risks[i][threshold > cal_vals[i]])
            # print(cal_risks[i][(threshold > cal_vals[i])].max() if cal_risks[i].shape[0] > 0 else torch.tensor([0]))
        # step_risks = torch.stack([torch.tensor(0.0).to(cal_risks[i].device) if len(cal_vals[i])==0 else cal_risks[i][(threshold > cal_vals[i])].max() for i in range(len(cal_vals))])
        step_risks = torch.stack(step_risks)
        # print(step_risks)
        risk = (step_risks.sum(0) + 1) / len(cal_vals)
        # print(risk)
        if risk < risk_level:
            start_idx = mid_idx
        else:
            end_idx = mid_idx
        
    conformal_threshold = threshold_levels[start_idx]
    return conformal_threshold.item()



def return_conformal(exp_logits, sensitive_class, conformal_threshold):
    exp_logits = exp_logits.to(device)
    # softmaxes = (exp_logits[0] / exp_logits.sum(dim=1)[0, :, :])
    softmaxes = (exp_logits[0])
    output = softmaxes[sensitive_class, :, :] > conformal_threshold
    exp_logits = exp_logits.cpu()
    return output

def compute_risk(output, label, sensitive_class):
    class_mask = (label == sensitive_class)
    fnr = (class_mask[0].to(output.device) & ~output.to(output.device)).sum() / class_mask.sum()
    return fnr.cpu()


@ex.automain
def run(
    output_dir,
    dataset_name,
    smoothing_sigma,
    r,
    main_risk_level,
    n_dcal, 
    sensitive_class,
):
    adversarial_risk_level = norm.cdf(norm.ppf(main_risk_level, scale=smoothing_sigma) - r, scale=smoothing_sigma)

    print(f"Main risk level: {main_risk_level:.4f}")
    print(f"Adversarial risk level: {adversarial_risk_level:.4f}")

    imagefiles_dir = {
        # 0.25: "./cityscapes-results/cityscapes_2025_03_08_12:38:57",
        # 0.25: "./cityscapes-results/cityscapes_2025_03_08_12:38:57",
        # 0.0: "./cityscapes-results/cityscapes_2025_04_30_17:29:19",
        0.25: "./cityscapes-results/cityscapes_2025_04_30_17:43:33",
        0.0: "./cityscapes-results/cityscapes_2025_04_30_19:48:00",
    }

    city_scape_results = imagefiles_dir.get(smoothing_sigma)
    file_indices = list(set([int(item.split("_")[0]) for item in os.listdir(city_scape_results) if item.split("_")[0].isnumeric()]))
    
    pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

    rand_idxs = torch.randperm(len(file_indices))
    cal_file_indices = torch.tensor(file_indices)[rand_idxs[:n_dcal]]
    eval_file_indices = torch.tensor(file_indices)[rand_idxs[n_dcal:]]

    file_index = cal_file_indices[0]

    loaded_output = torch.load(os.path.join(city_scape_results, f"{file_index}_output.pt"))
    label = loaded_output["label"]
    exp_logits = loaded_output["soft_output"]

    cal_vals = []
    cal_risks = []
    for file_index in tqdm(cal_file_indices):
        loaded_output = torch.load(os.path.join(city_scape_results, f"{file_index}_output.pt"))
        label = loaded_output["label"]
        exp_logits = loaded_output["soft_output"]
        vals, risks = extract_risks(exp_logits, label, sensitive_class)
        cal_vals.append(vals)
        cal_risks.append(risks)
    cal_vals.append(torch.tensor([1.0]).to(device))
    cal_risks.append(torch.tensor([1.0]).to(device))

    conformal_threshold = compute_conformal_threshold(cal_vals, cal_risks, main_risk_level)
    print(f"Conformal threshold: {conformal_threshold}")
    adversarial_threshold = compute_conformal_threshold(cal_vals, cal_risks, adversarial_risk_level)
    print(f"Adversarial threshold: {adversarial_threshold}")


    for eval_idx in tqdm(eval_file_indices):
        loaded_output = torch.load(os.path.join(city_scape_results, f"{eval_idx}_output.pt"))
        label = loaded_output["label"]
        class_mask = (label == sensitive_class)
        if class_mask.sum() == 0:
            print(f"No samples for class {sensitive_class} in file {eval_idx}")
            continue
        exp_logits = loaded_output["soft_output"]

        output = return_conformal(exp_logits, sensitive_class, conformal_threshold)
        adversarial_output = return_conformal(exp_logits, sensitive_class, adversarial_threshold)
        risk = compute_risk(output, label, sensitive_class)
        print("Risk: ", risk)
        adversarial_risk = compute_risk(adversarial_output, label, sensitive_class)
        print(f"Risk: {risk:.4f}")
        print(f"Adversarial risk: {adversarial_risk:.4f}")


        state_dict = {
            "label": class_mask[0].cpu(),
            "output": output.bool().cpu(),
            "adversarial_output": adversarial_output.bool().cpu(),
            "conformal_threshold": conformal_threshold,
            "adversarial_threshold": adversarial_threshold,
            "risk": risk,
            "adversarial_risk": adversarial_risk,
            "main_risk_level": main_risk_level,
            "adversarial_risk_level": adversarial_risk_level,
            "sensitive_class": sensitive_class,
        }
        
        torch.save(state_dict, os.path.join(output_dir, f"{eval_idx}-class_{sensitive_class}_sigma-{smoothing_sigma}_r-{r}-output.pt"))


