# Ref: https://github.com/BorealisAI/advertorch/blob/master/advertorch_examples/tutorial_train_mnist.py
from __future__ import print_function

from tqdm import tqdm
import os
import argparse
import wandb
import numpy as np

import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

from advertorch.context import ctx_noparamgrad_and_eval
from advertorch.test_utils import LeNet5
# from advertorch_examples.utils import TRAINED_MODEL_PATH
from torch import nn

from utils.data_loader import get_central_data
from utils.config import CHECKPOINT_ROOT
from utils.utils import AverageMeter, set_seed


def parse_args():
    parser = argparse.ArgumentParser(description='Train MNIST')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--train', default="cln", help="cln | adv")
    parser.add_argument('--data', default='digits', choices=['digits', 'domainnet'])

    parser.add_argument('--lr', default=1e-2, type=float)
    parser.add_argument('--l1_reg', default=0., type=float)
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('--train_batch_size', default=50, type=int)
    parser.add_argument('--eps', default=None, type=float)
    parser.add_argument('--test_batch_size', default=256, type=int)
    parser.add_argument('--adv_lmbd', default=0.5, type=float)

    parser.add_argument('--domain', default='MNIST', type=str)
    parser.add_argument('--test_domain', default=None, type=str)
    parser.add_argument('--model', default='convnet')
    parser.add_argument('--log_interval', default=200, type=int)
    parser.add_argument('--val_ratio', default=0.1)
    parser.add_argument('--no_log', action='store_true')
    parser.add_argument('--attack', default='LinfPGD')
    args = parser.parse_args()
    return args


class AvgChannel(nn.Module):
    def __init__(self):
        super(AvgChannel, self).__init__()

    def forward(self, x):
        return x.mean(1)

def main(args):
    run_name = f'{args.data}-{args.domain}_{args.model}_{args.train}_lr{args.lr}_s{args.seed}'
    if args.l1_reg > 0.:
        run_name += f'_l1-{args.l1_reg}'
    if args.adv_lmbd != 0.5:
        run_name += f'_adv-{args.adv_lmbd}'

    wandb.init(project='AdvTransfer', name=run_name,
               config={**vars(args)},
               mode='offline' if args.no_log else 'online')

    set_seed(args.seed)
    rng = np.random.RandomState(args.seed)
    device = torch.device("cuda")
    model_filename = f"{run_name}.pt"
    if args.train == "cln":
        flag_advtrain = False
    elif args.train == "adv":
        flag_advtrain = True
    else:
        raise RuntimeError()
    if args.test:
        flag_advtrain = True

    train_set, _ = get_central_data(args.data, [args.domain])
    train_set = train_set[0]
    set_idxs = rng.permutation(np.arange(len(train_set)))
    val_set = Subset(train_set, set_idxs[:int(args.val_ratio * len(train_set))])
    train_set = Subset(train_set, set_idxs[int(args.val_ratio * len(train_set)):])
    train_loader = DataLoader(
        train_set, batch_size=args.train_batch_size, shuffle=True)
    val_loader = DataLoader(
        val_set, batch_size=args.train_batch_size, shuffle=False)

    if args.model == 'lenet5':
        model = LeNet5()
    elif args.model == 'convnet':
        from utils.models import DigitModel
        model = DigitModel()
    elif args.model == 'alex':
        from utils.models import AlexNet
        model = AlexNet()
    elif args.model == 'linear':
        assert args.data == 'digits'
        model = nn.Sequential(
            AvgChannel(),
            nn.Flatten(),
            nn.Linear(28 * 28, 10)
        )
    elif args.model == 'mlp2':
        assert args.data == 'digits'
        model = nn.Sequential(
            AvgChannel(),
            nn.Flatten(),
            nn.Linear(28 * 28, 256),
            nn.ReLU(),
            nn.Linear(256, 10),
        )
    else:
        raise NotImplementedError(f"model: {args.model}")

    if flag_advtrain:
        # from advertorch.attacks import LinfPGDAttack
        # adversary = LinfPGDAttack(
        #     model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.3,
        #     nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0.0,
        #     clip_max=1.0, targeted=False)
        from utils.attacks import AdversaryCreator

        adversary = AdversaryCreator(args.attack)(model, eps=args.eps)
    else:
        adversary = None

    if args.test:
        fp = os.path.join(CHECKPOINT_ROOT, model_filename)
        loaded = torch.load(fp)
        print(f"load from {fp}")
        model.load_state_dict(loaded['state_dict'])
        model.to(device)

        _, test_set = get_central_data(args.data, [args.test_domain])
        test_set = test_set[0]
        test_loader = DataLoader(
            test_set, batch_size=args.test_batch_size, shuffle=True)

        print(f"==== Start test ====")
        test_res = run_test(
            adversary, device, flag_advtrain, model, test_loader, 'test')

        wandb.summary = test_res
        wandb.finish()
        exit(0)

    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    lr_sch = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
    if args.l1_reg > 0.:
        params = torch.cat([param.view(-1) for param in model.parameters()])
    else:
        params = None

    best_val_acc = -1
    print(f"==== Start training ====")
    for epoch in range(args.epochs):
        model.train()
        loss_mt = AverageMeter()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            # ori = data

            output = model(data)
            loss = F.cross_entropy(
                output, target, reduction='mean')

            if flag_advtrain:
                # when performing attack, the model needs to be in eval mode
                # also the parameters should NOT be accumulating gradients
                with ctx_noparamgrad_and_eval(model):
                    data = adversary.perturb(data, target)

                output = model(data)
                lmbd = args.adv_lmbd
                loss = (1 - lmbd) * loss + lmbd * F.cross_entropy(
                    output, target, reduction='mean')
            if args.l1_reg > 0.:
                loss = loss + args.l1_reg * torch.norm(params, p=1)
            loss.backward()
            optimizer.step()
            loss_mt.append(loss.item())

        print(f'[{epoch}/{args.epochs}] Loss: {loss.item():.6f}')

        lr_sch.step()
        val_res = run_test(adversary, device, flag_advtrain, model, val_loader, 'val')

        if flag_advtrain:
            val_acc = val_res['adv acc']
        else:
            val_acc = val_res['cln acc']
        wandb.log({'val_acc': val_acc}, commit=False)

        if val_acc > best_val_acc:
            fp = os.path.join(CHECKPOINT_ROOT, model_filename)
            print(f"save to {fp}")
            torch.save(
                {
                    "state_dict": model.state_dict(),
                    "best_val_acc": best_val_acc,
                    "epoch": epoch,
                    "lr": lr_sch.get_last_lr(),
                    **val_res
                },
                fp)
            best_val_acc = val_acc
        wandb.log({'epoch': epoch})


