from __future__ import print_function
import os
import argparse
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from models.net_mnist import *
from trades import *

parser = argparse.ArgumentParser(description='PyTorch MNIST TRADES Adversarial Training (Binary)')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                    help='input batch size for training (default: 128)')
parser.add_argument('--test-batch-size', type=int, default=128, metavar='N',
                    help='input batch size for testing (default: 128)')
parser.add_argument('--epochs', type=int, default=50, metavar='N',
                    help='number of epochs to train')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--epsilon', default=0.1,
                    help='perturbation')
parser.add_argument('--num-steps', default=20,
                    help='perturb number of steps')
parser.add_argument('--step-size', default=0.01,
                    help='perturb step size')
parser.add_argument('--beta', default=5.0,
                    help='regularization, i.e., lambda in TRADES for binary case')
parser.add_argument('--weight-decay', '--wd', default=0.0,
                    type=float, metavar='W', help='weight decay')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=20, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--model-dir', default='./model-mnist-net-two-class',
                    help='directory of model for saving checkpoint')
parser.add_argument('--save-freq', '-s', default=10, type=int, metavar='N',
                    help='save frequency (default: 10)')
parser.add_argument('--layer_snr_weight_type', default="exp", type=str,
                    help='type of layer snr weight')
parser.add_argument('--base', default=1.5, type=float,
                        help='base number of ')
parser.add_argument('--use_snr', action='store_true',
                    help='use snr loss or not')
parser.add_argument('--snr_layers', default=None, nargs='+', type=str,
                        help='the layers need to compute snr')
parser.add_argument('--snr_weight', default=10, type=float,
                    help='weight of snr loss')
args = parser.parse_args()

# settings
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

# download MNIST dataset
dataset_train = datasets.MNIST('../data', train=True, download=True,
                               transform=transforms.Compose([transforms.ToTensor()]))

dataset_test = datasets.MNIST('../data', train=False,
                              transform=transforms.Compose([transforms.ToTensor()]))


# select class '1' and class '3'
def get_same_index(target, label_1, label_2):
    label_indices = []

    for i in range(len(target)):
        if target[i] == label_1:
            label_indices.append(i)
        if target[i] == label_2:
            label_indices.append(i)
    return label_indices


# choose 2 classes - '1', '3'
idx_train = get_same_index(dataset_train.targets, 1, 3)
dataset_train.targets = dataset_train.targets[idx_train] - 2
dataset_train.data = dataset_train.data[idx_train]

# choose 2 classes - '1', '3'
idx_test = get_same_index(dataset_test.targets, 1, 3)
dataset_test.targets = dataset_test.targets[idx_test] - 2
dataset_test.data = dataset_test.data[idx_test]

# set up dataloader
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=args.test_batch_size, shuffle=True, **kwargs)


def perturb_hinge(net, x_nat):
    # Perturb function based on (E[\phi(f(x)f(x'))])
    # init with random noise
    net.eval()
    x = x_nat.detach() + 0.001 * torch.randn(x_nat.shape).cuda().detach()
    for _ in range(args.num_steps):
        x.requires_grad_()
        with torch.enable_grad():
            # perturb based on hinge loss
            loss = torch.mean(torch.clamp(1 - net(x).squeeze(1) * (net(x_nat).squeeze(1) / args.beta), min=0))
        grad = torch.autograd.grad(loss, [x])[0]
        x = x.detach() + args.step_size * torch.sign(grad.detach())
        x = torch.min(torch.max(x, x_nat - args.epsilon), x_nat + args.epsilon)
        x = torch.clamp(x, 0.0, 1.0)
    net.train()
    return x


def perturb_logistic(net, x_nat, target):
    # Perturb function based on logistic loss
    # init with random noise
    net.eval()
    x = x_nat.detach() + 0.001 * torch.randn(x_nat.shape).cuda().detach()
    for _ in range(args.num_steps):
        x.requires_grad_()
        with torch.enable_grad():
            # perturb based on logistic loss
            loss = torch.mean(1 + torch.exp(-1.0 * target.float() * net(x).squeeze(1)))
        grad = torch.autograd.grad(loss, [x])[0]
        x = x.detach() + args.step_size * torch.sign(grad.detach())
        x = torch.min(torch.max(x, x_nat - args.epsilon), x_nat + args.epsilon)
        x = torch.clamp(x, 0.0, 1.0)
    net.train()
    return x


def get_snr_loss(layer_outputs_natural, layer_outputs_robust, layer_snr_weight_type, base):
    snr_loss = 0.0
    nsr = []
    cnt = 0
    for output_natural, output_robust in zip(layer_outputs_natural, layer_outputs_robust):
        cnt += 1
        #if cnt >= len(layer_outputs_natural) - 1:
        #    median_values = torch.median(output_natural, dim=1, keepdim=True)[0]
        #    median_values_broadcasted = median_values.expand(-1, output_natural.size(1))
        #    output_natural = output_natural - median_values_broadcasted
        output_natural = output_natural.view(output_natural.size(0), -1)
        output_robust = output_robust.view(output_robust.size(0), -1)
        noise = torch.abs(output_natural - output_robust) 
        current_snr_loss = torch.mean(noise/ torch.pow(output_natural, 2) + 1)
        if layer_snr_weight_type == "exp":
            snr_loss += current_snr_loss * (base ** cnt)
        elif layer_snr_weight_type == "muln":
            snr_loss += current_snr_loss * cnt 
        elif layer_snr_weight_type == "sum":
            if cnt == len(layer_outputs_natural):
                snr_loss += current_snr_loss * 10
            else:
                snr_loss += current_snr_loss
        else:
            raise NameError("layer_snr_weight_type is illegal")
        nsr.append(current_snr_loss)
    return snr_loss, nsr

