# Modification of Opacus library MNIST example. The bulk of the code comes from the Opacus library. We only modify the privacy accounting to take into account realized gradient norms instead of the clipping value, and add appropriate plotting code.

"""
Runs MNIST training with differential privacy.

"""

import argparse

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchdp import PrivacyEngine
from torchvision import datasets, transforms
from tqdm import tqdm

import matplotlib.pyplot as plt

import os


# Precomputed characteristics of the MNIST dataset
MNIST_MEAN = 0.1307
MNIST_STD = 0.3081


class SampleConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 8, 2, padding=3)
        self.conv2 = nn.Conv2d(16, 32, 4, 2)
        self.fc1 = nn.Linear(32 * 4 * 4, 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        # x of shape [B, 1, 28, 28]
        x = F.relu(self.conv1(x))   # -> [B, 16, 14, 14]
        x = F.max_pool2d(x, 2, 1)   # -> [B, 16, 13, 13]
        x = F.relu(self.conv2(x))   # -> [B, 32, 5, 5]
        x = F.max_pool2d(x, 2, 1)   # -> [B, 32, 4, 4]
        x = x.view(-1, 32 * 4 * 4)  # -> [B, 512]
        x = F.relu(self.fc1(x))     # -> [B, 32]
        x = self.fc2(x)             # -> [B, 10]
        return x

    def name(self):
        return "SampleConvNet"


def train(args, model, device, train_loader, optimizer, epoch, running_norms):
    model.train()
    criterion = nn.CrossEntropyLoss()
    losses = []
    for _batch_idx, (data, target) in enumerate(tqdm(train_loader)):

        correct = 0


        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)

        # compute train acc
        pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
        correct = pred.eq(target.view_as(pred)).sum().item() 


        loss = criterion(output, target)
        loss.backward()
        gradient_norms = optimizer.step(running_norms)
        gradient_norms_sq = gradient_norms * gradient_norms
        losses.append(loss.item())

    

    if not args.disable_dp:
        epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(args.delta)
        print(
            f"Train Epoch: {epoch} \t"
            f"Loss: {np.mean(losses):.6f} "
            f"Acc: {correct/60000.0:.6f} "
            f"(ε = {epsilon:.2f}, δ = {args.delta}) for α = {best_alpha}"
        )
    else:
        print(f"Train Epoch: {epoch} \t Loss: {np.mean(losses):.6f}")

    return gradient_norms_sq


def test(args, model, device, test_loader):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in tqdm(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(
        "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n".format(
            test_loss,
            correct,
            len(test_loader.dataset),
            100.0 * correct / len(test_loader.dataset),
        )
    )
    return correct / len(test_loader.dataset)


def main():
    # Training settings
    parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=1024,
        metavar="TB",
        help="input batch size for testing (default: 1024)",
    )
    parser.add_argument(
        "-n",
        "--epochs",
        type=int,
        default=150,
        metavar="N",
        help="number of epochs to train",
    )
    parser.add_argument(
        "-budget",
        "--budget",
        type=float,
        default=11200,
        metavar="BG",
        help="norm squared budget",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=.2,
        metavar="LR",
        help="learning rate",
    )
    parser.add_argument(
        "--sigma",
        type=float,
        default=170,
        metavar="S",
        help="Noise multiplier",
    )
    parser.add_argument(
        "--should-clip",
        type=bool,
        default=True,
        metavar="SC",
        help="Indicator whether to clip the gradients (default True)",
    )
    parser.add_argument(
        "-c",
        "--max-per-sample-grad_norm",
        type=float,
        default=10,
        metavar="C",
        help="Clip per-sample gradients to this norm",
    )
    parser.add_argument(
        "--delta",
        type=float,
        default=1e-5,
        metavar="D",
        help="Target delta (default: 1e-5)",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cpu",
        help="device ID for this process (default: 'cpu')",
    )
    parser.add_argument(
        "--disable-dp",
        action="store_true",
        default=False,
        help="Disable privacy training and just train with vanilla SGD",
    )
    parser.add_argument(
        "--data-root",
        type=str,
        default="../mnist",
        help="Where MNIST is/will be stored",
    )
    args = parser.parse_args()
    device = torch.device(args.device)

    kwargs = {"num_workers": 1, "pin_memory": True}

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
        args.data_root,
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(),
            transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))]
        ),
        ),
        batch_size=60000,
        shuffle=False,
        **kwargs,
    )
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
        args.data_root,
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(),
            transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))]
        ),
        ),
        batch_size=args.test_batch_size,
        shuffle=True,
        **kwargs,
    )

    path1 = "accuracies.pdf"
    path2 = "active_points.pdf"

        
    run_results = []
    active_points = []
    running_gradient_sq_norms = [0]


    model = SampleConvNet().to(device)

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0)
    if not args.disable_dp:
        privacy_engine = PrivacyEngine(
                model,
                batch_size=60000,
                sample_size=len(train_loader.dataset),
                alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 80)),
                noise_multiplier=args.sigma,
                max_grad_norm=args.max_per_sample_grad_norm,
                norm_sq_budget = args.budget,
                should_clip = args.should_clip,
        )
        privacy_engine.attach(optimizer)
    for epoch in range(1, args.epochs + 1):
        gradient_norms = train(args, model, device, train_loader, optimizer, epoch, running_gradient_sq_norms[-1])

        # update running squared grad norms
        running_gradient_sq_norms.append(running_gradient_sq_norms[-1] + gradient_norms)


            

        # add new test accuracy 
        run_results.append(test(args, model, device, test_loader))
        active_points.append(torch.sum(running_gradient_sq_norms[-1] < args.budget))
        if torch.sum(running_gradient_sq_norms[-1] - args.budget) == 0:
            break
	   


    alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64))
    eps_val = min([alpha/2*args.budget/(args.sigma**2 * args.max_per_sample_grad_norm**2) + np.log(1/args.delta)/(alpha-1) for alpha in alphas])


    plt.plot(active_points)
    plt.axvline(int(args.budget/(args.max_per_sample_grad_norm**2)), 0, 1, c='m', linewidth = 3)
    plt.ylim(0,65000)
    plt.title(f"ε = {eps_val:.1f}, δ =1e-5", fontsize = 22)
    plt.ylabel(f"Number of active points", fontsize = 22)
    plt.xlabel("Step", fontsize = 22)
    plt.savefig(path2, bbox_inches='tight')
        

    plt.clf()

        
        
    plt.plot(range(1,len(run_results) + 1), run_results, linewidth = 3)
    plt.axvline(int(args.budget/(args.max_per_sample_grad_norm**2)), 0, 1, c='m', linewidth = 3)
    plt.xlabel("Step", fontsize = 22)
    plt.ylabel("Test accuracy", fontsize = 22)
    plt.title(f"ε = {eps_val:.1f}, δ =1e-5", fontsize = 22)
    plt.yticks([0.8, 0.825, 0.85, 0.875, 0.9, 0.925, 0.95, 0.975, 1.0])
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.ylim(0.875,0.975)
    plt.xlim(0,args.epochs)
    plt.grid(True)
    plt.savefig(path1, bbox_inches='tight')




if __name__ == "__main__":
    main()
