# Import pytorch packages
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import normalize
# import torchvision.transforms as transforms
from torchvision import datasets, transforms
from opacus import PrivacyEngine
from tqdm import tqdm
import numpy as np
from sklearn.datasets import load_svmlight_file

import os
import sys
import argparse
import warnings
warnings.simplefilter("ignore")

from models import LinearNN, TwoLayerNN

# MAX_GRAD_NORM = 1.2
# EPSILON = 50.0
# DELTA = 1e-5
# EPOCHS = 20

# LR = 1e-3

# BATCH_SIZE = 128
# MAX_PHYSICAL_BATCH_SIZE = 128

# Precomputed characteristics of the MNIST dataset
MNIST_MEAN = 0.1307
MNIST_STD = 0.3081
# Training settings
parser = argparse.ArgumentParser(
    description="Opacus MNIST Example",
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
    "-b",
    "--batch-size",
    type=int,
    default=128,
    metavar="B",
    help="Batch size",
)
parser.add_argument(
    "--test-batch-size",
    type=int,
    default=256,
    metavar="TB",
    help="input batch size for testing",
)
parser.add_argument(
    "-n",
    "--epochs",
    type=int,
    default=10,
    metavar="N",
    help="number of epochs to train",
)
parser.add_argument(
    "-r",
    "--n-runs",
    type=int,
    default=1,
    metavar="R",
    help="number of runs to average on",
)
parser.add_argument(
    "--lr",
    type=float,
    default=0.01,
    metavar="LR",
    help="learning rate",
)
parser.add_argument(
    "--sigma",
    type=float,
    default=1.0,
    metavar="S",
    help="Noise multiplier",
)
parser.add_argument(
    "-c",
    "--max-per-sample-grad_norm",
    type=float,
    default=1.0,
    metavar="C",
    help="Clip per-sample gradients to this norm",
)
parser.add_argument(
    "--delta",
    type=float,
    default=1e-5,
    metavar="D",
    help="Target delta",
)
parser.add_argument(
    "--device",
    type=str,
    default="cuda",
    help="GPU or CPU",
)
parser.add_argument(
    "--save-model",
    action="store_true",
    default=False,
    help="Save the trained model",
)
parser.add_argument(
    "--disable-dp",
    action="store_true",
    default=False,
    help="Disable privacy training and just train with vanilla SGD",
)
parser.add_argument(
    "-m",
    "--model",
    type=float,
    default=1,  # 1 for LinearNN and 2 for two-layer NN.
    help="1 for LinearNN and 2 for two-layer NN",
)
parser.add_argument(
    "--data",
    type=str,
    default="mnist",  # 1 for LinearNN and 2 for two-layer NN.
    help="Training on MNIST or CIFAR",
)
# parser.add_argument(
#     "--secure-rng",
#     action="store_true",
#     default=False,
#     help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
# )
# parser.add_argument(
#     "--data-root",
#     type=str,
#     default="../mnist",
#     help="Where MNIST is/will be stored",
# )
args = parser.parse_args()
device = torch.device(args.device)

batch_size = args.batch_size
privacy_engine = PrivacyEngine()

class CIFARDataLoader(Dataset):
    def __init__(self, featureFile, labelFile, transform=None):
        X = np.load(featureFile)
        y = np.load(labelFile)
        self.X = torch.from_numpy(X)
        # self.X = normalize( self.X, p=2.0, dim=1)
        self.X = self.X.float()
        self.y = torch.from_numpy(y)
        self.y = self.y.long()
        self.transform = transform
    def __len__(self):
        return self.X.shape[0]
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.X[idx,:], self.y[idx]

class MNISTDataLoader(Dataset):
    def __init__(self, filename, transform=None):
        X, y = load_svmlight_file(filename)
        X = X.todense()
        self.X = torch.from_numpy(X)
        # self.X = normalize( self.X, p=2.0, dim=1)
        self.X = self.X.float()
        self.y = torch.from_numpy(y)
        self.y = self.y.long()
        self.transform = transform
    def __len__(self):
        return self.X.shape[0]
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.X[idx,:], self.y[idx]


