import argparse
import os

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Evaluates conformal predictors',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--dataset', type=str, default='imagenet', help='dataset')
    parser.add_argument('--score', type=str, default='THR', help='the score function of CP.')
    parser.add_argument('--predictor', type=str, default='Split', help='the conformal predictor of CP.')
    parser.add_argument('--truncation_name', type=str, default='base', help='the transformation of base conformal prediction. Optional: base, truncated')
    parser.add_argument('--sparse_topk', type=int, default=5, help='the top-k retained probabilities.')
    parser.add_argument('--alpha', type=float, default=0.1, help='the error rate.')
    parser.add_argument('--trials', type=int, default=1, help='number of trials')
    parser.add_argument('--conf_cal', type=str, default="None", help='the confidence calibration method.')
    parser.add_argument('--hyperpar', type=int, default=1, help='tuning the hyper-parameter of score functions.')
    parser.add_argument('--tuning_criterion', type=str, default="size",
                        help='the criterion to tune the best hyper-parameter.')
    parser.add_argument('--gpu', type=int, default=4, help='chose gpu id')
    parser.add_argument('--n_cal', type=int, default=5000, help='the number of calibration examples')
    parser.add_argument('--model', type=str, default=None, help='model')

    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    
    import torch.backends.cudnn as cudnn
    from torch import optim

    dataset_name = args.dataset

    
    if dataset_name == "imagenet":
        n_data_conf = args.n_cal
        n_data_val = 30000
        if n_data_conf > 20000:
            n_data_val = 50000 - n_data_val
        models= ["ResNeXt101"]
    elif dataset_name == "cifar100":
        
        n_data_conf = args.n_cal
        if n_data_conf>5000:
            n_data_conf = 5000
        n_data_val = 5000
        models= ["ResNet101"]
    else:
        raise NotImplementedError

    pct_paramtune = 0.2
    bsz = 128
    cudnn.benchmark = True
    
    if args.score == "SAPS" and args.hyperpar :
            args.conf_cal = "TS"

    from main import experiment
    
    if args.model == None:
        
        save_csv_path = f"cache/{args.dataset}/res/predictor={args.predictor}_score={args.score}_truncation_name={args.truncation_name}_tuning_criterion={args.tuning_criterion}_cal={args.conf_cal}_cal={n_data_conf}_val={n_data_val}_hyperpar={args.hyperpar}.csv"
        print(save_csv_path)
        if os.path.exists(save_csv_path):
            os.remove(save_csv_path)
        for model in models:
            args.model = model
            this_experiment = experiment(args, save_csv_path)
            this_experiment.run(n_data_conf, n_data_val, pct_paramtune, bsz)
    else:
        save_csv_path = f"cache/{args.dataset}/res/predictor={args.predictor}_score={args.score}_truncation_name={args.truncation_name}_tuning_criterion={args.tuning_criterion}_cal={args.conf_cal}_cal={n_data_conf}_val={n_data_val}_hyperpar={args.hyperpar}_model={args.model}.csv"
        if os.path.exists(save_csv_path):
            os.remove(save_csv_path)
        this_experiment = experiment(args, save_csv_path)
        this_experiment.run(n_data_conf, n_data_val, pct_paramtune, bsz)