from __future__ import print_function
import os
import pickle
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Variable
import path
import sys
import subprocess

import numpy as np
import time

sys.path.append("../../../../")
folder_path= (path.Path(__file__).abspath()).parent.parent
sys.path.append(folder_path)
folder_path= (path.Path(__file__).abspath()).parent.parent.parent
sys.path.append(folder_path)
folder_path= (path.Path(__file__).abspath()).parent.parent.parent.parent.parent.parent

from resnet import ResNet18, ResNet50
from nn_mnist import NN_MNIST

import cleverhans
from cleverhans.torch.attacks.projected_gradient_descent import projected_gradient_descent

# os.environ["CUDA_VISIBLE_DEVICES"]="2"

parser = argparse.ArgumentParser(description='PyTorch CIFAR MART Defense')
parser.add_argument('-d', '--data', type=str, default='cifar10', choices=['cifar10', 'fmnist'])
# parser.add_argument('-gpu', type=int, default=0)
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                    help='input batch size for training (default: 128)')
parser.add_argument('--test-batch-size', type=int, default=100, metavar='N',
                    help='input batch size for testing (default: 100)')
parser.add_argument('--epochs', type=int, default=100, metavar='N',
                    help='number of epochs to train')
parser.add_argument('--weight-decay', '--wd', default=2e-4,
                    type=float, metavar='W')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                    help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--epsilon', default=0.031,
                    help='perturbation')
parser.add_argument('--num_steps', default=20,
                    help='perturb number of steps')
parser.add_argument('--step_size', default=0.007,
                    help='perturb step size')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--model_dir', default='saved_models/baselines',
                    help='directory of model for saving checkpoint')
parser.add_argument('--save-freq', '-s', default=50, type=int, metavar='N',
                    help='save frequency')

args = parser.parse_args()

# settings
model_dir = os.path.join(args.model_dir, args.data.lower(), 'PGD_AT') 
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
    
log_dir = './log'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
    
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(0)
device = torch.device("cuda" if use_cuda else "cpu")
remote_save_dir=""
# if use_cuda:
#     torch.set_default_tensor_type('torch.cuda.FloatTensor')
#     device = torch.device('cuda')
#     torch.cuda.set_device(args.gpu)
#     print('Using Device: ', torch.cuda.get_device_name())
# else:
#     device = torch.device('cpu')
    
print('Using device:',device)

kwargs = {'num_workers': 5, 'pin_memory': True} if use_cuda else {}
torch.backends.cudnn.benchmark = True
print(args)

# setup data loader
if args.data=="cifar10":
    # transform_train = transforms.Compose([
    #     transforms.RandomCrop(32, padding=4),
    #     transforms.RandomHorizontalFlip(),
    #     transforms.ToTensor(),
    # ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    # trainset = torchvision.datasets.CIFAR10(root='../data_attack/', train=True, download=True, transform=transform_train)
    # train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=10)
    testset = torchvision.datasets.CIFAR10(root='../data_attack/', train=False, download=True, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=10)

    print('using saved dataset...')
    # with open('../../../../dataset_split.pkl', 'rb') as f:
    with open('dataset_split.pkl', 'rb') as f:
        pkl_data = pickle.load(f)
        train_dataset = pkl_data['train_ds']
        val_dataset = pkl_data['val_ds']       ## validation dataset
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=10)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=10)
    print("Loading done!")
elif args.data=="fmnist":
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    with open('dataset_fmnist_split.pkl', 'rb') as f:
        pkl_data = pickle.load(f)
        train_dataset = pkl_data['train_ds']
        val_dataset = pkl_data['val_ds']       ## validation dataset
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=10)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=10)
    testset = torchvision.datasets.FashionMNIST(root='../data_attack/', download=True, train=False, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=10)

def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    epoch_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        # calculate adversarial example
        data = projected_gradient_descent(model, data, args.epsilon, args.step_size, args.num_steps, np.inf)
        
        optimizer.zero_grad()
        
        adv_out = model(data)
        adv_loss = F.cross_entropy(adv_out, target)
        # calculate robust loss
        loss = adv_loss
        loss.backward()
        
        optimizer.step()

        # print progress
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]Adv Loss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), adv_loss.item()))
        
        epoch_loss += loss.item()*data.shape[0]
    return epoch_loss/(len(train_loader.dataset))