def accuracy(preds, labels):
    return (preds == labels).mean()

def getPerSampleNorm(model, batchSize):
    #ret = np.array( [0.0]*batchSize )
    ret = [0.0]*batchSize
    for name, param in model.named_parameters():
        for i in range(batchSize):
            ret[i] += (torch.norm( param.grad_sample[i,:] )**2).item()
    ret = [ np.sqrt(x) for x in ret ]
    return ret

def getGrowthParameter(model):
    ret = 1.0
    with torch.no_grad():
        params = list(model.parameters())
        for i in range( len(params)-1 ):
            p = params[i]
            ret = ret * max(1.0, np.linalg.norm( p, ord=2 ) )
    return ret     

def getLoss(output, target):
    criterion = nn.CrossEntropyLoss(reduce=False)
    with torch.no_grad():
        ret = criterion(output, target)
    return ret

# Clip the loss function
def clippedLoss(output, target, weight):
    criterion = nn.CrossEntropyLoss(reduce=False)
    losses = criterion(output, target)
    return torch.mean( torch.mul( losses, weight ) )
                
def train(model, train_loader, optimizer, epoch, device, train_acc, epoch_record, clip_freq_record):
    model.train()
    criterion = nn.CrossEntropyLoss()

    losses = []
    top1_acc = []
    numClipping = []

    for i, (images, target) in enumerate(train_loader):
        optimizer.zero_grad()
        images = images.to(device)
        target = target.to(device)

        # compute output
        # print(images.shape)
        output = model(images)
        # lossBeforeClip = getLoss(output, target)
        # print(lossBeforeClip)

        # assert(1==2)
        loss = criterion(output, target)

        preds = np.argmax(output.detach().cpu().numpy(), axis=1)
        labels = target.detach().cpu().numpy()

        # measure accuracy and record loss
        acc = accuracy(preds, labels)

        losses.append(loss.item())
        top1_acc.append(acc)

        loss.backward()

        # get statistics
        perSampleGradNorm = getPerSampleNorm(model, images.shape[0])
        numClipping.append( sum( [ x > args.max_per_sample_grad_norm for x in perSampleGradNorm ] ) )

        optimizer.step()

        if ((i+1) % 100 == 0) or (i+1)==len(train_loader):
            epsilon = privacy_engine.get_epsilon(args.delta)
            print(
                f"\tTrain Epoch: {epoch} \t"
                f"Loss: {np.mean(losses):.6f} "
                f"Acc@1: {np.mean(top1_acc) * 100:.6f} "
                f"Clip: {np.sum( numClipping ):.6f}"
                f"(ε = {epsilon:.2f}, δ = {args.delta})"
            )
            epoch_record.append( epoch - 1 + (i+1)/len(train_loader) )
            print((i+1)/len(train_loader))
            train_acc.append( np.mean(top1_acc) * 100 )
    clip_freq_record.append( np.sum( numClipping ) )

def test(model, test_loader, device):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    losses = []
    top1_acc = []

    with torch.no_grad():
        for images, target in test_loader:
            images = images.to(device)
            target = target.to(device)

            output = model(images)
            loss = criterion(output, target)
            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
            labels = target.detach().cpu().numpy()
            acc = accuracy(preds, labels)

            losses.append(loss.item())
            top1_acc.append(acc)

    top1_avg = np.mean(top1_acc)

    print(
        f"\tTest set:"
        f"Loss: {np.mean(losses):.6f} "
        f"Acc: {top1_avg * 100:.6f} "
    )
    return np.mean(top1_acc)


