from tqdm import tqdm
import os
import argparse
import pickle
import warnings
import math

import numpy as np
import torch

from nnlib.nnlib import utils

import methods

from scripts.fcmi_train_classifier import mnist_ld_schedule, \
    cifar_resnet50_ld_schedule  # for pickle to be able to load LD methods

from methods.calibration import CalibrationMethod, TemperatureScaling, HistogramBinning, GPCalibration
from utils.metrics import compute_acc, compute_acc_l2, compute_ece, compute_ece2, compute_bins, get_ece_kde, get_bandwidth, idx_bins, _ECELoss, _LabelLoss, compute_lsloss
from utils.bounds import estimate_cmi_bound, estimate_fcmi, estimate_semipara_bound, estimate_cmi_ece_bound, estimate_cmi_recal_bound

from modules.data_utils import load_data_from_arguments
from nnlib.nnlib.data_utils.wrappers import SubsetDataWrapper, LabelSubsetWrapper, ResizeImagesWrapper
from nnlib.nnlib.data_utils.base import get_loaders_from_datasets, get_input_shape

class NestedDict(dict):
    def __missing__(self, key):
        self[key] = type(self)()
        return self[key]


def load_all_data(saved_data, test_data=False):
    args = saved_data['args']
    if test_data:
        _, _, examples, _ = load_data_from_arguments(args, build_loaders=False)
    else:
        examples, _, _, _ = load_data_from_arguments(args, build_loaders=False)
    
    # select labels if needed
    if args.which_labels is not None:
        examples = LabelSubsetWrapper(examples, which_labels=args.which_labels)
    
    # resize if needed
    if args.resize_to_imagenet:
        examples = ResizeImagesWrapper(examples, size=(224, 224))
    
    return examples

def preds_recalibrate(saved_data, all_data, n_data=100, test_data=False):
    all_examples = saved_data['all_examples'] ## train, val data
    args = saved_data['args']
    
    assert len(all_data) >= n_data
    np.random.seed(args.seed)
    if test_data:
        exclude_indices = np.arange(len(all_data))
        if n_data >= len(exclude_indices):
            raise ValueError(f" The number of recalibration data is larger than that of all data")
        exclude_indices = np.random.choice(exclude_indices, size=n_data, replace=False)
    else:
        include_indices = all_examples.include_indices
        exclude_indices = np.arange(len(all_data))[~np.isin(np.arange(len(all_data)), include_indices)]
        if n_data >= len(exclude_indices):
            raise ValueError(f" The number of recalibration data is larger than that of all data")
        exclude_indices = np.random.choice(exclude_indices, size=n_data, replace=False)
    
    examples = SubsetDataWrapper(all_data, include_indices=exclude_indices)
    
    return examples

def check_softmax(preds):
    
    if not torch.all(torch.abs(torch.sum(preds, dim=1) - 1) < 1e-10):
        print("make softmax prob.")
        preds = preds.softmax(1)
    else:
        print("already softmax prob.")
    
    return preds

def multi_pred_to_binary(preds, labels, target_label=3):
    if (preds.shape[1]-1) < target_label:
        raise ValueError(f" This target label is not included.")
    binary_preds = torch.zeros(len(preds), 2)
    binary_preds[:,0] = 1 - preds[:,target_label]
    binary_preds[:,1] = preds[:,target_label]
    labels[labels!=target_label] = 0
    labels[labels==target_label] = 1
    #idx = np.arange(0,len(torch.unique(labels)))
    #preds = preds.softmax(1)
    #binary_preds = torch.zeros(len(preds), 2)
    #binary_preds[:,0] = 1 - preds[:,idx==target_label].sum(1)
    #binary_preds[:,1] = preds[:,idx==target_label].sum(1)
    #labels[labels!=target_label] = 0
    #labels[labels==target_label] = 1
    return binary_preds, labels

