import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import argparse
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt 
import time
from utils import * 
from conformal_logits_full_pss_calib import ConformalModel
import torch.backends.cudnn as cudnn
import random


parser = argparse.ArgumentParser(description='Conformalize Torchvision Model on Imagenet')
parser.add_argument('--data', metavar='IN', help='path to Imagenet Val',default='/data')
parser.add_argument('--batch_size', metavar='BSZ', help='batch size', default=1024)
parser.add_argument('--num_workers', metavar='NW', help='number of workers', default=0)
parser.add_argument('--num_calib', metavar='NCALIB', help='number of calibration points', default=10000)
parser.add_argument('--num_classes', metavar='NCALIB', help='number of classes', default=1000)
parser.add_argument('--save_score_name', help='path to save score', default='score')


parser.add_argument('--seed', metavar='SEED', type=int, help='random seed', default=0)


parser.add_argument('--epsilon', '-e', type=float, default=0.031, 
        help='maximum perturbation of adversaries (4/255=0.0157)')
parser.add_argument('--alpha', '-a', type=float, default=0.00784, 
        help='movement multiplier per iteration when generating adversarial examples (2/255=0.00784)')
parser.add_argument('--k', '-k', type=int, default=100, 
        help='maximum iteration when generating adversarial examples')

parser.add_argument('--lambda_temp_sample', type=float, default=2,
        help='parameter of sampling temperature')
parser.add_argument('--lambda_dirichlet', type=float, default=0.1,
        help='parameter of sampling temperature')
parser.add_argument('--lambda_k', type=float, default=0.3,
        help='parameter of power function')
parser.add_argument('--lambda_ent', type=float, default=0.1,
        help='parameter of power function')
parser.add_argument('--round_opt', type=float, default=10,
        help='parameter of optimization round')
parser.add_argument('--lr_opt', type=float, default=0.01,
        help='parameter of optimization learning rate')

# parser.add_argument('--epsilon', '-e', type=float, default=0.0, 
#         help='maximum perturbation of adversaries (4/255=0.0157)')
# parser.add_argument('--alpha', '-a', type=float, default=0.00784, 
#         help='movement multiplier per iteration when generating adversarial examples (2/255=0.00784)')
# parser.add_argument('--k', '-k', type=int, default=0, 
#         help='maximum iteration when generating adversarial examples')


parser.add_argument('--perturbation_type', '-p', choices=['linf', 'l2'], default='linf', 
        help='the type of the perturbation (linf or l2)')
parser.add_argument('--use_adv_calib', help='whether use adv clib', default=False)


def probabilities_to_logits(probabilities):
    """
    Convert a probability vector to logits for multi-class classification.
    
    Parameters:
    probabilities (array-like): A 2D array (or list of lists) where each row is a probability vector 
                                 for a multi-class classification task (rows should sum to 1).
    
    Returns:
    numpy array: The corresponding logits.
    """
    probabilities = np.array(probabilities)
    
    # Validate input
    if np.any(probabilities <= 0) or not np.allclose(np.sum(probabilities, axis=1), 1):
        raise ValueError("Probabilities must be strictly positive and rows must sum to 1.")
    
    logits = np.log(probabilities)
    return logits