def run_test(adversary, device, flag_advtrain, model, test_loader, set_name):
    model.eval()
    test_clnloss = 0
    clncorrect = 0
    if flag_advtrain:
        test_advloss = 0
        advcorrect = 0
    for clndata, target in tqdm(test_loader, desc=set_name):
        clndata, target = clndata.to(device), target.to(device)
        with torch.no_grad():
            output = model(clndata)
        test_clnloss += F.cross_entropy(
            output, target, reduction='sum').item()
        pred = output.max(1, keepdim=True)[1]
        clncorrect += pred.eq(target.view_as(pred)).sum().item()

        if flag_advtrain:
            advdata = adversary.perturb(clndata, target)
            with torch.no_grad():
                output = model(advdata)
            test_advloss += F.cross_entropy(
                output, target, reduction='sum').item()
            pred = output.max(1, keepdim=True)[1]
            advcorrect += pred.eq(target.view_as(pred)).sum().item()

    test_clnloss /= len(test_loader.dataset)
    print(f'{set_name} set: '
          f'avg cln loss: {test_clnloss:.4f},'
          f' cln acc: {clncorrect}/{len(test_loader.dataset)} '
          f'({100. * clncorrect / len(test_loader.dataset):.0f}%)',
          )
    res = {
        'cln loss': test_clnloss,
        'cln acc': 100. * clncorrect / len(test_loader.dataset)
    }
    if flag_advtrain:
        test_advloss /= len(test_loader.dataset)
        print(f'{set_name} set: '
              f'avg adv loss: {test_advloss:.4f},'
              f' adv acc: {advcorrect}/{len(test_loader.dataset)} '
              f'({100. * advcorrect / len(test_loader.dataset):.0f}%)',
              )
        res['adv loss'] = test_advloss
        res['adv acc'] = 100. * advcorrect / len(test_loader.dataset)
    wandb.log(res, commit=False)
    return res


if __name__ == '__main__':
    args = parse_args()
    main(args)