def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    if args.use_snr:
        train_layer_snr = [0 for _ in range(len(args.snr_layers))]
        mean_snr_loss = 0
    train_number = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        train_number += 1
        data, target = data.to(device), target.to(device)

        # perturb input x
        x_adv = perturb_hinge(net=model, x_nat=data)

        # optimize
        optimizer.zero_grad()

        if not args.use_snr:
            output = model(data)
            loss_natural = torch.mean(torch.clamp(1 - output.squeeze(1) * target.float(), min=0))
            loss_robust = torch.mean(torch.clamp(1 - model(x_adv).squeeze(1) * (model(data).squeeze(1) / args.beta), min=0))
            loss = loss_natural + loss_robust
        else:
            print("Use Snr loss")
            layer_outputs = []
            def hook_fn(module, input, output):
                layer_outputs.append(output)
            hooks = []
            for name, layer in model.named_modules():
                if name in args.snr_layers:
                    hooks.append(layer.register_forward_hook(hook_fn))
            output = model(data)
            for hook in hooks:
                hook.remove()
            layer_outputs_natural = layer_outputs.copy()
            layer_outputs = []
            hooks = []
            for name, layer in model.named_modules():
                if name in args.snr_layers:
                    hooks.append(layer.register_forward_hook(hook_fn))
            adv_output = model(x_adv)
            for hook in hooks:
                hook.remove()
            layer_outputs_robust = layer_outputs.copy()
            snr_loss, nsr = get_snr_loss(layer_outputs_natural, layer_outputs_robust, args.layer_snr_weight_type, args.base)

            loss_natural = torch.mean(torch.clamp(1 - output.squeeze(1) * target.float(), min=0))
            #loss_robust = torch.mean(torch.clamp(1 - model(x_adv).squeeze(1) * (model(data).squeeze(1) / args.beta), min=0))
            loss = loss_natural + args.snr_weight * snr_loss

            train_layer_snr = [train_layer_snr[i] + nsr[i] for i in range(len(train_layer_snr))]
            mean_snr_loss += snr_loss

        loss.backward()
        optimizer.step()

        # print progress
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
    
    if args.use_snr:
        train_layer_snr = [x / train_number for x in train_layer_snr]
        mean_snr_loss = mean_snr_loss / train_number

    return mean_snr_loss, train_layer_snr
            
        


def eval_train(model, device, train_loader):
    """
    evaluate model on training data
    """
    model.eval()
    train_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            train_loss += torch.sum(torch.clamp(1 - target.float() * output.squeeze(1), min=0))
            pred = torch.sign(output).long()
            correct += pred.eq(target.view_as(pred)).sum().item()
    train_loss /= len(train_loader.dataset)
    # print loss and accuracy
    print('Training: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        train_loss, correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))
    return train_loss, correct / len(train_loader.dataset)


def eval_test(model, device, test_loader):
    """
    evaluate model on test data
    """
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += torch.sum(torch.clamp(1 - target.float() * output.squeeze(1), min=0))
            pred = torch.sign(output).long()
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print('Test: Average loss: {:.6f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss, correct / len(test_loader.dataset)


def eval_adv_test(model, device, test_loader):
    """
    evaluate model on test (adversarial) data
    """
    model.eval()
    adv_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            # use pgd attack on logistic loss
            x_perturb_linf = perturb_logistic(net=model, x_nat=data, target=target)
            output = model(x_perturb_linf)
            # adversarial loss (E[\phi(f(x)f(x'))])
            adv_loss += torch.sum(torch.clamp(1 - model(x_perturb_linf).squeeze(1) * (model(data).squeeze(1) / args.beta), min=0))
            pred = torch.sign(output).long()
            correct += pred.eq(target.view_as(pred)).sum().item()

    adv_loss /= len(test_loader.dataset)
    print('Test: Average Adv loss: {:.6f}, Robust Accuracy: {}/{} ({:.0f}%)'.format(
        adv_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return adv_loss, correct / len(test_loader.dataset)

def main():
    wandb.init(project="unsupervised-robust-learning", name="Trades-mnist-tiny_cnn-w1")
    model = Net_binary().to(device)
    wandb.watch(model)
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

    for epoch in range(1, args.epochs + 1):
        # adversarial training
        mean_snr_loss, train_layer_snr =train(args, model, device, train_loader, optimizer, epoch)

        # evaluation on natural and adversarial examples
        print('================================================================')
        train_clear_loss, train_clear_acc = eval_train(model, device, train_loader)
        test_clear_loss, test_clear_acc = eval_test(model, device, test_loader)
        test_robust_loss, test_robust_acc = eval_adv_test(model, device, test_loader)
        print('================================================================')
        wandb.log({
                "epoch": epoch,
                "train_clear_loss": train_clear_loss,
                "train_clear_acc": train_clear_acc,
                "test_clear_loss": test_clear_loss,
                "test_clear_acc": test_clear_acc,
                "test_robust_loss" : test_robust_loss,
                "test_robust_acc": test_robust_acc,
                "learning_rate": optimizer.param_groups[0]['lr'],
                "snr_loss" : mean_snr_loss,
                "layer1_snr": train_layer_snr[0],
                "layer2_snr": train_layer_snr[1],
                "layer3_snr": train_layer_snr[2],
                "layer4_snr": train_layer_snr[3]
            })
    torch.save(model.state_dict(), "/home/verification/models/Trades/mnist_tiny_cnn/mnist_tiny_cnn_w1.pt")
if __name__ == '__main__':
    main()
  