import os
import argparse
import torch
import torch.optim as optim
import torchvision
from torchsummary import summary
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm import tqdm
from mnist.mnist_model import *
from opacus import PrivacyEngine
from utils.data_helper import *
from utils.accountant_helper import *
from utils.math_helper import *


def train(device, train_loader, model, optimizer, epoch):
    model.train()
    correct = 0
    for batch_idx, (image, target) in enumerate(tqdm(train_loader)):
        image = image.to(device)
        label = target.to(device)
        output = model(image)
        _, predicted = torch.max(output.data, 1)
        correct += (predicted == label).sum().item()

        loss = nn.functional.cross_entropy(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('Train Epoch: {} \tLoss: {:.6f} \tAccuracy: {}/{} ( {:.2f}% )'.format(epoch, loss.item(), correct,
                                                                                len(train_loader.dataset),
                                                                                100. * correct / len(
                                                                                    train_loader.dataset)))


def train_dp(args, device, train_loader, model, optimizer, epoch):
    model.train()
    correct = 0
    for batch_idx, (image, target) in enumerate(tqdm(train_loader)):
        image = image.to(device)
        label = target.to(device)
        output = model(image)
        _, predicted = torch.max(output.data, 1)
        correct += (predicted == label).sum().item()

        loss = nn.functional.cross_entropy(output, label)
        optimizer.zero_grad()
        loss.backward()
        nabla = compute_norm(model)
        optimizer.step(device=device, group=label, batch_size=args.batch_size, nabla=nabla,
                       noise_multiplier=args.noise_multiplier)

    print('Train Epoch: {} \tLoss: {:.6f} \tAccuracy: {}/{} ( {:.2f}% )'.format(epoch, torch.mean(loss), correct,
                                                                                len(train_loader.dataset),
                                                                                100. * correct / len(
                                                                                    train_loader.dataset)))


def test(device, test_loader, model):
    model.eval()
    test_loss = 0
    correct = defaultdict(int)
    total = defaultdict(int)

    with torch.no_grad():
        for sample in test_loader:
            image, label = sample
            image = image.to(device)
            label = label.to(device)
            group = label
            output = model(image)
            _, predicted = torch.max(output.data, 1)
            for idx, g in enumerate(group):
                correct[g.item()] += predicted[idx].eq(label[idx]).item()
                total[g.item()] += 1
            test_loss += nn.functional.cross_entropy(output, label)

    cqt = sum(correct.values())
    test_loss /= len(test_loader.dataset)

    print('Test: \tLoss: {:.6f} \tAccuracy: {}/{} ( {:.2f}% )'.format(test_loss, cqt, len(test_loader.dataset),
                                                                      100. * cqt / len(test_loader.dataset)))
    for g, c in sorted(correct.items(), key=lambda x: x[0]):
        print('Test set: {} \tAccuracy: {}/{} ( {:.2f}% )'.format(g, c, total[g], 100. * c / total[g]))

    return 100. * cqt / len(test_loader.dataset)


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='MNIST')
    parser.add_argument('--device', default=1, type=int, choices=[-1, 0, 1, 2, 3])
    parser.add_argument('--seed', default=7, type=int)
    parser.add_argument('--data-dir', default='../../mnist/', type=str)
    parser.add_argument('--batch-size', default=256, type=int)
    parser.add_argument('--schema', default='vanilla', type=str,
                        choices=['dpsgd', 'dp-fedavg', 'opt-q', 'dpsgd-f', 'fairdp'])
    parser.add_argument('--balance', action='store_true')
    parser.add_argument('--lr', default=0.01, type=float)
    parser.add_argument('--epoch', default=60, type=int)
    # Differential privacy settings
    parser.add_argument('--l2-norm-clip', default=1.0, type=float)
    parser.add_argument('--noise-multiplier', default=1.0, type=float)
    parser.add_argument('--num-microbatches', default=256, type=int)
    parser.add_argument('--delta', default=1e-5, type=float)
    args = parser.parse_args()

    print('Argument =', vars(args))
    num_classes = 10

    device = 'cpu'
    torch.manual_seed(args.seed)
    if args.device != -1:
        device = 'cuda'
        torch.cuda.manual_seed(args.seed)

    kwargs = {'batch_size': args.batch_size}
    if args.device != -1:
        kwargs.update({'num_workers': 4, 'pin_memory': True})

    normalize = torchvision.transforms.Normalize((0.1307,), (0.3081,))
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        normalize
    ])
    train_dataset = torchvision.datasets.MNIST(args.data_dir, train=True, transform=transform, download=True)
    test_dataset = torchvision.datasets.MNIST(args.data_dir, train=False, transform=transform)
    if args.balance:
        sampler = torch.utils.data.RandomSampler(train_dataset, replacement=True)
        train_loader = torch.utils.data.DataLoader(train_dataset, **kwargs, sampler=sampler, drop_last=True)
    else:
        train_imba_dataset = sample_train_imba_data(num_classes, train_dataset)
        sampler = torch.utils.data.RandomSampler(train_imba_dataset, replacement=True)
        train_loader = torch.utils.data.DataLoader(train_imba_dataset, **kwargs, sampler=sampler, drop_last=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, **kwargs)
    # print
    # if args.balance:
    #     print_class_size(num_classes, train_dataset, test_dataset)
    # else:
    #     print_class_size(num_classes, train_imba_dataset, test_dataset)

    model = MNISTNet(num_classes).to(device)

    # summary(model, (1, 28, 28))
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.5)
    if args.schema == 'dpsgd':
        privacy_engine = PrivacyEngine(
            args=args,
            module=model,
            batch_size=args.batch_size,
            sample_size=int(len(train_loader.dataset) / args.batch_size) * args.epoch,
            alphas=[1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] + list(range(5, 64)) + [128, 256, 512],
            noise_multiplier=args.noise_multiplier,
            max_grad_norm=args.l2_norm_clip,
            target_delta=args.delta
        )
    elif args.schema == 'dp-fedavg':
        layer = int(sum(1 for l in model.parameters()) / 2)
        l2_norm_clip = args.l2_norm_clip / math.sqrt(layer)
        privacy_engine = PrivacyEngine(
            args=args,
            module=model,
            batch_size=args.batch_size,
            sample_size=int(len(train_loader.dataset) / args.batch_size) * args.epoch,
            alphas=[1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] + list(range(5, 64)) + [128, 256, 512],
            noise_multiplier=args.noise_multiplier,
            max_grad_norm=[l2_norm_clip] * 4,
            target_delta=args.delta
        )
    elif args.schema == 'opt-q':
        privacy_engine = PrivacyEngine(
            args=args,
            module=model,
            batch_size=args.batch_size,
            sample_size=int(len(train_loader.dataset) / args.batch_size) * args.epoch,
            alphas=[1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] + list(range(5, 64)) + [128, 256, 512],
            noise_multiplier=args.noise_multiplier,
            max_grad_norm=args.l2_norm_clip,
            target_delta=args.delta,
            experimental=True,
            clipping_method=5
        )
    elif args.schema == 'dpsgd-f':
        privacy_engine = PrivacyEngine(
            args=args,
            module=model,
            batch_size=args.batch_size,
            sample_size=int(len(train_loader.dataset) / args.batch_size) * args.epoch,
            alphas=[1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] + list(range(5, 64)) + [128, 256, 512],
            noise_multiplier=args.noise_multiplier,
            max_grad_norm=args.l2_norm_clip,
            target_delta=args.delta,
            experimental=True,
            clipping_method=6
        )
    elif args.schema == 'fairdp':
        privacy_engine = PrivacyEngine(
            args=args,
            module=model,
            batch_size=args.batch_size,
            sample_size=int(len(train_loader.dataset) / args.batch_size) * args.epoch,
            alphas=[1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] + list(range(5, 64)) + [128, 256, 512],
            noise_multiplier=args.noise_multiplier,
            max_grad_norm=args.l2_norm_clip,
            target_delta=args.delta,
            experimental=True,
            clipping_method=7
        )
    if args.schema != 'vanilla':
        privacy_engine.attach(optimizer)
    if args.device != -1:
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True

    acc_opt = 0
    for epoch in range(1, args.epoch + 1):
        start_time = datetime.now()

        if args.schema == 'vanilla':
            train(device, train_loader, model, optimizer, epoch)
        else:
            train_dp(args, device, train_loader, model, optimizer, epoch)
            compute_dp_sgd_privacy(len(train_loader.dataset), args.batch_size, args.noise_multiplier, epoch, args.delta)

        acc = test(device, test_loader, model)
        if acc > acc_opt:
            torch.save(model.state_dict(),
                       'mnist_{}_{}_{}_{}_model.pkl'.format(args.schema, args.seed, epoch, datetime.now()))
            print('Best Accuracy')
            acc_opt = acc

        end_time = datetime.now()
        print('{} - {} <{}>'.format(start_time, end_time, end_time - start_time))

    # For notification only: finish of the process
    plt.subplots()
    plt.show()


if __name__ == '__main__':
    main()