def main():
    # Load data
    data_name = args.data
    if data_name == "mnist":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(MNIST_MEAN, MNIST_STD),
        ])
        train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, download=True, 
                        transform=transform),batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=False, download=True, 
                        transform=transform),batch_size=batch_size, shuffle=True)
        # train_set = MNISTDataLoader("./data/mnist.scale")
        # test_set = MNISTDataLoader("./data/mnist.scale.t")
        # train_loader = DataLoader(train_set, batch_size=128,shuffle=True, num_workers=0)    
        # test_loader = DataLoader(test_set, batch_size=256,shuffle=True, num_workers=0)  
        print("Training on MNIST!")
    else:
        train_set = CIFARDataLoader("./data/cifar-10-resnet-features/train_features2.npy", "./data/cifar-10-resnet-features/train_labels2.npy")
        test_set = CIFARDataLoader("./data/cifar-10-resnet-features/test_features2.npy", "./data/cifar-10-resnet-features/test_labels2.npy")
        train_loader = DataLoader(train_set, batch_size=128,shuffle=True, num_workers=0)    
        test_loader = DataLoader(test_set, batch_size=256,shuffle=True, num_workers=0)    
        print("Training on CIFAR")

    if args.model == 1:
        print("Using linear model.")
        model = LinearNN(data_name).to(device)
    else:
        print("Using 2-layer networks")
        model = TwoLayerNN(data_name).to(device)


    # Setup optimizer and privacy engine
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0, weight_decay=1e-5)
    model, optimizer, train_loader = privacy_engine.make_private(
        module=model,
        optimizer=optimizer,
        data_loader=train_loader,
        noise_multiplier=args.sigma,
        max_grad_norm=args.max_per_sample_grad_norm,
    )

    # Training
    train_acc = []
    epoch_record = []
    clip_freq_record = []
    test_acc = []

    for epoch in tqdm(range(args.epochs), desc="Epoch", unit="epoch"):
        # if epoch == 25:
        # model.train()
        # optimizer = optim.SGD(model.parameters(), lr=1e-3)
        # privacy_engine = PrivacyEngine()
        # model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
        #     module=model,
        #     optimizer=optimizer,
        #     data_loader=train_loader,
        #     epochs=EPOCHS,
        #     target_epsilon=EPSILON,
        #     target_delta=DELTA,
        #     max_grad_norm=MAX_GRAD_NORM,
        # )
        train(model, train_loader, optimizer, epoch + 1, device, train_acc, epoch_record, clip_freq_record)
        test_acc.append( test(model, test_loader, device) )
    
    eps = privacy_engine.get_epsilon(args.delta)
    if args.model == 1:
        architecture = "Linear"
    else:
        architecture = "TwoLayerNN"
    filename = "./expResults/"+data_name+"C"+str(args.max_per_sample_grad_norm)+"SIGMA"+str(args.sigma)+"M"+str(architecture)+".txt"
    print(filename)
    writeResult(filename, epoch_record, train_acc, test_acc, clip_freq_record, architecture )


def writeResult( filename, epoch_record, train_acc, test_acc, clip_freq_record, architecture ):
    f = open(filename, "w")
    print("Configuration:", file=f)
    # Retrieve epsilon.
    eps = privacy_engine.get_epsilon(args.delta)
    print(f"Epsilon: {eps:.4f}, DELTA: {args.delta}, LR: {args.lr}, SIGMA: {args.sigma}, C: {args.max_per_sample_grad_norm}, ARCHI: {architecture} ", file=f)
    print("Epoch:", file=f)
    epochs_str = [str(e) for e in epoch_record]
    print( ','.join(epochs_str), file=f )
    print("Training:", file=f)
    train_acc_str = [str(acc) for acc in train_acc ]
    print( ','.join(train_acc_str), file=f )
    print("Testing:", file=f)
    test_acc_str = [str(acc) for acc in test_acc ]
    print( ','.join(test_acc_str), file=f )
    print("Clip:", file=f)
    clip_str = [str(clip) for clip in clip_freq_record]
    print( ','.join(clip_str), file=f )
    f.close()

if __name__ == "__main__":
    main()


# python MNISTscript.py -c 80 --lr 0.001 -m 1
# python MNISTscript.py -c 1 --lr 0.1 -m 1 --data cifar
# python MNISTscript.py -c 4 --sigma 1.0 --lr 0.1 -m 1 --data cifar
# dataloader = DataLoader(dataset, batch_size=128,shuffle=True, num_workers=0)    