from __future__ import print_function
import os
import pickle
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Variable
import path
import sys
import subprocess

import numpy as np
import time

sys.path.append("../../../../")
folder_path= (path.Path(__file__).abspath()).parent.parent
sys.path.append(folder_path)
folder_path= (path.Path(__file__).abspath()).parent.parent.parent
sys.path.append(folder_path)
folder_path= (path.Path(__file__).abspath()).parent.parent.parent.parent.parent.parent
# print(sys.path)
# for pth in sys.path:
#     print(pth)
from resnet import ResNet18, ResNet50
from nn_mnist import NN_MNIST

import cleverhans
from cleverhans.torch.attacks.projected_gradient_descent import projected_gradient_descent

# os.environ["CUDA_VISIBLE_DEVICES"]="2"

def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    epoch_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        # calculate adversarial example
        data = projected_gradient_descent(model, data, args.epsilon, args.step_size, args.num_steps, np.inf)
        
        optimizer.zero_grad()
        
        adv_out = model(data)
        adv_loss = F.cross_entropy(adv_out, target)
        # calculate robust loss
        loss = adv_loss
        loss.backward()
        
        optimizer.step()

        # print progress
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]Adv Loss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), adv_loss.item()))
        
        epoch_loss += loss.item()*data.shape[0]
    return epoch_loss/(len(train_loader.dataset))

def adjust_learning_rate(optimizer, epoch, args):
    """decrease the learning rate"""
    lr = args.lr
    if epoch >= 100:
        lr = args.lr * 0.001
        if epoch==100:
            print(f"Changing learning rate to {lr}")
    elif epoch >= 90:
        lr = args.lr * 0.01
        if epoch==90:
            print(f"Changing learning rate to {lr}")
    elif epoch >= 75:
        lr = args.lr * 0.1
        if epoch==75:
            print(f"Changing learning rate to {lr}")
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        
def _pgd_whitebox(args,
                  model,
                  X,
                  y,
                  epsilon,
                  num_steps,
                  step_size,
                  requires_error=True):
    model.eval()
    if not epsilon:
        epsilon = args.epsilon
    if not num_steps:
        num_steps = args.num_steps
    if not step_size:
        step_size = args.step_size
    X_pgd = projected_gradient_descent(model, X, epsilon, step_size, num_steps, np.inf)
    out = model(X)
    err = (out.data.max(1)[1] != y.data).float().sum()
    if requires_error:
        err_pgd = (model(X_pgd).data.max(1)[1] != y.data).float().sum()
        return err, err_pgd
    else:
        return X_pgd

def eval_adv_test_whitebox(model, device, test_loader):
    model.eval()
    robust_err_total = 0
    natural_err_total = 0

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        # pgd attack
        X, y = Variable(data, requires_grad=True), Variable(target)
        err_natural, err_robust = _pgd_whitebox(model, X, y)
        robust_err_total += err_robust
        natural_err_total += err_natural
    print('natural_acc: ', 1 - natural_err_total / len(test_loader.dataset))
    print('robust_acc: ', 1- robust_err_total / len(test_loader.dataset))
    return 1 - natural_err_total / len(test_loader.dataset), 1- robust_err_total / len(test_loader.dataset)

def get_model_for_dataset(args):
    if args.data=="cifar10":
        return ResNet18()
    elif args.data=="fmnist":
        return NN_MNIST()
    
