

import os, sys, inspect
sys.path.insert(1, os.path.join(sys.path[0], '..'))

import numpy as np
import torch
import torch.utils.data as tdata
import torchvision.transforms as tf
import random
import torch.backends.cudnn as cudnn
import itertools
from tqdm import tqdm
import pandas as pd

import os 
import pathlib
from sklearn.model_selection import train_test_split
import os

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import argparse

from models.connetor import build_common_model
from datasets.connector import get_mask
from lib.utils import *
from lib.metrics import *
from lib.calibration import LogitNorm
from lib.post_process import*
from lib.loss_function import _ECELoss,cal_loss
from lib.predictor import *
from lib.optimal_function import get_optimal_parameters_RRA

import joblib


class experiment:
    def __init__(self,model_name,alpha,predictor,dataset_name,post_hoc,num_trials,cal_wsc) -> None:
        """
        三个基本的参数
        """
        self.model_name = model_name
        self.alpha=  alpha
        self.predictor = predictor
        self.dataset_name = dataset_name
        self.mask =None
        if self.dataset_name =="imagenet" or  self.dataset_name =="imagenetv2":
            self.num_calsses = 1000
            
        elif self.dataset_name == "imagenet-a":
            self.mask  =  get_mask("imagenet-a")
            self.num_calsses=200
        elif self.dataset_name == "imagenet-r":
            self.mask  =  get_mask("imagenet-r")
            self.num_calsses=200
        elif self.dataset_name == "cifar10":
            self.num_calsses = 10
        elif self.dataset_name == "cifar100":
            self.num_calsses = 100

        else:
            raise NotImplementedError
        self.post_hoc =  post_hoc
        self.cal_wsc = cal_wsc
        ### Instantiate and wrap model
        self.model = build_common_model(self.model_name,dataset_name)

        ### Data Loading
        self.logits = get_logits_dataset(self.model_name,self.dataset_name,self.mask)
        if self.cal_wsc or self.post_hoc=="IA":
            # load the feature of X
            self.featureX = get_featureX_dataset(self.model_name, dataset_name)
            #对x特征进行降维
            pca = PCA(n_components=50)
            self.featureX_rd = pca.fit_transform(self.featureX)
            scaler = StandardScaler()
            self.featureX_rd = scaler.fit_transform(self.featureX_rd)
        #实验次数
        self.num_trials = num_trials
        # 分层的大小
        self.strata = [[0,1],[2,3],[4,6],[7,10],[11,100],[101,1000]]
        # difficulty
        self.strata_diff = [[1,1],[2,3],[4,6],[7,10],[11,100],[101,1000]]
        self.seed=0
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


    def run(self, randomized, n_data_conf, n_data_val, pct_paramtune, bsz,Kreg):
        ### Perform experiment
        top1s = np.zeros((self.num_trials ,))
        top5s = np.zeros((self.num_trials ,))
        coverages = np.zeros((self.num_trials ,))
        sizes = np.zeros((self.num_trials ,))
        sscvs = np.zeros((self.num_trials ,))
        sscvos = np.zeros((self.num_trials ,))
        sscvs_diff = np.zeros((self.num_trials ,))
        sscvos_diff = np.zeros((self.num_trials ,))
        wsc_coverages = np.zeros((self.num_trials ,))
        wsc_rd_coverages = np.zeros((self.num_trials ,))
        MCs = np.zeros((self.num_trials ,))
        ViolPers = np.zeros((self.num_trials ,))
        MSs = np.zeros((self.num_trials ,))
        ece_list=np.zeros((self.num_trials ,))
        
        ccss_list=[]
        ccss_difft_list=[]
        sTime =time.time()
        for i in tqdm(range(self.num_trials ),postfix="实验次数"):
        # for i in range(self.num_trials ):
            self.seed  = i
            self._fix_randomness(self.seed)
            top1_avg, top5_avg, cvg_avg, sz_avg,sscv,ccss,MC,ViolPer,MS,sscvo,sscv_diff,ccss_diff,sscvo_diff= self.trial( randomized, n_data_conf, n_data_val, pct_paramtune, bsz,Kreg)
            top1s[i] = top1_avg
            top5s[i] = top5_avg
            coverages[i] = cvg_avg
            sizes[i] = sz_avg
            sscvs[i] = sscv
            sscvos[i] = sscvo
            sscvs_diff[i] = sscv_diff
            sscvos_diff[i] = sscvo_diff
            ccss_difft_list.append(ccss_diff)
            ece_list[i] = self.ece_cal
            ccss_list.append(ccss)
            MCs[i] = MC
            ViolPers[i] = ViolPer
            MSs[i] = MS
            # wsc_coverages[i] =wsc_coverage
            # wsc_rd_coverages[i] =wsc_rd_coverage
            # print(f'\n\tTop1: {np.median(top1s[0:i+1]):.3f}, Top5: {np.median(top5s[0:i+1]):.3f}, Coverage: {np.median(coverages[0:i+1]):.3f}, Size: {np.median(sizes[0:i+1]):.3f}, sscv: {np.median(sscvs[0:i+1]):.3f}, wsc: {self.wsc:.3f}, wsc_rd: {self.wsc_rd:.3f} \033[F', end='')
            print(f'\n\tTop1: {np.median(top1s[0:i+1]):.3f}, Top5: {np.median(top5s[0:i+1]):.3f}, Coverage: {np.median(coverages[0:i+1]):.3f}, Size: {np.median(sizes[0:i+1]):.3f}, sscv: {np.median(sscvs[0:i+1]):.3f}, MC: {np.median(MCs[0:i+1]):.3f},ViolPer: {np.median(ViolPers[0:i+1]):.2f},MS: {np.median(MSs[0:i+1]):.3f},ECE:{np.round(np.median(ece_list),4):.4f}  \033[F', end='')
        print('\n')
        eTime =time.time()
        # print(sizes)

        # 选择展示的index
        choose_show_index = np.argsort(sscvs)[len(sscvs) // 2]
        choose_show_index_diff = np.argsort(sscvs_diff)[len(sscvs) // 2]
        res_dict={}
        res_dict["Model"] = self.model_name
        res_dict["Predictor"] = self.predictor
        res_dict["alpha"] = self.alpha
        res_dict["post_hc"] = self.post_hoc
        res_dict["Top1"] = np.round(np.median(top1s),4)
        res_dict["Top5"] = np.round(np.median(top5s),4)
        res_dict["Coverage"] = np.round(np.median(coverages),4)
        res_dict["Coverage_std"] = np.round(np.std(coverages),4)
        res_dict["Size"] = np.round(np.median(sizes),4)
        res_dict["ECE"] = np.round(np.median(ece_list),4)
        res_dict["SSCV_diff"] = np.round(np.median(sscvs_diff),4)
        res_dict["SSCVO_diff"] = np.round(np.median(sscvos_diff),4)
        res_dict["SSCV"] = np.round(np.median(sscvs),4)
        res_dict["SSCVO"] = np.round(np.median(sscvos),4)
        res_dict["Time"] = (eTime-sTime)/self.num_trials
        res_dict["CCSS_diff"] = ccss_difft_list[choose_show_index_diff]
        res_dict["CCSS"] = ccss_list[choose_show_index]
        res_dict["WSC"] = 0
        res_dict["WSC_rd"] = 0
        res_dict["MC"] = np.median(MCs)
        res_dict["MC"] = np.median(MCs)
        res_dict["ViolPer"] = np.median(ViolPers)
        res_dict["MS"] = np.median(MSs)
    
        return res_dict



    def trial(self, randomized, n_data_conf, n_data_val, pct_paramtune, bsz,Kreg):
        alpha = self.alpha
        # 读取原始的logits
        logits_cal, logits_val,self.cal_indices,self.val_indices= split2(self.logits, n_data_conf, len(self.logits)-n_data_conf) 
    


        ######################
        # psot hoc
        ######################

        # 选择post_hoc 并优化相应的post_hoc 算法. CP算法和post_hoc是脱节的。
        if self.post_hoc == "oTS":
            # 需要指定初始温度
            transformation = OptimalTeamperatureScaling(1.3)
        elif self.post_hoc == "Identity":
            transformation = PostHoc()
        elif self.post_hoc == "LN":
            transformation = LogitNormalization()
        elif self.post_hoc == "LNo":
            transformation = LogitNormalizationwOptimal()
        elif self.post_hoc == "IA":
            transformation = IA(self.num_calsses)
            cache_path = str(pathlib.Path(__file__).parent.absolute()) + '/.cache/' + self.dataset_name +"/IA/pkl"
            file_path = os.path.join(cache_path,"IA_KNN_{}_seed={}.pkl".format(self.model_name,self.seed))
            if os.path.exists(file_path):
                all_a = joblib.load(file_path)
                cal_a = all_a['cal_a'] 
                val_a = all_a['val_a'] 
            
            else:
                cal_X = self.featureX[self.cal_indices]
                val_X = self.featureX[self.val_indices]
                cal_a = computeInputyAtypicalityKNN(cal_X,cal_X)
                val_a = computeInputyAtypicalityKNN(cal_X,val_X)
                all_a={}
                all_a['cal_a'] =cal_a
                all_a['val_a'] =val_a
                joblib.dump(all_a,file_path)

            
            cal_a = cal_a.reshape(-1,1)
            # 将数据标准化
            scaler = StandardScaler()
            scaler.fit(cal_a)
            cal_a = scaler.transform(cal_a)
            
            input_cal = torch.stack([sample[0] for sample in logits_cal])
            input_cal = torch.concatenate([input_cal,torch.from_numpy(cal_a)],dim=1)
            targets_cal = torch.stack([sample[1] for sample in logits_cal])
            targets_cal = targets_cal.reshape(-1)
            logits_cal = torch.utils.data.TensorDataset(input_cal,targets_cal)
            
            
            val_a=  val_a.reshape(-1,1)
            val_a = scaler.transform(val_a)
            input_val = torch.stack([sample[0] for sample in logits_val])
            input_val = torch.concatenate([input_val,torch.from_numpy(val_a)],dim=1)
            targets_val = torch.stack([sample[1] for sample in logits_val]).reshape(-1)
            logits_val = torch.utils.data.TensorDataset(input_val,targets_val)
        elif self.post_hoc == "TS":
            
            transformation = FixedTeamperatureScaling(0.5)
        elif self.post_hoc == "VS":
            # vector scaling
            transformation = VectorScaling(self.num_calsses)
        else:
            raise NotImplementedError
        

        
       # Prepare the loaders
        loader_cal = torch.utils.data.DataLoader(logits_cal, batch_size = bsz, shuffle=False, pin_memory=True)
        loader_val = torch.utils.data.DataLoader(logits_val, batch_size = bsz, shuffle=False, pin_memory=True)
        
        
        # 优化post hoc 算法
        if any(param.requires_grad for param in transformation.parameters()):
            if self.post_hoc == "IA":
            
                transformation =  get_optimal_parameters_RRA(transformation,loader_cal,self.device)

            else:
                transformation =  self.get_optimal_parameters(transformation,loader_cal)
            # if self.post_hoc == "oTS":
            #     print(transformation.temperature.item())

        # print(cal_loss(loader_cal,self.device))
        # print(cal_loss(loader_val,self.device))
         # 应用post hoc 算法将初始的logits全部变换掉
        if self.post_hoc!="Identity":
            logits_cal = postHocLogits(transformation,loader_cal,self.device,self.num_calsses,self.mask )
            logits_val = postHocLogits(transformation,loader_val,self.device,self.num_calsses,self.mask )

        # Prepare the loaders
        loader_cal = torch.utils.data.DataLoader(logits_cal, batch_size = bsz, shuffle=False, pin_memory=True)
        loader_val = torch.utils.data.DataLoader(logits_val, batch_size = bsz, shuffle=False, pin_memory=True)
        
        # print(cal_loss(loader_cal,self.device))
        # print(cal_loss(loader_val,self.device))
        # self.ece_cal = cal_loss(loader_cal,self.device)
        self.ece_cal = cal_loss(loader_val,self.device)

        allow_zero_sets = True

        
        if self.predictor == "APS":
            self.conformal_model = APS(loader_cal, alpha=alpha,randomized=True,allow_zero_sets=allow_zero_sets)
        elif self.predictor == "RAPS":
            self.conformal_model = RAPS(loader_cal, alpha=alpha, kreg=Kreg, lamda=None, randomized=True, allow_zero_sets=allow_zero_sets, pct_paramtune=pct_paramtune, batch_size=bsz, lamda_criterion='size')
            self.lamda= self.conformal_model.lamda
        elif self.predictor == "RAPS_K1":
            self.conformal_model = RAPS(loader_cal, alpha=alpha, kreg=1, lamda=None, randomized=True, allow_zero_sets=allow_zero_sets, pct_paramtune=pct_paramtune, batch_size=bsz, lamda_criterion='size')
            

        elif self.predictor =="LAPS":
            self.conformal_model = LAPS(loader_cal, alpha=alpha,rank_pen=0,randomized=True,allow_zero_sets=allow_zero_sets,kreg=0,lamda=0,batch_size=bsz,pct_paramtune=pct_paramtune,)
       

        return self.validate(loader_val)
    

    def  get_optimal_parameters(self,transformation,calib_loader):
        """
        Tune the tempearature of the model (using the validation set).
        We're going to set it to optimize NLL.
        valid_loader (DataLoader): validation set loader
        """
        # #####################
        # #   temperature scaling 的  第二种方法
        # ####################
        if self.post_hoc =="oTS":
            device = self.device
            transformation.to(device)
            max_iters=10
            lr=0.01
            epsilon=0.01
            nll_criterion = nn.CrossEntropyLoss().cuda()

            T = transformation.temperature

            optimizer = optim.SGD([transformation.temperature], lr=lr)
            for iter in range(max_iters):
                T_old = T.item()
                # print(T_old)
                for x, targets in calib_loader:
                    optimizer.zero_grad()
                    x = x.cuda()
                    x.requires_grad = True
                    out = x/transformation.temperature
                    loss = nll_criterion(out, targets.long().cuda())
                    
                    loss.backward()
                    optimizer.step()
                T = transformation.temperature
                if abs(T_old - T.item()) < epsilon:
                    break

        else:
            device = self.device
            transformation.to(device)
            nll_criterion = nn.CrossEntropyLoss().to(device)
            ece_criterion = _ECELoss().to(device)
            # First: collect all the logits and labels for the validation set
            logits_list = []
            labels_list = []
            with torch.no_grad():
                for batch_idx, examples in enumerate(calib_loader):
                    logits, label = examples[0], examples[1]
                    logits_list.append(logits)
                    labels_list.append(label)
                # print(len(logits_list))
                # print(examples[0])
                logits = torch.cat(logits_list).to(device)
                labels = torch.cat(labels_list).to(device)

            # Calculate NLL and ECE before temperature scaling
            # before_temperature_O_nll = nll_criterion(logits, labels).item()
            # before_temperature_nll = nll_criterion(transformation(logits), labels).item()
            # before_temperature_ece = ece_criterion(transformation(logits), labels).item()
            # before_temperature_O_ece = ece_criterion(logits, labels).item()
            # print('Before temperature - OriginalNLL: %.3f, OriginalECE: %.3f   NLL: %.3f, ECE: %.3f' % (before_temperature_O_nll,before_temperature_O_ece,before_temperature_nll, before_temperature_ece))

            # Next: optimize the temperature w.r.t. NLL
            if self.post_hoc =="VS":
                if self.dataset_name== "cifar10":
                    optimizer = optim.LBFGS(transformation.parameters(), lr=0.1, max_iter=1000)
                else:      
                    optimizer = optim.LBFGS(transformation.parameters(), lr=0.3, max_iter=100)
                    # optimizer = optim.LBFGS(transformation.parameters())
              
            else:
                optimizer = optim.LBFGS(transformation.parameters(), lr=0.01, max_iter=50)

            def eval():
                optimizer.zero_grad()
                loss = nll_criterion(transformation(logits), labels)
                loss.backward()
                return loss
            optimizer.step(eval)

            # Calculate NLL and ECE after temperature scaling
            # after_temperature_nll = nll_criterion(transformation(logits), labels).item()
            # after_temperature_ece = ece_criterion(transformation(logits), labels).item()
            # print('After temperature - NLL: %.3f, ECE: %.6f ' % (after_temperature_nll, after_temperature_ece))
        return transformation

    def validate(self,val_loader):
        with torch.no_grad():
            batch_time = AverageMeter('batch_time')
            top1 = AverageMeter('top1')
            top5 = AverageMeter('top5')
            # switch to evaluate mode
            self.conformal_model.eval()
            end = time.time()
            N = 0
            all_S =[]
            targets=[]
            size_array=[]
            correct_array=[]
            label_positions=[]
            topk=[]
            for i, (logits, target) in enumerate(val_loader):
                # print(list(target))
                I,_,_ = sort_sum(logits.numpy())
                topk.append(np.where((I - target.view(-1,1).numpy())==0)[1]+1) 
                target = target.cuda()
                logits = logits.cuda()
                #计算lable在order of predictions中的位置
                sorted_indices = torch.argsort(logits, dim=1, descending=True)
                

                # 查找 label 在排序后的索引中的位置
                tmp_label_positions = torch.nonzero(sorted_indices == target.view(-1, 1))[:, 1]
                label_positions.append(tmp_label_positions)
                # compute output
                # logits , prediction sets
                _,S = self.conformal_model(logits.cuda())
                all_S.extend(S)
                targets.extend(list(target.detach().cpu().numpy()))
                for i in range(target.shape[0]):
                    size_array.append(len(S[i]))
                    if (target[i].item() in S[i]):
                        correct_array.append(1)
                    else:
                        correct_array.append(0)
                        
                
                # measure accuracy and record loss
                prec1, prec5 = accuracy(logits, target, topk=(1, 5))

                # Update meters
                top1.update(prec1.item()/100.0, n=logits.shape[0])
                top5.update(prec5.item()/100.0, n=logits.shape[0])

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                N = N + logits.shape[0]
                
        MacroCoverage,CoverViolation= cal_MacroCoverAndCoverViolation(correct_array,targets,self.num_calsses,self.alpha)
        MacroIneff = cal_MacroIneff(size_array,targets,self.num_calsses)
        topk  = np.concatenate(topk)
        sscv,ccss,sscvo,sscv_diff,ccss_diff,sscvo_diff= self.cal_sscv(size_array,correct_array,topk)
        label_positions  =torch.cat(label_positions).detach().cpu().numpy()
        
        # print(label_positions.shape,np.mean(label_positions),np.mean(label_positions[label_positions>0]))
        


        # if self.seed==0 and self.cal_wsc :
        #     self.wsc= self.wscUnbiased(self.featureX[self.val_indices],targets,all_S)
        #     self.wsc_rd = self.wscUnbiased(self.featureX_rd[self.val_indices],targets,all_S)

        return top1.avg, top5.avg, np.mean(correct_array), np.mean(size_array),sscv,ccss,MacroCoverage,CoverViolation,MacroIneff,sscvo,sscv_diff,ccss_diff,sscvo_diff


    def _fix_randomness(self,seed=0):
        ### Fix randomness 
        np.random.seed(seed=seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        random.seed(seed)



    def cal_sscv(self, size_array,correct_array,topk):
        """
        computing the sscv
        """
        size_array = np.array(size_array)
        correct_array = np.array(correct_array)

        wc_violation =-1 
        ccss = {}    # coverage conditional on set size
        for stratum in self.strata:
            temp_index = np.argwhere( (size_array >= stratum[0]) & (size_array <= stratum[1]) )
            ccss[str(stratum)]={}
            if len(temp_index) == 0:
                ccss[str(stratum)]['cnt'] = len(temp_index)
                ccss[str(stratum)]['cvg'] = 0
            else:
                ccss[str(stratum)]['cnt'] = len(temp_index)
                temp_index= temp_index[:,0]
                ccss[str(stratum)]['cvg'] = np.round(np.mean(correct_array[temp_index]),3)
                # 这个值越小越好
                stratum_violation = max(0,(1-self.alpha) - np.mean(correct_array[temp_index]))
                wc_violation = max(wc_violation, stratum_violation)
                
                
        wc_violation_one =0
        
        for i in range(1,self.num_calsses+1):
            temp_index = np.argwhere( size_array == i )
            if len(temp_index)>0:
                temp_index= temp_index[:,0]
                # 这个值越小越好
                stratum_violation = max(0,(1-self.alpha) - np.mean(correct_array[temp_index]))
                wc_violation_one = max(wc_violation_one, stratum_violation)
                
        ############calculate the data based on difficulty
        ccss_diff = {} 
        wc_diff_violation =-1 
        for stratum in self.strata_diff:
            temp_index = np.argwhere( (topk >= stratum[0]) & (topk <= stratum[1]) )
            ccss_diff[str(stratum)]={}
            ccss_diff[str(stratum)]['cnt'] = len(temp_index)
            if len(temp_index) == 0:
                ccss_diff[str(stratum)]['cvg'] = 0
                ccss_diff[str(stratum)]['sz'] = 0
            else:
                temp_index= temp_index[:,0]
                cvg = np.round(np.mean(correct_array[temp_index]),3)
                sz  = np.round(np.mean(size_array[temp_index]),3)
                
                ccss_diff[str(stratum)]['cvg'] = cvg
                ccss_diff[str(stratum)]['sz'] = sz
                # 这个值越小越好
                stratum_violation = max(0,(1-self.alpha) -cvg)
                wc_diff_violation = max(wc_diff_violation, stratum_violation)
        wc_diff_violation_one =0   
        for i in range(1,self.num_calsses+1):
            temp_index = np.argwhere( topk == i )
            if len(temp_index)>0:
                temp_index= temp_index[:,0]
                # 这个值越小越好
                stratum_violation = max(0,(1-self.alpha) - np.mean(correct_array[temp_index]))
                wc_diff_violation_one = max(wc_diff_violation_one, stratum_violation)
                
        return wc_violation,ccss,wc_violation_one,wc_diff_violation,ccss_diff,wc_diff_violation_one
    


    def calWSC(self,X, y, S, delta=0.1, M=1000, random_state=2020, verbose=True):
        rng = np.random.default_rng(random_state)

        def wsc_v(X, y, S, delta, v):
            n = len(y)
            cover = np.array([y[i] in S[i] for i in range(n)])
            z = np.dot(X,v)
            # Compute mass
            z_order = np.argsort(z)
            z_sorted = z[z_order]
            cover_ordered = cover[z_order]
            ai_max = int(np.round((1.0-delta)*n))
            ai_best = 0
            bi_best = n-1
            cover_min = 1
            for ai in np.arange(0, ai_max):
                bi_min = np.minimum(ai+int(np.round(delta*n)),n-1)
                coverage = np.cumsum(cover_ordered[ai:n]) / np.arange(1,n-ai+1)
                coverage[np.arange(0,bi_min-ai)]=1
                bi_star = ai+np.argmin(coverage)
                cover_star = coverage[bi_star-ai]
                if cover_star < cover_min:
                    ai_best = ai
                    bi_best = bi_star
                    cover_min = cover_star
            return cover_min, z_sorted[ai_best], z_sorted[bi_best]

        def sample_sphere(n, p):
            v = rng.normal(size=(p, n))
            v /= np.linalg.norm(v, axis=0)
            return v.T

        V = sample_sphere(M, p=X.shape[1])
        wsc_list = [[]] * M
        a_list = [[]] * M
        b_list = [[]] * M
        if verbose:
            for m in tqdm(range(M)):
                wsc_list[m], a_list[m], b_list[m] = wsc_v(X, y, S, delta, V[m])

        else:
            for m in range(M):
                wsc_list[m], a_list[m], b_list[m] = wsc_v(X, y, S, delta, V[m])                
            
        idx_star = np.argmin(np.array(wsc_list))
        a_star = a_list[idx_star]
        b_star = b_list[idx_star]
        v_star = V[idx_star]
        wsc_star = wsc_list[idx_star]
        return wsc_star, v_star, a_star, b_star

    def wscUnbiased(self,X, y, S, delta=0.1, M=1000, test_size=0.75, random_state=2020, verbose=True):
        def wsc_vab(X, y, S, v, a, b):
            n = len(y)
            cover = np.array([y[i] in S[i] for i in range(n)])
            z = np.dot(X,v)
            idx = np.where((z>=a)*(z<=b))
            coverage = np.mean(cover[idx])
            return coverage

        X_train, X_test, y_train, y_test, S_train, S_test = train_test_split(X, y, S, test_size=test_size,
                                                                            random_state=random_state)
        # Find adversarial parameters
        wsc_star, v_star, a_star, b_star = self.calWSC(X_train, y_train, S_train, delta=delta, M=M, random_state=random_state, verbose=verbose)
        # Estimate coverage
        coverage = wsc_vab(X_test, y_test, S_test, v_star, a_star, b_star)
        return coverage
        



if __name__ == "__main__":
    """测试算法的结果

    Raises:
        NotImplementedError: _description_
    """
    parser = argparse.ArgumentParser(description='Evaluates conformal predictors',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    


    parser.add_argument('--dataset_name', '-s', type=str, default='imagenet', help='dataset name.')
    parser.add_argument('--gpu', type=int, default=0, help='chose gpu id')
    
    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

    ### Fix randomness d
    dataset_name = args.dataset_name
    # datasetname = 'cifar10'
    cache_path = str(pathlib.Path(__file__).parent.absolute()) + '/.cache/' + dataset_name
        
    ### Configure experiment

    if dataset_name == "imagenet" or dataset_name == "imagenetv2":
        modelnames = ['ResNeXt101','ResNet152','ResNet101','ResNet50','ResNet18','DenseNet161','VGG16','Inception','ShuffleNet',"ViT",'DeiT',"CLIP"]
    else:
        modelnames = ['ResNet18','ResNet50','ResNet101','DenseNet161','VGG16','Inception',"ViT","CLIP"]


    alphas = [0.1,0.05,0.01]

    
    post_hocs = ["oTS"]
    predictors = ["RAPS"]

    
    params1 = list(itertools.product(post_hocs))
    params2 = list(itertools.product(modelnames,predictors))
    m1= len(params1)
    m2= len(params2)
    num_trials = 1
    cal_wsc = 0
    kreg = None 
    lamda = None 
    randomized = True
    if dataset_name== "imagenetv2":
        n_data_conf = 5000
        n_data_val = 5000
    elif dataset_name=="imagenet-a":
        n_data_conf = 4500
        n_data_val = 3000
    elif dataset_name=="imagenet-r":
        n_data_conf = 15000
        n_data_val = 15000
    elif dataset_name==  "imagenet":
        n_data_conf = 30000
        n_data_val = 20000
    elif dataset_name==  "cifar10" or dataset_name==  "cifar100":
        n_data_conf = 5000
        n_data_val = 5000

    else:
        raise NotImplementedError
    
    pct_paramtune = 0.2
    bsz = 320
    # 对卷积进行加速
    cudnn.benchmark = True
    filename ="paper_Ablation_RAPS.pkl"
    alpha = alphas[0]
    post_hoc = post_hocs[0]
    predictor = predictors[0]
    res_filepath = os.path.join(cache_path,filename)
    Res={}
    Res["Kreg"] = [1,2,3,4,5,6,7,8,9,10]
    for model_name in modelnames:
        Res[model_name] ={}
        Res[model_name]["size"] =[]
        Res[model_name]["lambda"] =[]
        print(f'Model: {model_name} | Desired coverage: {1-alpha} | Predictor: {predictor}| Calibration: {post_hoc}')
        for Kreg in Res["Kreg"]: 
            this_experiment =  experiment(model_name,alpha,predictor,dataset_name,post_hoc,num_trials,cal_wsc)
            out = this_experiment.run(randomized, n_data_conf, n_data_val, pct_paramtune, bsz,Kreg) 
            Res[model_name]["lambda"].append(this_experiment.lamda)
            Res[model_name]["size"].append(out["Size"])
        
        joblib.dump(Res,res_filepath)
                
                
    