from __future__ import print_function
import argparse

import torch
import torch.nn as nn
import torch.optim as optim
import sys
sys.path.append('./')
from utils.prepare_corruption_dataset import *
from utils.misc import *
from utils.test_helpers import *
from utils.prepare_dataset import *
from utils.rotation import rotate_batch 
from advertorch.context import ctx_noparamgrad_and_eval
from advertorch.attacks import LinfPGDAttack


common_corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur',
                        'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
                        'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression']
common_corruptions = ['elastic_transform', 'pixelate', 'jpeg_compression']
# Apply linf adversarial training to all cifar10-c datasets (train and test dataset
# are created by a 80%, 20% split among all severity levels)

for corruption in common_corruptions:
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='cifar10')
    parser.add_argument('--corruption', default=corruption)
    parser.add_argument('--level', default=5, type=int)
    parser.add_argument('--dataroot', default='/nobackup/yguo/datasets')
    parser.add_argument('--shared', default=None)
    ########################################################################
    parser.add_argument('--depth', default=26, type=int)
    parser.add_argument('--width', default=1, type=int)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--group_norm', default=0, type=int)
    ########################################################################
    parser.add_argument('--lr', default=3e-4, type=float)
    parser.add_argument('--nepoch', default=100, type=int)
    parser.add_argument('--milestone_1', default=50, type=int)
    parser.add_argument('--milestone_2', default=65, type=int)
    parser.add_argument('--rotation_type', default='rand')
    ########################################################################
    parser.add_argument('--outf', default='results/pretrain/cifar10c_{}_adv_none_gn'.format(corruption))
    parser.add_argument('--nb_iter', default=7, type=int)
    args = parser.parse_args()
    print(corruption)

    my_makedir(args.outf)
    import torch.backends.cudnn as cudnn
    cudnn.benchmark = True
    net, ext, head, ssh = build_model(args)
    (_, trloader), (_, teloader)  = prepare_corruption_data(args.corruption)
    print(len(trloader), len(teloader))

    parameters = list(net.parameters())+list(head.parameters())
    optimizer = optim.Adam(parameters, lr = args.lr)
    criterion = nn.CrossEntropyLoss().cuda()

    all_err_cls = []
    all_err_ssh = []
    eps = 16/255
    print('Running...')
    print('Error (%)\t\ttest\t\tself-supervised')
    for epoch in range(1, args.nepoch+1):
        net.train()
        adversary = LinfPGDAttack(net, loss_fn=nn.CrossEntropyLoss().cuda(), eps=eps,
            nb_iter=args.nb_iter, eps_iter=4/255, rand_init=True, clip_min=-1.0, clip_max=1.0,
            targeted=False)
        for batch_idx, (inputs, labels) in enumerate(trloader):
            optimizer.zero_grad()
            inputs_cls, labels_cls = inputs.cuda(), labels.cuda()
            with ctx_noparamgrad_and_eval(net):
                inputs_adv = adversary.perturb(inputs_cls, labels_cls)
            outputs_cls = net(inputs_adv)
            loss = criterion(outputs_cls, labels_cls)

            loss.backward()
            optimizer.step()

        err_cls = test(teloader, net)[0]
        err_ssh = 0 if args.shared is None else test(teloader, ssh, sslabel='expand')[0]
        all_err_cls.append(err_cls)
        all_err_ssh.append(err_ssh)
        # scheduler.step()

        print(('Epoch %d/%d:' %(epoch, args.nepoch)).ljust(24) +
                        '%.2f\t\t%.2f' %(err_cls*100, err_ssh*100))
        torch.save((all_err_cls, all_err_ssh), args.outf + '/loss.pth')
        plot_epochs(all_err_cls, all_err_ssh, args.outf + '/loss.pdf')

        if epoch % 5 == 1: 
            state = {'err_cls': err_cls, 'err_ssh': err_ssh,
                        'net': net.state_dict(), 'head': head.state_dict(),
                        'optimizer': optimizer.state_dict()}
            torch.save(state, args.outf + '/ckpt.pth')
