'''
from /src/models/defense/baselines/TRADES
python train_trades.py
'''

from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Subset
from torch.utils.data import DataLoader
import numpy as np
import sys
import pickle
import time
import cleverhans
from cleverhans.torch.attacks.projected_gradient_descent import (
    projected_gradient_descent,
)
import subprocess

# from models.wideresnet import *
# from models.resnet import *
# from trades import trades_loss

from models.defense.baselines.TRADES.models.wideresnet import *
from models.defense.baselines.TRADES.models.resnet import *
from models.defense.baselines.TRADES.trades import trades_loss

from models.defense.nn_mnist import NN_MNIST


parser = argparse.ArgumentParser(description='PyTorch TRADES Adversarial Training')
parser.add_argument('-d','--data', type=str, default='fmnist')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                    help='input batch size for training (default: 128)')
parser.add_argument('--test-batch-size', type=int, default=128, metavar='N',
                    help='input batch size for testing (default: 128)')
parser.add_argument('--epochs', type=int, default=120, metavar='N',
                    help='number of epochs to train')
parser.add_argument('--weight-decay', '--wd', default=2e-4,
                    type=float, metavar='W')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                    help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--epsilon', default=0.031,
                    help='perturbation')
parser.add_argument('--num-steps', default=20,
                    help='perturb number of steps')
parser.add_argument('--step-size', default=0.007,
                    help='perturb step size')
parser.add_argument('--beta', default=6.0,
                    help='regularization, i.e., 1/lambda in TRADES')
parser.add_argument('--seed', type=int, default=0, metavar='S',
                    help='random seed (default: 0)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--save-freq', '-s', default=1, type=int, metavar='N',
                    help='save frequency')
parser.add_argument('-gpu', default=0, type=int, help='gpu:id to be used')
parser.add_argument('--init-epoch', default=1, help='Epoch to start with')



args = parser.parse_args()

torch.manual_seed(args.seed)

class RandomSubset:
    """
        Select a random subset of size K from the dataset.
    """
    def __init__(self, dataset):
        self.dataset = dataset
    
    def get_sample(self, K):
        idx_list = np.arange(len(self.dataset))
        choice = np.random.choice(idx_list, min(K, len(self.dataset)), replace=False).tolist()
        return Subset(self.dataset, choice)
    
    def split_ds(self, K):
        idx_list = np.arange(len(self.dataset))
        np.random.shuffle(idx_list)
        attack_idxs, rem_idxs = idx_list[:K], idx_list[K:]
        attack_set = Subset(self.dataset, attack_idxs)
        rem_set = Subset(self.dataset, rem_idxs)
        return (attack_set, rem_set)

print(args)
model_dir = f'./model_{args.data}' + str(args.beta)
args.beta = float(args.beta)
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
    
print('Using device:',device)
kwargs = {'num_workers': 5, 'pin_memory': True} if use_cuda else {}
print(kwargs)
print(args)

log_path = os.path.join(model_dir, "train_log.txt")
sys.stdout = open(log_path, 'w', 1)

if args.data == 'cifar10':
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform_test)
elif args.data == 'fmnist':
    testset = torchvision.datasets.FashionMNIST('data', download=True, train=False, transform=transforms.Compose([transforms.ToTensor()]))

test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, **kwargs)
sub = RandomSubset(testset)
eta = 10    ## percent to be attacked
eta_frac = eta/100
attack_size = int((len(testset)/100)*eta)
attack_ds, rem_ds = sub.split_ds(attack_size)

pkl_path = f'dataset{"_"+args.data if args.data!="cifar10" else ""}_split.pkl'
with open(pkl_path, 'rb') as f:
    dat = pickle.load(f)
    print("Loading datasets")
    
trainset = dat["train_ds"]
train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=False, **kwargs)
valset = dat["val_ds"]
val_loader = torch.utils.data.DataLoader(valset, batch_size=args.test_batch_size, shuffle=False, **kwargs)


print(f"Train set size is {len(trainset)} and val set size is {len(valset)}")

def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    start = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        # calculate robust loss
        loss = trades_loss(model=model,
                           x_natural=data,
                           y=target,
                           optimizer=optimizer,
                           step_size=args.step_size,
                           epsilon=args.epsilon,
                           perturb_steps=args.num_steps,
                           beta=args.beta)
        loss.backward()
        optimizer.step()

        
        # print progress
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()), end=" ")
            end = time.time()
            print(f"Time taken for this since prev log is {round(end-start, 3)}")
            start = time.time()


