import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms

import argparse
from tqdm import tqdm
import os
from pprint import pprint

from utils import set_seed, PoisonDataset, make_and_restore_cifar_model
from attacks.natural import natural_attack
from attacks.adv import adv_attack
from tf_attacks.attacks.torchattack_lib import square

from tf_attacks.attacks.AA import aa

def eval_model(args, model, data_loader):
    model.eval()

    keys, values = [], []
    keys.append('Model')
    values.append(args.model_path)

    # Natural
    _, acc, name = natural_attack(args, model, data_loader)
    keys.append(name)
    values.append(acc)

    # FGSM
    args.num_steps = 1
    args.step_size = args.eps
    args.random_restarts = 1
    _, acc, name = adv_attack(args, model, data_loader)
    keys.append('FGSM')
    values.append(acc)

    # PGD-20
    args.num_steps = 20
    args.step_size = args.eps / 4
    args.random_restarts = 1
    _, acc, name = adv_attack(args, model, data_loader)
    keys.append(name)
    values.append(acc)

    
    # Save results
    import csv
    csv_fn = '{}.{}-{}.csv'.format(args.model_path, args.eval_mode, args.eval_data_arch)
    with open(csv_fn, 'w') as f:
        write = csv.writer(f)
        write.writerow(keys)
        write.writerow(values)

    print('=> csv file is saved at [{}]'.format(csv_fn))


def main(args):
    model = make_and_restore_cifar_model(args.arch, args.model_path)
    model.eval()
    set_seed(args.seed)

    if args.eval_mode in ['Clean'] or os.path.isfile(args.hyp_data_path):
        print('Evaluation data [{}] already exists. Loading ...'.format(args.eval_mode))
    else:
        print('Preparing evaluation data [{}] ...'.format(args.eval_mode))

        clean_eval_data = datasets.CIFAR10(args.data_path, train=False, transform=transforms.ToTensor())
        clean_data_loader = DataLoader(clean_eval_data, batch_size=128, shuffle=False, num_workers=8, pin_memory=True)

        if args.eval_mode in ['Hyp']:
            from attacks.hyp import hyp_attack
            hyp_input, clean_target, ATTACK_NAME = hyp_attack(args, model, clean_data_loader)
        elif args.eval_mode in ['TF']:
            from attacks.tf import tf_attack
            hyp_input, clean_target, ATTACK_NAME = tf_attack(args, model, clean_data_loader)
        torch.save((hyp_input, clean_target), args.hyp_data_path)
    
    if args.eval_mode in ['Clean']:
        eval_data = datasets.CIFAR10(args.data_path, train=False, transform=transforms.ToTensor())
    else:
        eval_data = PoisonDataset(args.hyp_data_root, data_type=args.eval_mode, transform=transforms.ToTensor())
    
    data_loader = DataLoader(eval_data, batch_size=128, shuffle=False, num_workers=8, pin_memory=True)
    print('\nEvaluation on [{}].'.format(args.eval_mode))
    eval_model(args, model, data_loader)



if __name__ == '__main__':
    
    
    parser = argparse.ArgumentParser('Evaluating classifiers on robust hypocritical examples')

    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--arch', default='ResNet18', type=str, choices=['MLP', 'VGG16', 'ResNet18', 'DenseNet121', 'WRN28-10'])
    parser.add_argument('--train_loss', default='ST', type=str, choices=['ST', 'AT', 'TRADES', 'TRHRM_2', 'THRM'])
    parser.add_argument('--constraint', default='Linf', choices=['Linf', 'L2'], type=str)
    parser.add_argument('--eps', default=8/255, type=float)
    parser.add_argument('--data_type', default='Quality', choices=['Naive', 'Noise', 'Mislabeling', 'Poisoning', 'Quality'])

    parser.add_argument('--beta', default=6, type=float)

    parser.add_argument('--eval_mode', default='Clean', type=str, choices=['Clean', 'Hyp', 'TF'])
    parser.add_argument('--eval_data_arch', default='Clean', type=str, choices=['MLP', 'VGG16', 'ResNet18', 'DenseNet121', 'WRN28-10'])
    parser.add_argument('--device', default=0, type=int)
    
    args = parser.parse_args()

    import os 
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device)

    # Attack options
    args.num_steps = 20
    args.step_size = args.eps / 4
    args.random_restarts = 1

    if args.data_type == 'Poisoning':
        args.data_type = args.data_type + args.constraint

    # Miscellaneous
    args.out_dir = 'results/CIFAR10'
    args.data_path = '../datasets/CIFAR10'
    args.exp_name = '{}-{}-{}-{}-seed{}'.format(args.arch, args.train_loss, args.data_type, args.constraint, args.seed)
    args.tensorboard_path = os.path.join(os.path.join(args.out_dir, args.exp_name), 'tensorboard')
    args.model_path_best = os.path.join(os.path.join(args.out_dir, args.exp_name), 'checkpoint_best.pth')
    args.model_path_last = os.path.join(os.path.join(args.out_dir, args.exp_name), 'checkpoint_last.pth')
    args.model_path = args.model_path_last

    data_exp_name = '{}-{}-{}-{}-seed{}'.format(args.eval_data_arch, args.train_loss, args.data_type, args.constraint, args.seed)
    args.hyp_data_root = os.path.expanduser(os.path.join(args.out_dir, data_exp_name))
    args.hyp_data_path = os.path.join(args.hyp_data_root, '{}.data'.format(args.eval_mode))

    pprint(vars(args))

    torch.backends.cudnn.benchmark = True
    main(args)