import random
import torch
import numpy as np

from datasets import DATASETS
from models import model_settings
from train import train
from evaluation_another_copy import evaluation_loop, cutoff_test, cr_loop, scaling_analysis

import torch.backends.cudnn as cudnn

from collections import OrderedDict

import argparse

def rename_state_dict(base_dict, new_prefix):
    new_dict = OrderedDict()
    for key, value in base_dict.items():
        new_key = new_prefix + key.partition('.')[2].partition('.')[2] # Corrects key to match loading in here
        new_dict[new_key] = value
        
    return new_dict

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean Value Expected')

PARALLEL_CHOICES = ['never', 'always', 'eval']

parser = argparse.ArgumentParser(description='Certifying examples as per Smoothing')
parser.add_argument('--dataset', type=str, choices=DATASETS)
parser.add_argument('--parallel', type=str, choices=PARALLEL_CHOICES, help='Never if parallel will never be used, always = Training & Eval, eval = only at evaluation')
parser.add_argument('--batch_size', type=int, default=0, help='Batch Size (0 == Model default)')
parser.add_argument('--lr', type=float, default=0, help='Learning Rate (0 == Model default)')

parser.add_argument('--sigma', type=float, default=0.0, help='Noise level')
parser.add_argument('--samples', type=int, default=1500, help='Number of samples')
parser.add_argument('--epochs', type=int, default=80, help='Training Epochs')
parser.add_argument('--total_cutoff', type=int, default=250, help='Number of samples tested over')

parser.add_argument('--train', action='store_true', help='If training is required')
parser.add_argument('--eval', action='store_true', help='If evaluation is required')
parser.add_argument('--cutoff_experiment', action='store_true', help='If cutoff experiment is performed')
parser.add_argument('--sigma_scaling_experiment', action='store_true', help='If experiment on scaling of sigma is performed')
parser.add_argument('--new_cr', action='store_true', help='If improved cr experiment is performed')
parser.add_argument('--plotting', type=str2bool, nargs='?', const=True, default=True, help='If cutoff experiment is performed')

parser.add_argument('--autoattack_radii', type=float, default=-1, help='Noise level')
parser.add_argument('--pgd_radii', type=float, default=20, help='Noise level')

args = parser.parse_args()
args.pgd_radii = args.pgd_radii/255

cudnn.benchmark = True

#args.epochs
#args.parallel = ['never', 'always', 'eval']
#args.eval = [True, False]
#args.train = [True, False]
#args.cutoff_experiment = [True, False]

if __name__ == "__main__":
    torch.manual_seed(0)
    random.seed(0)
    np.random.seed(0)

    # Preload model settings
    model, loss, optimizer, lr_scheduler, train_loader, val_loader, test_loader, device, classes = model_settings(args.dataset, args)
    
    #print(len(train_loader), len(val_loader), len(test_loader), flush=True)
        
    # Train
    if args.train:
        print('Training', flush=True)
        model, cutoff = train(device, model, optimizer, lr_scheduler, args.epochs, train_loader, val_loader, args, args.dataset, val_cutoff=1e6)
    else: 
        print('Loading Model', flush=True)
        pth = './saved_models/' + args.dataset + '-' + str(args.sigma) + '-weight.pth'               
        loc = 'cuda'#:{}'.format(args.gpu_num)
        '''try:
            checkpoint = torch.load(pth, map_location=loc)
            model.load_state_dict(checkpoint)            
        except:
            checkpoint = torch.load(pth)                  
            model.load_state_dict(checkpoint)     
        '''       
        checkpoint = torch.load(pth)                  
        model.load_state_dict(checkpoint)
        model.eval()     
                     
    del train_loader, val_loader
   
    if args.eval or args.new_cr:      
        print('Evaluating attacks')
        if args.parallel == 'eval' or ((args.parallel == 'always') and (args.train is False)):        
            cuda_device_count = torch.cuda.device_count()        
            print('Cuda device count: ', cuda_device_count)
            model = model.to("cpu")
            model = torch.nn.DataParallel(model, device_ids=[i for i in range(cuda_device_count)])    
            device = torch.device("cuda:0")
            model.to(device)
            #model.to(f'cuda:{0}')
            
        if args.new_cr: 
            cr_loop(device, model, test_loader, args.dataset, args.sigma, args.samples, classes, total_cutoff=args.total_cutoff, plotting=args.plotting) 
        else:
            evaluation_loop(device, model, test_loader, args.dataset, args.sigma, args.samples, classes, total_cutoff=args.total_cutoff, plotting=args.plotting, autoattack_radii=args.autoattack_radii, pgd_radii=args.pgd_radii)
        
    if args.cutoff_experiment:
        print('Experimenting with cutoff threshold')
        cutoff_test(device, model, test_loader, args.dataset, args.sigma, args.samples, classes)
        
    if args.sigma_scaling_experiment:
        scaling_analysis(device, model, test_loader, args.dataset, args.sigma, args.samples, classes, total_cutoff=args.total_cutoff, plotting=args.plotting, autoattack_radii=args.autoattack_radii, pgd_radii=args.pgd_radii)
    

        