def get_fcmi_results_for_fixed_z(n, lr, epoch, seed, args, num_bins_eval=None, recalibrate=False, traindata_reuse=False, n_caldata=100, recal_method='hist_umb', recal_ece=False, strategy='label', make_binary=False, binary_target=None):
    """
    Get results in the recalibration case.
    """
    train_accs = []
    val_accs = []
    train_accs_l2 = []
    val_accs_l2 = []
    train_recal_accs = []
    val_recal_accs = []
    train_recal_accs_l2 = []
    val_recal_accs_l2 = []

    train_eces_uwb = []
    val_eces_uwb = []
    train_eces_umb = []
    val_eces_umb = []
    train_eces_uwb2 = []
    val_eces_uwb2 = []
    train_eces_umb2 = []
    val_eces_umb2 = []
    train_eces_kde = []
    val_eces_kde = []


    preds = []
    masks = []
    labels = []
    bins = []
    bins_umb = []
    if recalibrate:
        preds_recal = []
        labels_recal = []
    
    if traindata_reuse:
        lsloss_uwb = []
        lsloss_umb = []

    all_examples = None  # will be needed after this loop to dump some extra information
    all_data = None
    
    ## number of bins for CMI evaluation
    if num_bins_eval == None:
        print('set our optimal numper of bins.')
        n_bins_eval = math.floor(n ** (1/3))
    else:
        n_bins_eval = num_bins_eval
    for S_seed in range(args.n_S_seeds):
        dir_name = f'n={n},lr={lr},seed={seed},S_seed={S_seed}'
        dir_path = os.path.join(args.results_dir, args.exp_name, dir_name)
        if not os.path.exists(dir_path):
            print(f"Did not find results for {dir_name}")
            continue

        with open(os.path.join(dir_path, 'saved_data.pkl'), 'rb') as f:
            saved_data = pickle.load(f)

        model = utils.load(path=os.path.join(dir_path, 'checkpoints', f'epoch{epoch - 1}.mdl'),
                           methods=methods, device=args.device)

        print(list(saved_data.keys()))
        if 'all_examples_wo_data_aug' in saved_data:
            all_examples = saved_data['all_examples_wo_data_aug']
        else:
            all_examples = saved_data['all_examples']

        cur_preds = check_softmax(utils.apply_on_dataset(model=model, dataset=all_examples,
                                           batch_size=args.batch_size)['pred'])
        cur_labels = utils.apply_on_dataset(model=model, dataset=all_examples, 
                                            batch_size=args.batch_size)['label_0']
        cur_mask = saved_data['mask']

        if make_binary:
            cur_preds, cur_labels = multi_pred_to_binary(cur_preds, cur_labels, binary_target)

        ### Pool preds, masks, and labels
        preds.append(cur_preds)
        masks.append(cur_mask)
        labels.append(cur_labels)

        ## Prediction on recalibration data
        if recalibrate:

            ## define the recalibration method
            if recal_method == 'temp':
                recal_f = TemperatureScaling()
            elif recal_method == 'hist_uwb':
                recal_f = HistogramBinning(mode='equal_width', n_bins=n_bins_eval)
            elif recal_method == 'hist_umb':
                recal_f = HistogramBinning(mode='equal_freq', n_bins=n_bins_eval)
            elif recal_method == 'GPcalib':
                num_classes = len(torch.unique(cur_labels))
                recal_f = GPCalibration(n_classes=num_classes, maxiter=1000, num_inducing=10, logits=False,
                                      random_state=1)
            elif recal_method == 'GPcalib_approx':
                num_classes = len(torch.unique(cur_labels))
                recal_f = GPCalibration(n_classes=num_classes, maxiter=1000, num_inducing=10, logits=False,
                                      random_state=1, inf_mean_approx=True)
            elif recal_method == 'GPcalib_pac':
                num_classes = len(torch.unique(cur_labels))
                recal_f = GPCalibration(n_classes=num_classes, maxiter=1000, num_inducing=10, logits=False,
                                      random_state=1, pac=True)
            elif recal_method == 'GPcalib_pac_root':
                num_classes = len(torch.unique(cur_labels))
                recal_f = GPCalibration(n_classes=num_classes, maxiter=1000, num_inducing=10, logits=False,
                                      random_state=1, pac=True, loss_type='root')
            elif recal_method == 'GPcalib_pac_total':
                num_classes = len(torch.unique(cur_labels))
                recal_f = GPCalibration(n_classes=num_classes, maxiter=1000, num_inducing=10, logits=False,
                                      random_state=1, pac=True, likelihood_type='total')
            else:
                raise ValueError(f"Unexpected recalibratin method: {recal_method}.")
            
            ## proposed method
            if traindata_reuse:
                train_idx = 2*np.arange(len(cur_mask)) + cur_mask
                recal_f.fit(cur_preds[train_idx].numpy(), cur_labels[train_idx].numpy())
                cur_preds_recal = torch.tensor(recal_f.predict_proba(cur_preds.numpy())).float()
                cur_labels_recal = cur_preds.argmax(1)
            
            else:
                if S_seed == 0: ## Loading the all dataset to prepare the recablation dataset at the first iteration
                    all_data = load_all_data(saved_data=saved_data, test_data=True) ## Using test data is no problem because we prepare train/test data from the "training dataset" itself and do not use test data for model training/evaluation.
                print(f"preparing the {n_caldata}-th recalibration data.")
                recab_examples = preds_recalibrate(saved_data=saved_data, all_data=all_data, n_data=n_caldata, test_data=True)
                cur_preds_recal = check_softmax(utils.apply_on_dataset(model=model, dataset=recab_examples, batch_size=args.batch_size)['pred']) ## (2*n, num_classes)
                cur_labels_recal = utils.apply_on_dataset(model=model, dataset=recab_examples, batch_size=args.batch_size)['label_0'] ## (2*n, num_classes)
                if make_binary:
                    cur_preds_recal, cur_labels_recal = multi_pred_to_binary(cur_preds_recal, cur_labels_recal, binary_target)
                
                recal_f.fit(cur_preds_recal.numpy(), cur_labels_recal.numpy())
                cur_preds_recal = torch.tensor(recal_f.predict_proba(cur_preds.numpy())).float()
                cur_labels_recal = cur_preds.argmax(1)
            
            ### Pool preds, masks, and labels
            preds_recal.append(cur_preds_recal)
            labels_recal.append(cur_labels_recal)
        
        ## Pool bins to evaluate the binned ECE and the CMI (Remark: using the "non-recalibrated" predictive prob. on training data.)
        train_idx = 2*np.arange(len(cur_mask)) + cur_mask
        conf = cur_preds[train_idx].max(1).values
        
        n_bins = compute_bins(num_bins=n_bins_eval) 
        cur_bins = idx_bins(conf, n_bins)
        bins.append(cur_bins.numpy())

        n_bins_umb = compute_bins(num_bins=n_bins_eval, confidences=conf, method='quantile')
        cur_bins = idx_bins(conf, n_bins_umb)
        bins_umb.append(cur_bins.numpy())
        
        ############################ Accuracy evaluation ############################
        if recalibrate:
            ## Acc.
            cur_train_recal_acc = compute_acc(preds=cur_preds_recal, mask=cur_mask, dataset=all_examples, target_label=binary_target)
            cur_val_recal_acc = compute_acc(preds=cur_preds_recal, mask=1-cur_mask, dataset=all_examples, target_label=binary_target)
            ## Acc L2
            cur_train_recal_acc_l2 = compute_acc_l2(preds=cur_preds_recal, mask=cur_mask, dataset=all_examples, target_label=binary_target)
            cur_val_recal_acc_l2 = compute_acc_l2(preds=cur_preds_recal, mask=1-cur_mask, dataset=all_examples, target_label=binary_target)

            train_recal_accs.append(cur_train_recal_acc)
            val_recal_accs.append(cur_val_recal_acc)
            train_recal_accs_l2.append(cur_train_recal_acc_l2)
            val_recal_accs_l2.append(cur_val_recal_acc_l2)
            print("recal_acc",cur_train_recal_acc, cur_val_recal_acc)
            print("recal_acc_l2", cur_train_recal_acc_l2, cur_val_recal_acc_l2)

        cur_train_acc = compute_acc(preds=cur_preds, mask=cur_mask, dataset=all_examples, target_label=binary_target)
        cur_val_acc = compute_acc(preds=cur_preds, mask=1-cur_mask, dataset=all_examples, target_label=binary_target)
        ## Acc L2
        cur_train_acc_l2 = compute_acc_l2(preds=cur_preds, mask=cur_mask, dataset=all_examples, target_label=binary_target)
        cur_val_acc_l2 = compute_acc_l2(preds=cur_preds, mask=1-cur_mask, dataset=all_examples, target_label=binary_target)
        
        print("acc", cur_train_acc, cur_val_acc)
        print("acc_l2", cur_train_acc_l2, cur_val_acc_l2)
        
        ## Pool results
        train_accs.append(cur_train_acc)
        val_accs.append(cur_val_acc)
        train_accs_l2.append(cur_train_acc_l2)
        val_accs_l2.append(cur_val_acc_l2)

        ############################ ECE evaluation (UWB) ############################
        if recalibrate:
            cur_train_ece = compute_ece(preds=cur_preds_recal, mask=cur_mask, dataset=all_examples, bins=n_bins, norm='l1', method='uniform', recalibrate=recal_ece, strategy=strategy, target_label=binary_target)
            cur_val_ece = compute_ece(preds=cur_preds_recal, mask=1-cur_mask, dataset=all_examples, bins=n_bins, norm='l1', method='uniform', recalibrate=recal_ece, strategy=strategy, target_label=binary_target)
            
            ece_loss = _ECELoss(n_bins=n_bins_eval)
            cur_train_ece2 = compute_ece2(preds=cur_preds_recal, mask=cur_mask, dataset=all_examples, ece_loss=ece_loss, target_label=binary_target)
            cur_val_ece2 = compute_ece2(preds=cur_preds_recal, mask=1-cur_mask, dataset=all_examples, ece_loss=ece_loss, target_label=binary_target)
            print(np.abs(cur_train_ece2 - cur_val_ece2))

            if traindata_reuse:
                loss = _LabelLoss(n_bins=n_bins_eval, method='uniform')
                cur_lsloss = compute_lsloss(cur_preds_recal, cur_mask, all_examples, loss)
                lsloss_uwb.append(cur_lsloss)
        else:
            cur_train_ece = compute_ece(preds=cur_preds, mask=cur_mask, dataset=all_examples, bins=n_bins, norm='l1', method='uniform', recalibrate=recal_ece, strategy=strategy, target_label=binary_target)
            cur_val_ece = compute_ece(preds=cur_preds, mask=1-cur_mask, dataset=all_examples, bins=n_bins, norm='l1', method='uniform', recalibrate=recal_ece, strategy=strategy, target_label=binary_target)
            
            ece_loss = _ECELoss(n_bins=n_bins_eval)
            cur_train_ece2 = compute_ece2(preds=cur_preds, mask=cur_mask, dataset=all_examples, ece_loss=ece_loss, target_label=binary_target)
            cur_val_ece2 = compute_ece2(preds=cur_preds, mask=1-cur_mask, dataset=all_examples, ece_loss=ece_loss, target_label=binary_target)
            print(np.abs(cur_train_ece2 - cur_val_ece2))
        
        print(cur_train_ece, cur_val_ece)
        print(np.abs(cur_train_ece - cur_val_ece))

        ## Pool results
        train_eces_uwb.append(cur_train_ece)
        val_eces_uwb.append(cur_val_ece)
        train_eces_uwb2.append(cur_train_ece2)
        val_eces_uwb2.append(cur_val_ece2)

        ############################ ECE evaluation (UMB) ############################
        if recalibrate:
            cur_train_ece = compute_ece(preds=cur_preds_recal, mask=cur_mask, dataset=all_examples, bins=n_bins_umb, norm='l1', method='quantile', recalibrate=recal_ece, strategy=strategy, target_label=binary_target)
            cur_val_ece = compute_ece(preds=cur_preds_recal, mask=1-cur_mask, dataset=all_examples, bins=n_bins_umb, norm='l1', method='quantile', recalibrate=recal_ece, strategy=strategy, target_label=binary_target)

            ece_loss = _ECELoss(n_bins=n_bins_eval, method='quantile', logits=cur_preds_recal[train_idx])
            cur_train_ece2 = compute_ece2(preds=cur_preds_recal, mask=cur_mask, dataset=all_examples, ece_loss=ece_loss, target_label=binary_target)
            cur_val_ece2 = compute_ece2(preds=cur_preds_recal, mask=1-cur_mask, dataset=all_examples, ece_loss=ece_loss, target_label=binary_target)
            print(np.abs(cur_train_ece2 - cur_val_ece2))

            if traindata_reuse:
                loss = _LabelLoss(n_bins=n_bins_eval, method='quantile', logits=cur_preds_recal[train_idx])
                cur_lsloss = compute_lsloss(cur_preds_recal, cur_mask, all_examples, loss)
                lsloss_umb.append(cur_lsloss)
        else:
            cur_train_ece = compute_ece(preds=cur_preds, mask=cur_mask, dataset=all_examples, bins=n_bins_umb, norm='l1', method='quantile', recalibrate=recal_ece, strategy=strategy, target_label=binary_target)
            cur_val_ece = compute_ece(preds=cur_preds, mask=1-cur_mask, dataset=all_examples, bins=n_bins_umb, norm='l1', method='quantile', recalibrate=recal_ece, strategy=strategy, target_label=binary_target)
            
            ece_loss = _ECELoss(n_bins=n_bins_eval, method='quantile', logits=cur_preds[train_idx])
            cur_train_ece2 = compute_ece2(preds=cur_preds, mask=cur_mask, dataset=all_examples, ece_loss=ece_loss, target_label=binary_target)
            cur_val_ece2 = compute_ece2(preds=cur_preds, mask=1-cur_mask, dataset=all_examples, ece_loss=ece_loss, target_label=binary_target)
            print(np.abs(cur_train_ece2 - cur_val_ece2))
        
        print(cur_train_ece, cur_val_ece)
        print(np.abs(cur_train_ece - cur_val_ece))

        ## Pool results
        train_eces_umb.append(cur_train_ece)
        val_eces_umb.append(cur_val_ece)
        train_eces_umb2.append(cur_train_ece2)
        val_eces_umb2.append(cur_val_ece2)

        ############################ ECE evaluation (Kernel) ############################
        #train_idx = 2*np.arange(len(cur_mask)) + cur_mask
        #test_idx = 2*np.arange(len(1-cur_mask)) + (1-cur_mask)
        #bandwidth = get_bandwidth(cur_preds[train_idx], args.device)
        #bandwidth = 0.001
        
        #if recalibrate:
        #    cur_train_ece = get_ece_kde(cur_preds_recal[train_idx], cur_labels_recal[train_idx], bandwidth=bandwidth, p=1, mc_type='marginal', device=args.device)
        #    cur_val_ece = get_ece_kde(cur_preds_recal[test_idx], cur_labels_recal[test_idx], bandwidth=bandwidth, p=1, mc_type='marginal', device=args.device)
        #else:
        #    cur_train_ece = get_ece_kde(cur_preds[train_idx], cur_labels[train_idx], bandwidth=bandwidth, p=1, mc_type='marginal', device=args.device)
        #    cur_val_ece = get_ece_kde(cur_preds[test_idx], cur_labels[test_idx], bandwidth=bandwidth, p=1, mc_type='marginal', device=args.device)
        
        #print(cur_train_ece, cur_val_ece)
        #print(np.abs(cur_train_ece - cur_val_ece))

        ## Pool results
        #train_eces_kde.append(cur_train_ece)
        #val_eces_kde.append(cur_val_ece)

    
    ############################ Bound evaluation ############################
    
    ### UWB
    cal_bound_uwb, mis_list_uwb = estimate_cmi_bound(masks, preds, labels, bins, n_bins=n_bins_eval, num_examples=n, loss='diff', verbose=False)
    #print(np.array(mis_list).sum())

    ### UMB
    cal_bound_umb, mis_list_umb = estimate_cmi_bound(masks, preds, labels, bins_umb, n_bins=n_bins_eval, num_examples=n, loss='diff', verbose=False)
    #print(np.array(mis_list).sum())

    if traindata_reuse:
        ### UWB
        #cal_reuse_bound_uwb, mis_reuse_list_uwb = estimate_cmi_bound(masks, preds, labels, bins, n_bins=n_bins_eval, num_examples=n, loss='reuse', verbose=False)
        mis_reuse_list_uwb, mis_bin_reuse_list_uwb = estimate_cmi_recal_bound(lsloss_uwb, masks)
        print('MIS reuse', np.array(mis_reuse_list_uwb).sum(), np.array(mis_bin_reuse_list_uwb).sum())
        
        ### UMB
        #cal_reuse_bound_umb, mis_reuse_list_umb = estimate_cmi_bound(masks, preds, labels, bins_umb, n_bins=n_bins_eval, num_examples=n, loss='reuse', verbose=False)
        mis_reuse_list_umb, mis_bin_reuse_list_umb = estimate_cmi_recal_bound(lsloss_umb, masks)
        print('MIS reuse', np.array(mis_reuse_list_umb).sum(), np.array(mis_bin_reuse_list_umb).sum())
    
    else:
        #cal_reuse_bound_uwb, mis_reuse_list_uwb = None, None
        #cal_reuse_bound_umb, mis_reuse_list_umb = None, None
        mis_reuse_list_uwb, mis_bin_reuse_list_uwb = None, None
        mis_reuse_list_umb, mis_bin_reuse_list_umb = None, None


    ### fCMI
    _, mis_list_fcmi = estimate_cmi_bound(masks, preds, labels, bins, n_bins=n_bins_eval, num_examples=n, loss='fcmi', verbose=False)
    #print(np.array(mis_list).sum())

    ### ECE-based eCMI
    mis_ece_list_uwb = estimate_cmi_ece_bound(train_eces=train_eces_uwb2, val_eces=val_eces_uwb2, masks=masks)
    print("MIS ECE:", np.array(mis_ece_list_uwb).sum())

    mis_ece_list_umb = estimate_cmi_ece_bound(train_eces=train_eces_umb2, val_eces=val_eces_umb2, masks=masks)
    print("MIS ECE:", np.array(mis_ece_list_umb).sum())

    #mis_list_harutyunyan = estimate_fcmi(masks, preds, num_examples=n)
    #print("mis_harutyunyan:", mis_list_harutyunyan.sum())
    #print(np.array(mis_list).sum())

    #bound_ours, mis_list_ours, kernel_tr, kernel_te = estimate_semipara_bound(masks, preds, labels, num_examples=n)
    #print("mis:", mis_list_ours.sum())
    #print('kernel', kernel_tr, kernel_te)

    return {
        'exp_train_acc': np.mean(train_accs),
        'exp_val_acc': np.mean(val_accs),
        'exp_gap': np.abs(np.mean(train_accs) - np.mean(val_accs)),
        'exp_train_acc_l2': np.mean(train_accs_l2),
        'exp_val_acc_l2': np.mean(val_accs_l2),
        'exp_gap_l2': np.abs(np.mean(train_accs_l2) - np.mean(val_accs_l2)),
        'exp_train_recal_acc': np.mean(train_recal_accs),
        'exp_val_recal_acc': np.mean(val_recal_accs),
        'exp_recal_gap': np.abs(np.mean(train_recal_accs) - np.mean(val_recal_accs)),
        'exp_train_recal_acc_l2': np.mean(train_recal_accs_l2),
        'exp_val_recal_acc_l2': np.mean(val_recal_accs_l2),
        'exp_gap_recal_l2': np.abs(np.mean(train_recal_accs_l2) - np.mean(val_recal_accs_l2)),
        'exp_train_ece_uwb': np.mean(train_eces_uwb),
        'exp_val_ece_uwb': np.mean(val_eces_uwb),
        'exp_gap_ece_uwb': np.abs(np.mean(train_eces_uwb) - np.mean(val_eces_uwb)),
        'exp_train_ece_umb': np.mean(train_eces_umb),
        'exp_val_ece_umb': np.mean(val_eces_umb),
        'exp_gap_ece_umb': np.abs(np.mean(train_eces_umb) - np.mean(val_eces_umb)),
        'exp_train_ece_uwb2': np.mean(train_eces_uwb2),
        'exp_val_ece_uwb2': np.mean(val_eces_uwb2),
        'exp_gap_ece_uwb2': np.abs(np.mean(train_eces_uwb2) - np.mean(val_eces_uwb2)),
        'exp_train_ece_umb2': np.mean(train_eces_umb2),
        'exp_val_ece_umb2': np.mean(val_eces_umb2),
        'exp_gap_ece_umb2': np.abs(np.mean(train_eces_umb2) - np.mean(val_eces_umb2)),
        'exp_train_ece_kde': np.mean(train_eces_kde),
        'exp_val_ece_kde': np.mean(val_eces_kde),
        'exp_gap_ece_kde': np.abs(np.mean(train_eces_kde) - np.mean(val_eces_kde)),
        'mis_list_uwb': np.array(mis_list_uwb).sum(),
        'mis_list_umb': np.array(mis_list_umb).sum(),
        'mis_list_fcmi': np.array(mis_list_fcmi).sum(),
        'mis_ece_list_uwb': np.array(mis_ece_list_uwb).sum(),
        'mis_ece_list_umb': np.array(mis_ece_list_umb).sum(),
        'mis_reuse_list_uwb': np.array(mis_reuse_list_uwb).sum(),
        'mis_reuse_bin_list_uwb': np.array(mis_bin_reuse_list_uwb).sum(),
        'mis_reuse_list_umb': np.array(mis_reuse_list_umb).sum(),
        'mis_reuse_bin_list_umb': np.array(mis_bin_reuse_list_umb).sum(),
        #'mis_list_harutyunyan': np.array(mis_list_harutyunyan).sum(),
        #'mis_list_ours': np.array(mis_list_ours).sum(),
        'cal_bound_uwb': cal_bound_uwb,
        'cal_bound_umb': cal_bound_umb,
        #'cal_reuse_bound_uwb': cal_reuse_bound_uwb,
        #'cal_reuse_bound_umb': cal_reuse_bound_umb,
        #'kernel_tr': kernel_tr,
        #'kernel_te': kernel_te
    }