def output_mean_std(args, load_checkpoint_list, save_file, lambda_k, lambda_temp_sample, lambda_dirichlet, round_opt, lr_opt, lambda_ent):
    coverage_adv_all = []
    coverage_normal_all = []
    size_adv_all = []
    size_normal_all = []
    clean_acc_all = []
    clean_acc_top5_all = []
    robust_acc_all = []
    robust_acc_top5_all = []

    coverage_seed_list_normal = []
    size_seed_list_normal = []

    coverage_seed_list_adv = []
    size_seed_list_adv = []

    args.lambda_k = lambda_k
    args.lambda_temp_sample = lambda_temp_sample
    args.lambda_dirichlet = lambda_dirichlet
    args.round_opt = round_opt
    args.lr_opt = lr_opt
    args.lambda_ent = lambda_ent

    for load_checkpoint in load_checkpoint_list:
        # Transform as in https://github.com/pytorch/examples/blob/42e5b996718797e45c46a25c55b031e6768f8440/imagenet/main.py#L92 

        # Get the conformal calibration dataset


        # Get the model 
        
        preds = np.load(load_checkpoint[0])
        targets = np.load(load_checkpoint[1])

        #preds = probabilities_to_logits(preds)
        print(preds.shape)
        pred_label = np.argmax(preds,axis=1)
        print(np.max(pred_label))
        print('accuracy:')
        print(np.sum(pred_label==targets))
        #print(preds[0])
        print(targets.shape)
        print(np.max(targets))

        # optimize for 'size' or 'adaptiveness'
        lamda_criterion = 'size'
        # allow sets of size zero
        allow_zero_sets = False
        # use the randomized version of conformal
        randomized = True 

        logits = torch.from_numpy(preds)
        labels = torch.from_numpy(targets)

        #logits_tensor.data = preds
        #labels.data = targets

        dataset = torch.utils.data.TensorDataset(logits, labels.long()) 

        # Test the loader
        # count = 0
        # for i, (x, target) in enumerate(dataset_logits):
        #         if torch.argmax(x) == target:
        #                 count += 1
        # print(count)

        
        for set_seed in range(5):

            np.random.seed(seed=set_seed)
            torch.manual_seed(set_seed)
            torch.cuda.manual_seed(set_seed)
            random.seed(set_seed)
            
            imagenet_calib_data, imagenet_val_data = torch.utils.data.random_split(dataset, [args.num_calib,len(dataset)-args.num_calib])

            # Initialize loaders


            #calib_loader = torch.utils.data.DataLoader(imagenet_calib_data, batch_size=args.batch_size, shuffle=True, pin_memory=True)
            val_loader = torch.utils.data.DataLoader(imagenet_val_data, batch_size=args.batch_size, shuffle=True, pin_memory=True)

            # Conformalize model
            model = ConformalModel(imagenet_calib_data, args, alpha=0.1, lamda=0., kreg=0, randomized=randomized, allow_zero_sets=allow_zero_sets)

            print("Model calibrated and conformalized! Now evaluate over remaining data.")
            #_, _, coverage_normal, size_normal, size_std_normal, size_mean_normal = validate(val_loader, model, print_bool=True)
            
            #_, _, coverage_adv, size_adv, size_std_adv, size_mean_adv  = validate_adv_conf_score(val_loader, model, args, print_bool=True)
            clean_acc, clean_acc_top_5, coverage_normal, size_normal, coverage_list_normal, size_list_normal, size_list, acc_sample_list, correct_term_index_list, ps_score_list, ps_score_normalized_list_binned, top_1_acc = validate_with_calibration(val_loader, model, print_bool=True)

            print('normal result:')
            print(coverage_normal)
            print(size_normal)

            print("Complete!")
            
            coverage_normal_all.append(coverage_normal*100.0)
            size_normal_all.append(size_normal)
            clean_acc_all.append(clean_acc*100.)
            # robust_acc_all.append(robust_acc*100.)
            clean_acc_top5_all.append(clean_acc_top_5*100.)
            # robust_acc_top5_all.append(robust_acc_top_5*100.)

            coverage_seed_list_normal.append(coverage_list_normal)
            size_seed_list_normal.append(size_list_normal)
            pickle.dump({'size_list':size_list, 'acc_sample_list':acc_sample_list, 'correct_term_index_list':correct_term_index_list, 'top_1_acc':top_1_acc, 'ps_score_list':ps_score_list, 'ps_score_normalized_list_binned':ps_score_normalized_list_binned, 'coverage_mean':np.mean(coverage_normal_all), 'coverage_std':np.std(coverage_normal_all), 'size_mean':np.mean(size_normal_all), 'size_std':np.std(size_normal_all) },open('result_imagenet_calibrate_Gaussian_ViT_B/'+save_file+'_seed_'+str(set_seed)+'.pkl','wb'))

            # coverage_seed_list_adv.append(coverage_list_adv)
            # size_seed_list_adv.append(size_list_adv)
    
    # coverage_seed_list_normal = np.mean(np.asarray(coverage_seed_list_normal),axis=0)
    # size_seed_list_normal = np.mean(np.asarray(size_seed_list_normal),axis=0)
    # coverage_seed_list_adv = np.mean(np.asarray(coverage_seed_list_adv),axis=0)
    # size_seed_list_adv = np.mean(np.asarray(size_seed_list_adv),axis=0)


    coverage_seed_list_normal = np.asarray(coverage_seed_list_normal)
    size_seed_list_normal = np.asarray(size_seed_list_normal)


    print(f' {np.mean(coverage_normal_all):.3f} ({np.std(coverage_normal_all):.3f}) & {np.mean(size_normal_all):.3f} ({np.std(size_normal_all):.3f}) & {np.mean(clean_acc_all):.3f} ({np.std(clean_acc_all):.3f}) & {np.mean(clean_acc_top5_all):.3f} ({np.std(clean_acc_top5_all):.3f})')
    result_string = f' {np.mean(coverage_normal_all):.3f} ({np.std(coverage_normal_all):.3f}) & {np.mean(size_normal_all):.3f} ({np.std(size_normal_all):.3f}) & {np.mean(clean_acc_all):.3f} ({np.std(clean_acc_all):.3f}) & {np.mean(clean_acc_top5_all):.3f} ({np.std(clean_acc_top5_all):.3f})'

    #cp_curve = {'coverage_normal':coverage_seed_list_normal ,'size:normal':size_seed_list_normal, 'coverage_adv':coverage_seed_list_adv, 'size_adv':size_seed_list_adv, 'result':result_string}
    #pickle.dump(cp_curve,open('result/'+save_file+'.pkl','wb'))



if __name__ == "__main__":
    args = parser.parse_args()
    load_checkpoint_list = []
    
    output_mean_std(args, load_checkpoint_list, filename, 0.5, 3, 0, 4, 0.0001, 1.0)
    
    
    