def adjust_learning_rate(optimizer, epoch):
    """decrease the learning rate"""
    lr = args.lr
    if epoch >= 100:
        lr = args.lr * 0.001
        if epoch==100:
            print(f"Changing learning rate to {lr}")
    elif epoch >= 90:
        lr = args.lr * 0.01
        if epoch==90:
            print(f"Changing learning rate to {lr}")
    elif epoch >= 75:
        lr = args.lr * 0.1
        if epoch==75:
            print(f"Changing learning rate to {lr}")
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        
def _pgd_whitebox(model,
                  X,
                  y,
                  epsilon=args.epsilon,
                  num_steps=args.num_steps,
                  step_size=args.step_size,
                  requires_error=True):
    model.eval()
    X_pgd = projected_gradient_descent(model, X, epsilon, step_size, num_steps, np.inf)
    out = model(X)
    err = (out.data.max(1)[1] != y.data).float().sum()
    if requires_error:
        err_pgd = (model(X_pgd).data.max(1)[1] != y.data).float().sum()
        return err, err_pgd
    else:
        return X_pgd

def eval_adv_test_whitebox(model, device, test_loader):
    model.eval()
    robust_err_total = 0
    natural_err_total = 0

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        # pgd attack
        X, y = Variable(data, requires_grad=True), Variable(target)
        err_natural, err_robust = _pgd_whitebox(model, X, y)
        robust_err_total += err_robust
        natural_err_total += err_natural
    print('natural_acc: ', 1 - natural_err_total / len(test_loader.dataset))
    print('robust_acc: ', 1- robust_err_total / len(test_loader.dataset))
    return 1 - natural_err_total / len(test_loader.dataset), 1- robust_err_total / len(test_loader.dataset)

def get_model_for_dataset(args):
    if args.data=="cifar10":
        return ResNet18()
    elif args.data=="fmnist":
        return NN_MNIST()
    
def main():
    
    model = get_model_for_dataset(args)
    print("Using model", model)
    
    model.to(device)
    log_path = 'log.txt'
    sys.stdout = open(log_path, 'w', 1)

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)   
    
    natural_acc = []
    robust_acc = []
    loss_list = []
    
    for epoch in range(1, args.epochs + 1):
        # adjust learning rate for SGD
        adjust_learning_rate(optimizer, epoch)
        
        start_time = time.time()

        # adversarial training
        loss = train(args, model, device, train_loader, optimizer, epoch)
        loss_list.append(loss)


        print('================================================================')

        natural_err_total, robust_err_total = eval_adv_test_whitebox(model, device, test_loader)

        print('using time:', time.time()-start_time)
        
        natural_acc.append(natural_err_total.detach().cpu())
        robust_acc.append(robust_err_total.detach().cpu())
        print('================================================================')
        
        # file_name = os.path.join(log_dir, 'train_stats.npy')
        # np.save(file_name, np.stack((np.array(natural_acc), np.array(robust_acc))))        

        # save checkpoint
        if epoch % args.save_freq == 0 or epoch >= 90:
            torch.save(model.state_dict(),
                       os.path.join(model_dir, 'PGD-AT-model-epoch{}.pt'.format(epoch)))
            torch.save(optimizer.state_dict(),
                       os.path.join(model_dir, 'PGD-AT-opt-checkpoint_epoch{}.tar'.format(epoch)))
            out_path = os.path.join(model_dir, 'PGD-AT-model-epoch{}.pt'.format(epoch))
            out_opt_path = os.path.join(model_dir, 'PGD-AT-opt-checkpoint_epoch{}.tar'.format(epoch))
            model_final_save_path = os.path.join(model_dir, 'model-final')
            subprocess.run(["cp", out_path, model_final_save_path])
            subprocess.run(["scp", "-r", out_path, remote_save_dir])
            subprocess.run(["scp", "-r", model_final_save_path, remote_save_dir])
            subprocess.run(["scp", "-r", out_opt_path, remote_save_dir])               


if __name__ == '__main__':
    main()