def main():
    
    model = get_model_for_dataset(args)
    print("Using model", model)
    
    model.to(device)
    log_path = 'log.txt'
    sys.stdout = open(log_path, 'w', 1)

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)   
    
    natural_acc = []
    robust_acc = []
    loss_list = []
    
    for epoch in range(1, args.epochs + 1):
        # adjust learning rate for SGD
        adjust_learning_rate(optimizer, epoch)
        
        start_time = time.time()

        # adversarial training
        loss = train(args, model, device, train_loader, optimizer, epoch)
        loss_list.append(loss)


        print('================================================================')

        natural_err_total, robust_err_total = eval_adv_test_whitebox(model, device, test_loader)

        print('using time:', time.time()-start_time)
        
        natural_acc.append(natural_err_total.detach().cpu())
        robust_acc.append(robust_err_total.detach().cpu())
        print('================================================================')
        
        # file_name = os.path.join(log_dir, 'train_stats.npy')
        # np.save(file_name, np.stack((np.array(natural_acc), np.array(robust_acc))))        

        # save checkpoint
        if epoch % args.save_freq == 0 or epoch >= 90:
            torch.save(model.state_dict(),
                       os.path.join(model_dir, 'PGD-AT-model-epoch{}.pt'.format(epoch)))
            if epoch>=95:
                torch.save(optimizer.state_dict(),
                       os.path.join(model_dir, 'PGD-AT-opt-checkpoint_epoch{}.tar'.format(epoch)))
                
            out_path = os.path.join(model_dir, 'PGD-AT-model-epoch{}.pt'.format(epoch))
            # out_opt_path = os.path.join(model_dir, 'PGD-AT-opt-checkpoint_epoch{}.tar'.format(epoch))
            model_final_save_path = os.path.join(model_dir, 'model-final')
            subprocess.run(["cp", out_path, model_final_save_path])
            subprocess.run(["scp", "-r", out_path, remote_save_dir])
            subprocess.run(["scp", "-r", model_final_save_path, remote_save_dir])
            # subprocess.run(["scp", "-r", out_opt_path, remote_save_dir])
            if epoch < 90:
                subprocess.run(["rm", out_path])
                       


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PyTorch CIFAR MART Defense')
    parser.add_argument('-d', '--data', type=str, default='fmnist', choices=['cifar10', 'fmnist'])
    # parser.add_argument('-gpu', type=int, default=0)
    parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 128)')
    parser.add_argument('--test-batch-size', type=int, default=100, metavar='N',
                        help='input batch size for testing (default: 100)')
    parser.add_argument('--epochs', type=int, default=100, metavar='N',
                        help='number of epochs to train')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--epsilon', default=0.3,
                        help='perturbation')
    parser.add_argument('--num_steps', default=40,
                        help='perturb number of steps')
    parser.add_argument('--step_size', default=0.01,
                        help='perturb step size')
    parser.add_argument('--log-interval', type=int, default=100, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--model_dir', default='saved_models/baselines',
                        help='directory of model for saving checkpoint')
    parser.add_argument('--save-freq', '-s', default=50, type=int, metavar='N',
                        help='save frequency')

    args = parser.parse_args()

    # settings
    model_dir = os.path.join(args.model_dir, args.data.lower(), 'PGD_AT') 
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
        
    log_dir = './log'
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
        
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(0)
    device = torch.device("cuda" if use_cuda else "cpu")
    remote_save_dir=None
    # if use_cuda:
    #     torch.set_default_tensor_type('torch.cuda.FloatTensor')
    #     device = torch.device('cuda')
    #     torch.cuda.set_device(args.gpu)
    #     print('Using Device: ', torch.cuda.get_device_name())
    # else:
    #     device = torch.device('cpu')
        
    print('Using device:',device)

    kwargs = {'num_workers': 5, 'pin_memory': True} if use_cuda else {}
    torch.backends.cudnn.benchmark = True
    print(args)

    # setup data loader
    if args.data=="cifar10":
        # transform_train = transforms.Compose([
        #     transforms.RandomCrop(32, padding=4),
        #     transforms.RandomHorizontalFlip(),
        #     transforms.ToTensor(),
        # ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
        ])
        # trainset = torchvision.datasets.CIFAR10(root='../data_attack/', train=True, download=True, transform=transform_train)
        # train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=10)
        testset = torchvision.datasets.CIFAR10(root='../data_attack/', train=False, download=True, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=10)

        print('using saved dataset...')
        # with open('../../../../dataset_split.pkl', 'rb') as f:
        with open('dataset_split.pkl', 'rb') as f:
            pkl_data = pickle.load(f)
            train_dataset = pkl_data['train_ds']
            val_dataset = pkl_data['val_ds']       ## validation dataset
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=10)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=10)
        print("Loading done!")
    elif args.data=="fmnist":
        transform_test = transforms.Compose([
            transforms.ToTensor(),
        ])
        with open('dataset_fmnist_split.pkl', 'rb') as f:
            pkl_data = pickle.load(f)
            train_dataset = pkl_data['train_ds']
            val_dataset = pkl_data['val_ds']       ## validation dataset
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=10)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=10)
        testset = torchvision.datasets.FashionMNIST(root='../data_attack/', download=True, train=False, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=10)
    
    main()