def get_fcmi_results_for_fixed_model(n, lr, epoch, args, num_bins_eval=None, recalibrate=False, traindata_reuse=False, n_caldata=100, recal_method='hist_uwb', recal_ece=True, strategy='label', make_binary=False, binary_target=None):
    results = []
    for seed in range(args.n_seeds):
        cur = get_fcmi_results_for_fixed_z(n=n, lr=lr, epoch=epoch, seed=seed, args=args, num_bins_eval=num_bins_eval, recalibrate=recalibrate, traindata_reuse=traindata_reuse, n_caldata=n_caldata, recal_method=recal_method, recal_ece=recal_ece, strategy=strategy, make_binary=make_binary, binary_target=binary_target)
        results.append(cur)
    return results


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', '-d', default='cuda', help='specifies the main device')
    parser.add_argument('--exp_name', type=str, required=True)
    parser.add_argument('--results_dir', type=str, default='results')
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
    parser.add_argument('--recalibration', type=bool, default=False, help='Recalibration')
    parser.add_argument('--train_reuse', type=bool, default=False, help='Re-using training data for recalibration (proposed)')
    parser.add_argument('--recal_method', type=str, default='hist_umb', help='method for recalibration')
    parser.set_defaults(parse=True)
    args = parser.parse_args()
    print(args)

    if args.exp_name in ["fcmi-mnist-4vs9-CNN", "fcmi-mnist-4vs9-CNN-deterministic", "fcmi-mnist-4vs9-CNN-noisy", "fcmi-mnist-4vs9-CNN-halfnoisy"]:
        args.n_seeds = 5
        #args.n_S_seeds = 30
        args.n_S_seeds = 10
        args.ns = [75, 250, 1000, 4000]
        args.epochs = [200]#args.epochs = np.arange(1, 11) * 20
        args.num_classes = 2
        args.batch_size = 100
        lr_rate = 0.001
    elif args.exp_name in ["fcmi-mnist-CNN"]:
        args.n_seeds = 5
        #args.n_S_seeds = 30
        args.n_S_seeds = 10
        args.ns = [500, 1000, 5000, 20000]
        args.epochs = [200]#args.epochs = np.arange(1, 11) * 20
        args.num_classes = 2
        args.batch_size = 100
        lr_rate = 0.001
    elif args.exp_name in ["fcmi-mnist-4vs9-CNN-lr"]:
        print("hiiya")
        args.n_seeds = 5
        args.n_S_seeds = 30
        args.ns = [4000]#[75]
        args.lrs = [0.1, 1e-3, 1e-5]
        args.epochs = [200]#args.epochs = np.arange(1, 11) * 20
        args.num_classes = 2
    elif args.exp_name in ["fcmi-mnist-4vs9-CNN-width","fcmi-mnist-4vs9-CNN-depth"]:
        print("hi")
        args.n_seeds = 5
        args.n_S_seeds = 30
        args.ns = [4000]
        args.lrs = [0.0, 1.0, 2.0]
        args.epochs = [200] #args.epochs = np.arange(1, 11) * 20
        args.num_classes = 2
    elif args.exp_name == 'fcmi-mnist-4vs9-wide-CNN-deterministic':
        args.n_seeds = 5
        args.n_S_seeds = 30
        args.ns = [75, 250, 1000, 4000]
        args.epochs = [200]
        args.num_classes = 2
    elif args.exp_name == 'fcmi-mnist-4vs9-CNN-LD':
        args.n_seeds = 5
        #args.n_S_seeds = 30
        args.n_S_seeds = 10
        #args.ns = [4000]
        args.ns = [75, 250, 1000, 4000]
        #args.epochs = np.arange(1, 11) * 4
        args.epochs = [200]
        args.num_classes = 2
        args.batch_size = 100
        lr_rate = 0.001
    elif args.exp_name == 'fcmi-mnist-CNN-LD':
        args.n_seeds = 5
        #args.n_S_seeds = 30
        args.n_S_seeds = 10
        #args.ns = [4000]
        args.ns = [500, 1000, 5000, 20000]
        #args.epochs = np.arange(1, 11) * 4
        args.epochs = [200]
        args.num_classes = 2
        args.batch_size = 100
        lr_rate = 0.001
    elif args.exp_name == 'cifar10-pretrained-resnet50':
        args.n_seeds = 2
        #args.n_S_seeds = 40
        args.n_S_seeds = 5
        args.ns = [500, 1000, 5000, 20000]
        args.epochs = [40]
        args.num_classes = 10
        #args.batch_size = 64
        args.batch_size = 100
        args.make_binary = True
        args.binary_target = 3
        lr_rate = 0.01
    elif args.exp_name == 'cifar10-pretrained-resnet50-LD':
        args.n_seeds = 2
        #args.n_S_seeds = 40
        args.n_S_seeds = 5
        args.ns = [500, 1000, 5000, 20000]
        args.epochs = [40]
        args.num_classes = 10
        #args.batch_size = 64
        args.batch_size = 100
        args.make_binary = True
        args.binary_target = 3
        lr_rate = 0.001
    else:
        raise ValueError(f"Unexpected exp_name: {args.exp_name}")
    
    results = NestedDict()  # indexing with n, epoch
    for n in tqdm(args.ns):
        for epoch in tqdm(args.epochs, leave=False):
            if args.exp_name == 'cifar10-pretrained-resnet50' or args.exp_name == 'cifar10-pretrained-resnet50-LD':
                results[n][epoch] = get_fcmi_results_for_fixed_model(n=n, lr=lr_rate, epoch=epoch, args=args, recalibrate=args.recalibration, traindata_reuse=args.train_reuse, recal_method=args.recal_method, make_binary=args.make_binary, binary_target=args.binary_target)
            else:
                results[n][epoch] = get_fcmi_results_for_fixed_model(n=n, lr=lr_rate, epoch=epoch, args=args, recalibrate=args.recalibration, traindata_reuse=args.train_reuse, recal_method=args.recal_method)
    if args.recalibration:
        if args.train_reuse:
            results_file_path = os.path.join(args.results_dir, args.exp_name, 'results_ecmi_recal_reuse_{}.pkl'.format(args.recal_method))
        else:
            results_file_path = os.path.join(args.results_dir, args.exp_name, 'results_ecmi_recal_{}.pkl'.format(args.recal_method))
    else:
        results_file_path = os.path.join(args.results_dir, args.exp_name, 'results_ecmi_nonrecal.pkl')
    with open(results_file_path, 'wb') as f:
        pickle.dump(results, f)

if __name__ == '__main__':
    main()