import argparse
import torch.optim as optim
from datetime import datetime
from dutch.dutch_dataset import *
from dutch.dutch_model import *
from torch.utils.data import RandomSampler
from tqdm import tqdm
import matplotlib.pyplot as plt
from opacus import PrivacyEngine
from utils.data_helper import *
from utils.accountant_helper import *
from utils.math_helper import *


def train(device, train_loader, model, optimizer, epoch):
    model.train()
    correct = 0
    for batch_idx, sample in enumerate(tqdm(train_loader)):
        attribute = sample['attribute'].to(device)
        label = sample['label'].to(device)
        output = model(attribute)
        pred = (output >= 0.5).float()
        correct += pred.eq(label.float().view_as(pred)).sum()
        loss = nn.functional.binary_cross_entropy(torch.squeeze(output), label.float())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('Train Epoch: {} \tLoss: {:.6f} \tAccuracy: {}/{} ( {:.2f}% )'.format(epoch, loss.item(), correct,
                                                                                len(train_loader.dataset),
                                                                                100. * correct / len(
                                                                                    train_loader.dataset)))


def train_dp(args, device, train_loader, model, optimizer, epoch):
    model.train()
    correct = 0
    for batch_idx, sample in enumerate(tqdm(train_loader)):
        attribute = sample['attribute'].to(device)
        label = sample['label'].to(device)
        group = sample['group'].to(device)
        output = model(attribute)
        pred = (output >= 0.5).float()
        correct += pred.eq(label.float().view_as(pred)).sum()

        loss = nn.functional.binary_cross_entropy(torch.squeeze(output), label.float())
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        # || nabla f(theta_t) ||
        nabla = compute_norm(model)
        optimizer.step(device=device, group=group, batch_size=args.batch_size, nabla=nabla,
                       noise_multiplier=args.noise_multiplier)

    print('Train Epoch: {} \tLoss: {:.6f} \tAccuracy: {}/{} ( {:.2f}% )'.format(epoch, loss.item(), correct,
                                                                                len(train_loader.dataset),
                                                                                100. * correct / len(
                                                                                    train_loader.dataset)))


