from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm import tqdm
from models.wideresnet import *
from models.preactresnet import PreActResNet18
from losses import trades_loss
import time

parser = argparse.ArgumentParser(description='Adversarial Training')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                    help='input batch size for training (default: 128)')
parser.add_argument('--test-batch-size', type=int, default=128, metavar='N',
                    help='input batch size for testing (default: 128)')
parser.add_argument('--epochs', type=int, default=110, metavar='N',
                    help='number of epochs to train')
parser.add_argument('--weight-decay', '--wd', default=2e-4,
                    type=float, metavar='W')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                    help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--epsilon', type=float, default=0.031,
                    help='perturbation')
parser.add_argument('--num-steps', default=10,
                    help='perturb number of steps')
parser.add_argument('--step-size', type=float, default=0.007,
                    help='perturb step size')
parser.add_argument('--beta', type = float, default=1.0)
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--snap-epoch', type=int, default=5, metavar='N',
                    help='how many batches to test')                    
parser.add_argument('--model-dir', default='./wideResNet',
                    help='directory of model for saving checkpoint')
parser.add_argument('--model', default='WideResNet')
parser.add_argument('--save-freq', default=1, type=int, metavar='N',
                    help='save frequency')
parser.add_argument('--start-freq', default=0, type=int, metavar='N',
                    help='start point')
parser.add_argument('--loss', default='pgd_he', type=str)
parser.add_argument('--distance', default='l_inf', type=str, help='distance')
parser.add_argument('--widen_factor', default=10, type=int, help='widen_factor')
parser.add_argument('--BNeval', action='store_true', default=False) # whether use eval mode for BN when crafting adversarial examples
parser.add_argument('--softplus_beta', default=10., type=float)
parser.add_argument('--activation', default='ReLU', type=str, choices=['ReLU', 'Softplus', 'GELU'])

args = parser.parse_args()
preidx = ''
if args.model == 'PreActResNet18':
    preidx += '_PreActResNet18'
if args.activation == 'Softplus':
    preidx += '_softplus'

if 'trades' in args.loss:
    preidx += '_b{}'.format(args.beta)
    if args.widen_factor == 20:
        preidx += '_wide{}'.format(args.widen_factor)
    preidx += '_wd' + str(args.weight_decay)
    if args.epsilon > 0.031:
        preidx += '_eps8'
    if args.BNeval:
        preidx += '_BNeval'
    
model_dir = "checkpoint//" + args.loss + preidx +'_' + str(args.s) +'_'+str(args.m)
print(model_dir)
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {}

# setup data loader
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
])
trainset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs)
testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, **kwargs)

def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()

        # calculate robust loss
        loss = trades_loss(model=model,
                           x_natural=data,
                           y=target,
                           optimizer=optimizer,
                           step_size=args.step_size,
                           epsilon=args.epsilon,
                           perturb_steps=args.num_steps,
                           beta=args.beta,
                           loss=args.loss,
                           distance=args.distance,
                           BNeval = args.BNeval)
        loss.backward()
        optimizer.step()

        # print progress
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


def eval_train(model, device, train_loader):
    model.eval()
    train_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            train_loss += F.cross_entropy(output, target, size_average=False).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    train_loss /= len(train_loader.dataset)
    print('Training: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        train_loss, correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))
    training_accuracy = correct / len(train_loader.dataset)
    return train_loss, training_accuracy


def eval_test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, size_average=False).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print('Test: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    test_accuracy = correct / len(test_loader.dataset)
    return test_loss, test_accuracy


def adjust_learning_rate(optimizer, epoch):
    """decrease the learning rate"""
    lr = args.lr
    if epoch >= 100:
        lr = args.lr * 0.1
    if epoch >= 105:
        lr = args.lr * 0.01
    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def main():
    # init model, ResNet18() can be also used here for training
    
    if args.model == 'WideResNet':
        model = nn.DataParallel(WideResNet(widen_factor = args.widen_factor, activation=args.activation, softplus_beta=args.softplus_beta)).to(device)
    elif args.model == 'PreActResNet18':
        model = nn.DataParallel(PreActResNet18(activation=args.activation, softplus_beta=args.softplus_beta)).to(device)
    else: 
        raise IOError
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    for epoch in tqdm(range(1, args.epochs + 1)):
        # adjust learning rate for SGD
        adjust_learning_rate(optimizer, epoch)

        # adversarial training
        train_time = time.time()
        train(args, model, device, train_loader, optimizer, epoch)
        epoch_time = time.time()
        print('Train epoch {} time: {:.4f} minutes'.format(epoch, (epoch_time - train_time)/60))

        # evaluation on natural examples
        eval_train(model, device, train_loader)
        eval_test(model, device, test_loader)
        
        # save checkpoint
        if (epoch >= args.start_freq) and (epoch % args.save_freq == 0):
            torch.save(model.module.state_dict(),
                       os.path.join(model_dir, 'epoch{}.pt'.format(epoch)))

if __name__ == '__main__':
    main()
