

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
import argparse
import time
from tensorboardX import SummaryWriter


import sys
import os

currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(os.path.dirname(currentdir))
sys.path.append(parentdir)

from utils.Certified.architectures import ARCHITECTURES
from utils.Certified.datasets import DATASETS
from utils.Certified.utils_ensemble import AverageMeter, accuracy, test, copy_code, requires_grad_
from utils.Certified.datasets import get_dataset
from utils.Certified.architectures import get_architecture
from train.Certified.third_party.smoothadv import PGD_L2, DDN

from train.Certified.trainer import SmoothAdv_Trainer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('dataset', type=str, choices=DATASETS)
parser.add_argument('arch', type=str, choices=ARCHITECTURES)
parser.add_argument('--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--batch', default=256, type=int, metavar='N',
                    help='batchsize (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    help='initial learning rate', dest='lr')
parser.add_argument('--lr_step_size', type=int, default=30,
                    help='How often to decrease learning by gamma.')
parser.add_argument('--gamma', type=float, default=0.1,
                    help='LR is multiplied by gamma on schedule.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--noise_sd', default=0.0, type=float,
                    help="standard deviation of Gaussian noise for data augmentation")
parser.add_argument('--print-freq', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--num-models', type=int, required=True)

#####################
# Options added by Salman et al. (2019)
parser.add_argument('--resume', action='store_true',
                    help='if true, tries to resume training from existing checkpoint')
parser.add_argument('--pretrained-model', type=str, default='',
                    help='Path to a pretrained model')

#####################
# Attack params
parser.add_argument('--attack', default='DDN', type=str, choices=['DDN', 'PGD'])
parser.add_argument('--epsilon', default=64.0, type=float)
parser.add_argument('--num-steps', default=10, type=int)
parser.add_argument('--warmup', default=1, type=int, help="Number of epochs over which "
                                                          "the maximum allowed perturbation increases linearly "
                                                          "from zero to args.epsilon.")
parser.add_argument('--num-noise-vec', default=1, type=int,
                    help="number of noise vectors to use for finding adversarial examples. `m_train` in the paper.")
parser.add_argument('--no-grad-attack', action='store_true',
                    help="Choice of whether to use gradients during attack or do the cheap trick")

# PGD-specific
parser.add_argument('--random-start', default=True, type=bool)

# DDN-specific
parser.add_argument('--init-norm-DDN', default=256.0, type=float)
parser.add_argument('--gamma-DDN', default=0.05, type=float)

args = parser.parse_args()


if args.attack == 'PGD':
    mode = f"pgd_{args.epsilon}_{args.num_steps}_{args.warmup}"
elif args.attack == 'DDN':
    mode = f"ddn_{args.epsilon}_{args.num_steps}_{args.warmup}_{args.init_norm_DDN}_{args.gamma_DDN}"
else:
    raise Exception('Unknown attack')

args.outdir = f"/{args.dataset}/salman/{mode}/num_{args.num_noise_vec}/noise_{args.noise_sd}"

if (args.resume):
	args.outdir = "resume" + args.outdir
else:
	args.outdir = "scratch" + args.outdir

args.outdir = "logs/Certified/" + args.outdir

args.epsilon /= 256.0
args.init_norm_DDN /= 256.0


def main():
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    copy_code(args.outdir)

    train_dataset = get_dataset(args.dataset, 'train')
    test_dataset = get_dataset(args.dataset, 'test')
    pin_memory = (args.dataset == "imagenet")
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch,
                              num_workers=args.workers, pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch,
                             num_workers=args.workers, pin_memory=pin_memory)

    model = []
    for i in range(args.num_models):
        submodel = get_architecture(args.arch, args.dataset)
        submodel = nn.DataParallel(submodel)
        model.append(submodel)
    print("Model loaded")

    criterion = nn.CrossEntropyLoss().cuda()

    param = list(model[0].parameters())
    for i in range(1, args.num_models):
        param.extend(list(model[i].parameters()))

    optimizer = optim.SGD(param, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.gamma)

    model_path = os.path.join(args.outdir, 'checkpoint.pth.tar')
    writer = SummaryWriter(args.outdir)

    if (args.resume):
        base_classifier = "logs/Certified/scratch/" + args.dataset + "/cohen/noise_" + \
                          str(args.noise_sd) + "/checkpoint.pth.tar"
        print(base_classifier)
        for i in range(3):
            checkpoint = torch.load(base_classifier + ".%d" % (i))
            print("Load " + base_classifier + ".%d" % (i))
            model[i].load_state_dict(checkpoint['state_dict'])
            model[i].train()
        print("Loaded...")

    if args.attack == 'PGD':
        print('Attacker is PGD')
        attacker = PGD_L2(steps=args.num_steps, device=device, max_norm=args.epsilon)
    elif args.attack == 'DDN':
        print('Attacker is DDN')
        attacker = DDN(steps=args.num_steps, device=device, max_norm=args.epsilon,
                       init_norm=args.init_norm_DDN, gamma=args.gamma_DDN)
    else:
        raise Exception('Unknown attack')

    for epoch in range(args.epochs):
        attacker.max_norm = np.min([args.epsilon, (epoch + 1) * args.epsilon / args.warmup])
        attacker.init_norm = np.min([args.epsilon, (epoch + 1) * args.epsilon / args.warmup])

        SmoothAdv_Trainer(args, train_loader, model, criterion, optimizer, epoch, args.noise_sd,
                                      attacker, device, writer)
        test(test_loader, model, criterion, epoch, args.noise_sd, device, writer, args.print_freq)

        scheduler.step(epoch)

        for i in range(args.num_models):
            model_path_i = model_path + ".%d" % (i)
            torch.save({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model[i].state_dict(),
                'optimizer': optimizer.state_dict(),
            }, model_path_i)


if __name__ == "__main__":
    main()