def test(device, test_loader, model):
    model.eval()
    test_loss = 0
    correct = defaultdict(int)
    total = defaultdict(int)
    with torch.no_grad():
        for sample in test_loader:
            attribute = sample['attribute'].to(device)
            label = sample['label'].to(device)
            group = sample['group'].to(device)
            output = model(attribute)
            pred = (output >= 0.5).float()
            for idx, g in enumerate(group):
                correct[g.item()] += pred[idx].eq(label.float().view_as(pred)[idx]).item()
                total[g.item()] += 1
            test_loss += nn.functional.binary_cross_entropy(torch.squeeze(output), label.float())

    cqt = sum(correct.values())
    test_loss /= len(test_loader.dataset)

    print('Test: \tLoss: {:.6f} \tAccuracy: {}/{} ( {:.2f}% )'.format(test_loss, cqt, len(test_loader.dataset),
                                                                      100. * cqt / len(test_loader.dataset)))
    for g, c in sorted(correct.items(), key=lambda x: x[0]):
        print('Test set: {} \tAccuracy: {}/{} ( {:.2f}% )'.format(g, c, total[g], 100. * c / total[g]))

    return 100. * cqt / len(test_loader.dataset)


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='Dutch')
    parser.add_argument('--device', default=-1, type=int, choices=[-1, 0, 1, 2, 3])
    parser.add_argument('--seed', default=7, type=int)
    parser.add_argument('--data-dir', default='../../dutch/', type=str)
    parser.add_argument('--batch-size', default=128, type=int)
    parser.add_argument('--schema', default='vanilla', type=str, choices=['dpsgd', 'opt-q', 'dpsgd-f', 'fairdp'])
    parser.add_argument('--lr', default=0.22360679774997896, type=float)
    parser.add_argument('--weight-decay', default=0.01, type=float)
    parser.add_argument('--epoch', default=20, type=int)
    # Differential privacy settings
    parser.add_argument('--l2-norm-clip', default=0.5, type=float)
    parser.add_argument('--noise-multiplier', default=1, type=float)
    parser.add_argument('--num-microbatches', default=128, type=int)
    parser.add_argument('--delta', default=1e-5, type=float)
    args = parser.parse_args()
    print('Argument =', vars(args))

    device = 'cpu'
    torch.manual_seed(args.seed)
    if args.device != -1:
        device = 'cuda'
        torch.cuda.manual_seed(args.seed)

    kwargs = {'batch_size': args.batch_size}
    if args.device != -1:
        kwargs.update({'num_workers': 1, 'pin_memory': True}, )

    full_dataset = DutchDataset(args.data_dir + 'dutch.csv')
    train_length = int(len(full_dataset) * 0.8)
    test_length = len(full_dataset) - train_length
    train_data, test_data = torch.utils.data.random_split(full_dataset, [train_length, test_length])

    sampler = RandomSampler(train_data, replacement=True)
    train_loader = torch.utils.data.DataLoader(train_data, **kwargs, sampler=sampler, drop_last=True)
    test_loader = torch.utils.data.DataLoader(test_data, **kwargs)
    # print
    # print_group_size(train_data, test_data)

    model = DutchNet().to(device)
    optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    if args.schema == 'dpsgd':
        privacy_engine = PrivacyEngine(
            args=args,
            module=model,
            batch_size=args.batch_size,
            sample_size=int(len(train_loader.dataset) / args.batch_size) * args.epoch,
            alphas=[1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] + list(range(5, 64)) + [128, 256, 512],
            noise_multiplier=args.noise_multiplier,
            max_grad_norm=args.l2_norm_clip,
            target_delta=args.delta
        )
    elif args.schema == 'opt-q':
        privacy_engine = PrivacyEngine(
            args=args,
            module=model,
            batch_size=args.batch_size,
            sample_size=int(len(train_loader.dataset) / args.batch_size) * args.epoch,
            alphas=[1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] + list(range(5, 64)) + [128, 256, 512],
            noise_multiplier=args.noise_multiplier,
            max_grad_norm=args.l2_norm_clip,
            target_delta=args.delta,
            experimental=True,
            clipping_method=5
        )
    elif args.schema == 'dpsgd-f':
        privacy_engine = PrivacyEngine(
            args=args,
            module=model,
            batch_size=args.batch_size,
            sample_size=int(len(train_loader.dataset) / args.batch_size) * args.epoch,
            alphas=[1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] + list(range(5, 64)) + [128, 256, 512],
            noise_multiplier=args.noise_multiplier,
            max_grad_norm=args.l2_norm_clip,
            target_delta=args.delta,
            experimental=True,
            clipping_method=6
        )
    elif args.schema == 'fairdp':
        privacy_engine = PrivacyEngine(
            args=args,
            module=model,
            batch_size=args.batch_size,
            sample_size=int(len(train_loader.dataset) / args.batch_size) * args.epoch,
            alphas=[1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] + list(range(5, 64)) + [128, 256, 512],
            noise_multiplier=args.noise_multiplier,
            max_grad_norm=args.l2_norm_clip,
            target_delta=args.delta,
            experimental=True,
            clipping_method=7
        )
    if args.schema != 'vanilla':
        privacy_engine.attach(optimizer)
    if args.device != -1:
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True

    acc_opt = 0
    for epoch in range(1, args.epoch + 1):
        start_time = datetime.now()

        if args.schema == 'vanilla':
            train(device, train_loader, model, optimizer, epoch)
        else:
            train_dp(args, device, train_loader, model, optimizer, epoch)
            compute_dp_sgd_privacy(len(train_loader.dataset), args.batch_size, args.noise_multiplier, epoch, args.delta)

        acc = test(device, test_loader, model)
        if acc > acc_opt:
            torch.save(model.state_dict(),
                       'dutch_{}_{}_{}_{}_model.pkl'.format(args.schema, args.seed, epoch, datetime.now()))
            print('Best Accuracy')
            acc_opt = acc

        end_time = datetime.now()
        print('{} - {} <{}>'.format(start_time, end_time, end_time - start_time))

    # For notification only: finish of the process
    plt.subplots()
    plt.show()


if __name__ == '__main__':
    main()