def eval_train(model, device, train_loader):
    model.eval()
    train_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            train_loss += F.cross_entropy(output, target, size_average=False).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    train_loss /= len(train_loader.dataset)
    print('Training: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        train_loss, correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))
    training_accuracy = correct / len(train_loader.dataset)
    return train_loss, training_accuracy


def eval_test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, size_average=False).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print('Validation: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    test_accuracy = correct / len(test_loader.dataset)
    return test_loss, test_accuracy


def adjust_learning_rate(optimizer, epoch):
    """decrease the learning rate"""
    lr = args.lr
    if epoch >= 75:
        lr = args.lr * 0.1
    if epoch >= 90:
        lr = args.lr * 0.01
    if epoch >= 100:
        lr = args.lr * 0.001
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def eval_test_acc(model, device, epoch):
    attack_dl = DataLoader(attack_ds, batch_size=256, shuffle=False)
    rem_dl = DataLoader(rem_ds, batch_size=256, shuffle=False)
    
    att_correct = 0
    for (x, y) in attack_dl:
        model.eval()
        x = x.to(device)
        y = y.to(device)
        x = projected_gradient_descent(model, x, args.epsilon, args.step_size, args.num_steps, np.inf)
        output = model(x)
        pred = output.max(1, keepdim=True)[1]
        att_correct += pred.eq(y.view_as(pred)).sum().item()
    attacked_accuracies = att_correct/len(attack_dl.dataset)

    # Evaluate unattacked accuracies
    unatt_correct = 0
    for (x, y) in rem_dl:
        model.eval()
        x = x.to(device)
        y = y.to(device)
        output = model(x)
        pred = output.max(1, keepdim=True)[1]
        unatt_correct += pred.eq(y.view_as(pred)).sum().item()
    unattacked_accuracies = unatt_correct/len(rem_dl.dataset)

    combined_accuracies = eta_frac*attacked_accuracies + (1-eta_frac)*unattacked_accuracies
    a = attacked_accuracies
    b = unattacked_accuracies
    c = combined_accuracies
    print(f"Model {epoch}: Attacked {round(a, 4)}, Unattacked {round(b, 4)}, Combined {round(c, 4)}")

    pkl_dict = {}
    pkl_dict["Attacked Accuracies"] = {epoch:attacked_accuracies*100}
    pkl_dict["Unattacked Accuracies"] = {epoch:unattacked_accuracies*100}
    pkl_dict["Combined Accuracies"] = {epoch:combined_accuracies*100}
    
    remote_save_dir = f"results/{args.data}/TRADES/beta_{args.beta}/"
    os.makedirs(remote_save_dir, exist_ok=True)
    pkl_path = os.path.join(remote_save_dir, 'test_accs.pkl')
    with open(pkl_path, 'wb') as f:
        pickle.dump(pkl_dict, f)

def main():
    if args.data == 'cifar10':
        model = ResNet18().to(device)
    elif args.data == 'fmnist':
        model = NN_MNIST().to(device)

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    # optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # sd = torch.load("model-cifar-ResNet_40k_comparison/model-resnet-epoch48.pt")
    # model.load_state_dict(sd)
    # optim_dict = torch.load("model-cifar-ResNet_40k_comparison/opt-resnet-checkpoint_epoch48.tar")
    # optimizer.load_state_dict(optim_dict)


    for epoch in range(1, args.epochs + 1):
        # adjust learning rate for SGD
        start = time.time()
        adjust_learning_rate(optimizer, epoch)

        # adversarial training
        train(args, model, device, train_loader, optimizer, epoch)

        # evaluation on natural examples
        print('================================================================')
        eval_train(model, device, train_loader)
        # eval_test(model, device, val_loader)
        eval_test_acc(model, device, epoch)
        end = time.time()
        print(f"Time taken for current epoch is {round(end-start, 3)}")
        print('================================================================')

        # save checkpoint
        if (epoch % args.save_freq == 0) or epoch>=100:
            torch.save(model.state_dict(),
                       os.path.join(model_dir, 'model-resnet-epoch{}.pt'.format(epoch)))
            torch.save(optimizer.state_dict(),
                       os.path.join(model_dir, 'opt-resnet-checkpoint_epoch{}.tar'.format(epoch)))


if __name__ == '__main__':
    main()
