"""
References: Opacus Tutorial, https://opacus.ai/
"""
import argparse
import os

import torch
import torch.nn as nn
import torch.optim as optim

from vgg import VGG
from dataset import CIFAR10, SVHN

from opacus import PrivacyEngine
from opacus.utils.module_modification import convert_batchnorm_modules
from opacus.utils.uniform_sampler import UniformWithReplacementSampler

import numpy as np
from tqdm import tqdm

def accuracy(preds, labels):
    return (preds == labels).mean()

def test(model, test_loader, device):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    losses = []
    top1_acc = []

    with torch.no_grad():
        for images, target in test_loader:
            images = images.to(device)
            target = target.to(device)

            output = model(images)
            loss = criterion(output, target)
            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
            labels = target.detach().cpu().numpy()
            acc = accuracy(preds, labels)

            losses.append(loss.item())
            top1_acc.append(acc)

    top1_avg = np.mean(top1_acc)

    print(
        f"\tTest set:"
        f"Loss: {np.mean(losses):.6f} "
        f"Acc: {top1_avg * 100:.6f} "
    )
    return np.mean(top1_acc)

def train_cls_opacus(trainset, testloader, save_dir, name=None):
    MAX_GRAD_NORM = args.max_grad_norm
    EPSILON = args.epsilon
    DELTA = args.delta
    EPOCHS = args.epochs

    LR = args.lr
    NUM_WORKERS = 2

    BATCH_SIZE = args.batch_size
    VIRTUAL_BATCH_SIZE = 512
    assert VIRTUAL_BATCH_SIZE % BATCH_SIZE == 0 # VIRTUAL_BATCH_SIZE should be divisible by BATCH_SIZE
    N_ACCUMULATION_STEPS = int(VIRTUAL_BATCH_SIZE / BATCH_SIZE)

    SAMPLE_RATE = BATCH_SIZE / len(trainset)

    trainloader = torch.utils.data.DataLoader(
        trainset,
        num_workers=NUM_WORKERS,
        batch_sampler=UniformWithReplacementSampler(
            num_samples=len(trainset),
            sample_rate=SAMPLE_RATE,
        ),
    )

    trainloader_eval = torch.utils.data.DataLoader(
        trainset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
    )

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net = VGG('VGG19')
    net = convert_batchnorm_modules(net)
    net.to(device).train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.RMSprop(net.parameters(), lr=LR)

    privacy_engine = PrivacyEngine(
        net,
        sample_rate=SAMPLE_RATE * N_ACCUMULATION_STEPS,
        epochs = EPOCHS,
        target_epsilon = EPSILON,
        target_delta = DELTA,
        max_grad_norm=MAX_GRAD_NORM,
    )
    privacy_engine.attach(optimizer)

    print(f"Using sigma={privacy_engine.noise_multiplier} and C={MAX_GRAD_NORM}")

    losses = []
    top1_acc = []
    for epoch in tqdm(range(EPOCHS), desc="Epoch", unit="epoch"):
        for i, (images, target) in enumerate(trainloader):        
            images = images.to(device)
            target = target.to(device)

            # compute output
            output = net(images)
            loss = criterion(output, target)
            
            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
            labels = target.detach().cpu().numpy()
            
            # measure accuracy and record loss
            acc = accuracy(preds, labels)

            losses.append(loss.item())
            top1_acc.append(acc)
            
            loss.backward()
                
            # take a real optimizer step after N_VIRTUAL_STEP steps t
            if ((i + 1) % N_ACCUMULATION_STEPS == 0) or ((i + 1) == len(trainloader)):
                optimizer.step()
            else:
                optimizer.virtual_step() # take a virtual step

            if i % 200 == 0:
                epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(DELTA)
                print(
                    f"\tTrain Epoch: {epoch} \t"
                    f"Loss: {np.mean(losses):.6f} "
                    f"Acc@1: {np.mean(top1_acc) * 100:.6f} "
                    f"(ε = {epsilon:.2f}, δ = {DELTA})"
                )
    test(net, trainloader_eval, device)
    test(net, testloader, device)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    torch.save(net.state_dict(), os.path.join(save_dir, name if name is not None else 'final_model.pt'))


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Classifier training")
    parser.add_argument("--dataset", type=str)
    parser.add_argument("--load_from", type=str)
    parser.add_argument("--save_to", type=str, default='models/cifar')
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--delta", type=float, default=1e-5)
    parser.add_argument("--epsilon", type=float, default=50.0)
    parser.add_argument("--max_grad_norm", type=float, default=1.2)


    args = parser.parse_args()

    if args.dataset == 'cifar10':
        dataset = CIFAR10()
        dataset.split()
    elif args.dataset == 'svhn':
        dataset = SVHN()
        dataset.split()
    split_loaders, testloader = dataset.get_dataloaders(batch_size=args.batch_size, split=True)

    for i, trainloader in enumerate(split_loaders):
        train_cls_opacus(dataset.split_trainsets[i], testloader, save_dir=args.save_to, name='model_'+str(i)+'.pt')

    