# Import pytorch packages
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import normalize
# import torchvision.transforms as transforms
from torchvision import datasets, transforms
# from opacus import PrivacyEngine
from tqdm import tqdm
import numpy as np
import time
from sklearn.datasets import load_svmlight_file

import os
import sys
import argparse
import warnings
warnings.simplefilter("ignore")

from models import LinearNN, TwoLayerNN

# MAX_GRAD_NORM = 1.2
# EPSILON = 50.0
# DELTA = 1e-5
# EPOCHS = 20

# LR = 1e-3

# BATCH_SIZE = 128
# MAX_PHYSICAL_BATCH_SIZE = 128

# Precomputed characteristics of the MNIST dataset
MNIST_MEAN = 0.1307
MNIST_STD = 0.3081
# Training settings
parser = argparse.ArgumentParser(
    description="Opacus MNIST Example",
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
    "-b",
    "--batch-size",
    type=int,
    default=128,
    metavar="B",
    help="Batch size",
)
parser.add_argument(
    "--test-batch-size",
    type=int,
    default=256,
    metavar="TB",
    help="input batch size for testing",
)
parser.add_argument(
    "-n",
    "--epochs",
    type=int,
    default=10,
    metavar="N",
    help="number of epochs to train",
)
parser.add_argument(
    "-r",
    "--n-runs",
    type=int,
    default=1,
    metavar="R",
    help="number of runs to average on",
)
parser.add_argument(
    "--lr",
    type=float,
    default=0.01,
    metavar="LR",
    help="learning rate",
)
parser.add_argument(
    "--sigma",
    type=float,
    default=1.0,
    metavar="S",
    help="Noise multiplier",
)
parser.add_argument(
    "-c",
    "--max-per-sample-grad_norm",
    type=float,
    default=1.0,
    metavar="C",
    help="Clip per-sample gradients to this norm",
)
parser.add_argument(
    "--delta",
    type=float,
    default=1e-5,
    metavar="D",
    help="Target delta",
)
parser.add_argument(
    "--device",
    type=str,
    default="cuda",
    help="GPU or CPU",
)
parser.add_argument(
    "--save-model",
    action="store_true",
    default=False,
    help="Save the trained model",
)
parser.add_argument(
    "--disable-dp",
    action="store_true",
    default=False,
    help="Disable privacy training and just train with vanilla SGD",
)
parser.add_argument(
    "-m",
    "--model",
    type=float,
    default=1,  # 1 for LinearNN and 2 for two-layer NN.
    help="1 for LinearNN and 2 for two-layer NN",
)
parser.add_argument(
    "--data",
    type=str,
    default="mnist",  # 1 for LinearNN and 2 for two-layer NN.
    help="Training on MNIST or CIFAR",
)
# parser.add_argument(
#     "--secure-rng",
#     action="store_true",
#     default=False,
#     help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
# )
# parser.add_argument(
#     "--data-root",
#     type=str,
#     default="../mnist",
#     help="Where MNIST is/will be stored",
# )
args = parser.parse_args()
device = torch.device(args.device)

batch_size = args.batch_size
# privacy_engine = PrivacyEngine()

class CIFARDataLoader(Dataset):
    def __init__(self, featureFile, labelFile, transform=None):
        X = np.load(featureFile)
        y = np.load(labelFile)
        self.X = torch.from_numpy(X)
        # self.X = normalize( self.X, p=2.0, dim=1)
        self.X = self.X.float()
        self.y = torch.from_numpy(y)
        self.y = self.y.long()
        self.transform = transform
    def __len__(self):
        return self.X.shape[0]
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.X[idx,:], self.y[idx]


class MNISTDataLoader(Dataset):
    def __init__(self, filename, transform=None):
        X, y = load_svmlight_file(filename)
        X = X.todense()
        self.X = torch.from_numpy(X)
        # self.X = normalize( self.X, p=2.0, dim=1)
        self.X = self.X.float()
        self.y = torch.from_numpy(y)
        self.y = self.y.long()
        self.transform = transform
    def __len__(self):
        return self.X.shape[0]
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.X[idx,:], self.y[idx]

def accuracy(preds, labels):
    return (preds == labels).mean()

def getPerSampleNorm(model, batchSize):
    #ret = np.array( [0.0]*batchSize )
    ret = [0.0]*batchSize
    for name, param in model.named_parameters():
        for i in range(batchSize):
            ret[i] += (torch.norm( param.grad_sample[i,:] )**2).item()
    ret = [ np.sqrt(x) for x in ret ]
    return ret

def PowerIteration(A, maxIter=10):
    # A: n x m, n > m
    b_k = torch.randn( A.shape[0] ).cuda()
    # b_k = np.random.rand(A.shape[0])
    for _ in range(maxIter):
        # b_k1 =  np.dot( A, b_k )
        b_k1 = torch.matmul( A, b_k ) 
        # b_k1_norm = np.linalg.norm( b_k1 )
        b_k1_norm = torch.norm( b_k1 )
        b_k = b_k1 / b_k1_norm
    ret = torch.norm(  torch.matmul( A, b_k ) ) / torch.norm(b_k)
    return ret

def GetSpectralNorm(param):
    # param should be on CPU as numpy.array
    # return 1.0
    n, m = param.shape
    if n < m:
        A = torch.matmul( param, param.T )
    else:
        A = torch.matmul( param.T )
    # A = A.cpu().detach().numpy()
    ret = torch.sqrt( PowerIteration(A) )
    # ret2 = np.sqrt(  np.linalg.norm( A, ord=2 ) )
    # print(ret, ret2)
    return ret

def getGrowthParameter(model):
    LipsPerLayer = [] 
    with torch.no_grad():
        params = list(model.parameters())
        for i in range( len(params) ):
            p = params[i]
            # LipsPerLayer.append(1.0)
            LipsPerLayer.append(  GetSpectralNorm( p )  )
            # LipsPerLayer.append(  np.linalg.norm( p, ord=2 )  )
    LipsPerLayer = torch.tensor(LipsPerLayer)
    LipsProd = torch.prod( LipsPerLayer )
    LipsPerLayer = [ LipsProd/x for x in LipsPerLayer ]
    return sum( [x**2 for x in LipsPerLayer] )

def getLoss(output, target):
    criterion = nn.CrossEntropyLoss(reduce=False)
    with torch.no_grad():
        ret = criterion(output, target)
    return ret

# Clip the loss function
def clippedLoss(output, target, weight):
    criterion = nn.CrossEntropyLoss(reduce=False)
    losses = criterion(output, target)
    return torch.mean( torch.mul( losses, weight ) )

def getWeights(model, images, output, target, C):
    bsize = output.shape[0]
    lossBeforeClip = getLoss(output, target)
    xs = torch.flatten(images,1)
    xnorms = torch.norm(xs, dim=1)
    modelLipschitz = getGrowthParameter(model)
    # modelLipschitz = 1.0
    # lossBeforeClip can occasionally be -0.00
    lastLayerGradUpperBound = torch.minimum( torch.ones( bsize ).to(device), 2*lossBeforeClip + 1e-6 )
    gradUpperBound = torch.sqrt( 4 * modelLipschitz * xnorms**2 * lastLayerGradUpperBound + 1e-6 )
    weights = torch.minimum( torch.ones( bsize ).to(device), C/(gradUpperBound+1e-6) )
    return weights

def addNoise(model, batchSize, C, sigma, device):
    with torch.no_grad():
        for name, param in model.named_parameters():
            # Add noise to each layer.
            param.grad = param.grad + torch.normal( mean=0, std=(C*sigma/batchSize), size=param.shape, device=device )
                
def train(model, train_loader, optimizer, epoch, C, device, train_acc, epoch_record):
    model.train()
    criterion = nn.CrossEntropyLoss()

    losses = []
    top1_acc = []
    numClipping = []

    modelLipschitz = getGrowthParameter(model)
    print(modelLipschitz)
    for i, (images, target) in enumerate(train_loader):
        optimizer.zero_grad()
        images = images.to(device)
        target = target.to(device)

        # compute output
        # print(images.shape)
        output = model(images)
        weights = getWeights(model, images, output, target, C)
        # print(torch.min(weights))
        loss = clippedLoss( output, target, weights )

        # loss = criterion(output, target)

        preds = np.argmax(output.detach().cpu().numpy(), axis=1)
        labels = target.detach().cpu().numpy()

        # measure accuracy and record loss
        acc = accuracy(preds, labels)

        losses.append(loss.item())
        top1_acc.append(acc)

        loss.backward()

        # Add noise
        B = images.shape[0]
        addNoise(model, B, C, args.sigma, device)

        # get statistics
        # perSampleGradNorm = getPerSampleNorm(model, images.shape[0])
        # numClipping.append( sum( [ x > args.max_per_sample_grad_norm for x in perSampleGradNorm ] ) )
        
        optimizer.step()


    #     if ((i+1) % 100 == 0) or (i+1)==len(train_loader):
    #         # epsilon = privacy_engine.get_epsilon(args.delta)
    #         print(
    #             f"\tTrain Epoch: {epoch} \t"
    #             f"Loss: {np.mean(losses):.6f} "
    #             f"Acc@1: {np.mean(top1_acc) * 100:.6f} "
    #             # f"Clip: {np.sum( numClipping ):.6f}"
    #             # f"(ε = {epsilon:.2f}, δ = {args.delta})"
    #         )
    #         epoch_record.append( epoch - 1 + (i+1)/len(train_loader) )
    #         # print((i+1)/len(train_loader))
    #         train_acc.append( np.mean(top1_acc) * 100 )
    # # clip_freq_record.append( np.sum( numClipping ) )

def test(model, test_loader, device):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    losses = []
    top1_acc = []

    with torch.no_grad():
        for images, target in test_loader:
            images = images.to(device)
            target = target.to(device)

            output = model(images)
            loss = criterion(output, target)
            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
            labels = target.detach().cpu().numpy()
            acc = accuracy(preds, labels)

            losses.append(loss.item())
            top1_acc.append(acc)

    top1_avg = np.mean(top1_acc)

    print(
        f"\tTest set:"
        f"Loss: {np.mean(losses):.6f} "
        f"Acc: {top1_avg * 100:.6f} "
    )
    return np.mean(top1_acc)


def main():
    # Load data
    data_name = args.data
    if data_name == "mnist":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(MNIST_MEAN, MNIST_STD),
        ])
        train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, download=True, 
                        transform=transform),batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=False, download=True, 
                        transform=transform),batch_size=batch_size, shuffle=True)
        # train_set = MNISTDataLoader("./data/mnist.scale")
        # test_set = MNISTDataLoader("./data/mnist.scale.t")
        # train_loader = DataLoader(train_set, batch_size=128,shuffle=True, num_workers=0)    
        # test_loader = DataLoader(test_set, batch_size=256,shuffle=True, num_workers=0)  
        # print("Training on MNIST!")
    else:
        train_set = CIFARDataLoader("./data/cifar-10-resnet-features/train_features2.npy", "./data/cifar-10-resnet-features/train_labels2.npy")
        test_set = CIFARDataLoader("./data/cifar-10-resnet-features/test_features2.npy", "./data/cifar-10-resnet-features/test_labels2.npy")
        train_loader = DataLoader(train_set, batch_size=128,shuffle=True, num_workers=0)    
        test_loader = DataLoader(test_set, batch_size=256,shuffle=True, num_workers=0)    
        print("Training on CIFAR")

    if args.model == 1:
        print("Using linear model.")
        model = LinearNN(data_name).to(device)
    else:
        print("Using 2-layer networks")
        model = TwoLayerNN(data_name).to(device)


    # Setup optimizer and privacy engine
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0)
    # model, optimizer, train_loader = privacy_engine.make_private(
    #     module=model,
    #     optimizer=optimizer,
    #     data_loader=train_loader,
    #     noise_multiplier=args.sigma,
    #     max_grad_norm=1e20,
    # )

    C = args.max_per_sample_grad_norm

    # Training
    train_acc = []
    epoch_record = []
    clip_freq_record = []
    test_acc = []

    for epoch in tqdm(range(args.epochs), desc="Epoch", unit="epoch"):
        # if epoch == 25:
        train(model, train_loader, optimizer, epoch + 1, C, device, train_acc, epoch_record)
        test_acc.append( test(model, test_loader, device) )
    
    # eps = privacy_engine.get_epsilon(args.delta)
    if args.model == 1:
        architecture = "Linear"
    else:
        architecture = "TwoLayerNN"
    filename = "./expResults/VC"+data_name+"C"+str(args.max_per_sample_grad_norm)+"SIGMA"+str(args.sigma)+"M"+str(architecture)+".txt"
    writeResult(filename, epoch_record, train_acc, test_acc, architecture )


def writeResult( filename, epoch_record, train_acc, test_acc, architecture ):
    f = open(filename, "w")
    print("Configuration:", file=f)
    # Retrieve epsilon.
    print(f"LR: {args.lr}, SIGMA: {args.sigma}, C: {args.max_per_sample_grad_norm}, ARCHI: {architecture} ", file=f)
    print("Epoch:", file=f)
    epochs_str = [str(e) for e in epoch_record]
    print( ','.join(epochs_str), file=f )
    print("Training:", file=f)
    train_acc_str = [str(acc) for acc in train_acc ]
    print( ','.join(train_acc_str), file=f )
    print("Testing:", file=f)
    test_acc_str = [str(acc) for acc in test_acc ]
    print( ','.join(test_acc_str), file=f )
    # print("Clip:", file=f)
    # clip_str = [str(clip) for clip in clip_freq_record]
    # print( ','.join(clip_str), file=f )
    f.close()

def timingGrowthCondition():
    model = TwoLayerNN('mnist').to(args.device)
    start = time.time()
    for i in range(500):
        a = getGrowthParameter(model)
    end = time.time()
    print(end - start)

def Test2():
    print("Testing time for vanilla SGD!")
     # Load data
    data_name = args.data
    if data_name == "mnist":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(MNIST_MEAN, MNIST_STD),
        ])
        train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, download=True, 
                        transform=transform),batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=False, download=True, 
                        transform=transform),batch_size=batch_size, shuffle=True)
        print("Training on MNIST!")
    else:
        train_set = CIFARDataLoader("./data/cifar-10-resnet-features/train_features2.npy", "./data/cifar-10-resnet-features/train_labels2.npy")
        test_set = CIFARDataLoader("./data/cifar-10-resnet-features/test_features2.npy", "./data/cifar-10-resnet-features/test_labels2.npy")
        train_loader = DataLoader(train_set, batch_size=128,shuffle=True, num_workers=0)    
        test_loader = DataLoader(test_set, batch_size=256,shuffle=True, num_workers=0)    
        print("Training on CIFAR")

    if args.model == 1:
        print("Using linear model.")
        model = LinearNN(data_name).to(device)
    else:
        print("Using 2-layer networks")
        model = TwoLayerNN(data_name).to(device)

    # Setup optimizer and privacy engine
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0)
    start_time = time.time()
    # for epoch in tqdm(range(args.epochs), desc="Epoch", unit="epoch"):
    for epoch in range(args.epochs):
        model.train()
        criterion = nn.CrossEntropyLoss()
        losses = []
        top1_acc = []
        for i, (images, target) in enumerate(train_loader):
            optimizer.zero_grad()
            images = images.to(device)
            target = target.to(device)
            output = model(images)

            loss = criterion(output, target)

            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
            labels = target.detach().cpu().numpy()
            # measure accuracy and record loss
            acc = accuracy(preds, labels)

            losses.append(loss.item())
            top1_acc.append(acc)

            loss.backward()
            optimizer.step()

            # if ((i+1) % 100 == 0) or (i+1)==len(train_loader):
            #     # epsilon = privacy_engine.get_epsilon(args.delta)
            #     print(
            #         f"\tTrain Epoch: {epoch} \t"
            #         f"Loss: {np.mean(losses):.6f} "
            #         f"Acc@1: {np.mean(top1_acc) * 100:.6f} "
            #         # f"Clip: {np.sum( numClipping ):.6f}"
            #         # f"(ε = {epsilon:.2f}, δ = {args.delta})"
            #     )
    end_time = time.time()
    print(f"\t time elapsed: {(end_time - start_time)/args.epochs}:.3f")



def TestVanillaSGD():
    print("Running vanilla SGD!")
     # Load data
    data_name = args.data
    if data_name == "mnist":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(MNIST_MEAN, MNIST_STD),
        ])
        train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, download=True, 
                        transform=transform),batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=False, download=True, 
                        transform=transform),batch_size=batch_size, shuffle=True)
        print("Training on MNIST!")
    else:
        train_set = CIFARDataLoader("./data/cifar-10-resnet-features/train_features2.npy", "./data/cifar-10-resnet-features/train_labels2.npy")
        test_set = CIFARDataLoader("./data/cifar-10-resnet-features/test_features2.npy", "./data/cifar-10-resnet-features/test_labels2.npy")
        train_loader = DataLoader(train_set, batch_size=128,shuffle=True, num_workers=0)    
        test_loader = DataLoader(test_set, batch_size=256,shuffle=True, num_workers=0)    
        print("Training on CIFAR")

    if args.model == 1:
        print("Using linear model.")
        model = LinearNN(data_name).to(device)
    else:
        print("Using 2-layer networks")
        model = TwoLayerNN(data_name).to(device)


    # Setup optimizer and privacy engine
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0)
    for epoch in tqdm(range(args.epochs), desc="Epoch", unit="epoch"):
        model.train()
        criterion = nn.CrossEntropyLoss()
        losses = []
        top1_acc = []
        for i, (images, target) in enumerate(train_loader):
            optimizer.zero_grad()
            images = images.to(device)
            target = target.to(device)
            output = model(images)

            loss = criterion(output, target)

            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
            labels = target.detach().cpu().numpy()
            # measure accuracy and record loss
            acc = accuracy(preds, labels)

            losses.append(loss.item())
            top1_acc.append(acc)

            loss.backward()
            optimizer.step()

            if ((i+1) % 100 == 0) or (i+1)==len(train_loader):
                # epsilon = privacy_engine.get_epsilon(args.delta)
                print(
                    f"\tTrain Epoch: {epoch} \t"
                    f"Loss: {np.mean(losses):.6f} "
                    f"Acc@1: {np.mean(top1_acc) * 100:.6f} "
                    # f"Clip: {np.sum( numClipping ):.6f}"
                    # f"(ε = {epsilon:.2f}, δ = {args.delta})"
                )


def TestMicroBatching():
    print("Running micro-batching!")
     # Load data
    data_name = args.data
    if data_name == "mnist":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(MNIST_MEAN, MNIST_STD),
        ])
        train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, download=True, 
                        transform=transform),batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=False, download=True, 
                        transform=transform),batch_size=batch_size, shuffle=True)
        print("Training on MNIST!")
    else:
        train_set = CIFARDataLoader("./data/cifar-10-resnet-features/train_features2.npy", "./data/cifar-10-resnet-features/train_labels2.npy")
        test_set = CIFARDataLoader("./data/cifar-10-resnet-features/test_features2.npy", "./data/cifar-10-resnet-features/test_labels2.npy")
        train_loader = DataLoader(train_set, batch_size=128,shuffle=True, num_workers=0)    
        test_loader = DataLoader(test_set, batch_size=256,shuffle=True, num_workers=0)    
        print("Training on CIFAR")

    if args.model == 1:
        print("Using linear model.")
        model = LinearNN(data_name).to(device)
    else:
        print("Using 2-layer networks")
        model = TwoLayerNN(data_name).to(device)


    # Setup optimizer and privacy engine
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0)
    for epoch in tqdm(range(args.epochs), desc="Epoch", unit="epoch"):
        model.train()
        criterion = nn.CrossEntropyLoss()
        losses = [0]
        top1_acc = [0]
        for i, (images, target) in enumerate(train_loader):
            optimizer.zero_grad()
            images = images.to(device)
            target = target.to(device)
            all_per_sample_gradients = []
            # print(images.shape)
            B = images.shape[0]
            for j in range(B):
                cur_image = images[j,:].to(device)
                cur_image = cur_image[None,:]
                # print(cur_image.shape)
                cur_label = torch.tensor([target[j]]).to(device)
                # cur_label = cur_label[None,:].to(device)
                output = model( cur_image )
                loss = criterion(output, cur_label)
                loss.backward()

                per_sample_grads = [p.grad.detach().clone() for p in model.parameters()]

                all_per_sample_gradients.append(per_sample_grads)
                model.zero_grad()
            # optimizer.zero_grad()
            # images = images.to(device)
            # target = target.to(device)
            # output = model(images)

            # loss = criterion(output, target)

            # preds = np.argmax(output.detach().cpu().numpy(), axis=1)
            # labels = target.detach().cpu().numpy()
            # # measure accuracy and record loss
            # acc = accuracy(preds, labels)

            # losses.append(loss.item())
            # top1_acc.append(acc)

            # loss.backward()
            # optimizer.step()

            if ((i+1) % 100 == 0) or (i+1)==len(train_loader):
                # epsilon = privacy_engine.get_epsilon(args.delta)
                print(
                    f"\tTrain Epoch: {epoch} \t"
                    f"Loss: {np.mean(losses):.6f} "
                    f"Acc@1: {np.mean(top1_acc) * 100:.6f} "
                    # f"Clip: {np.sum( numClipping ):.6f}"
                    # f"(ε = {epsilon:.2f}, δ = {args.delta})"
                )
                

if __name__ == "__main__":
    # Test2()
    # TestVanillaSGD()
    # TestMicroBatching()
    # timingGrowthCondition()
    main()
    # timingGrowthCondition()


# python MNISTscript.py -c 80 --lr 0.001 -m 1
# python MNISTscript.py -c 1 --lr 0.1 -m 1 --data cifar
# python VCScript.py -c 5 --sigma 1.0 --lr 0.01 -m 1 --data mnist
# dataloader = DataLoader(dataset, batch_size=128,shuffle=True, num_workers=0) 


# Vanilla SGD:
# -- MNIST Linear:   6.76s/epoch
# -- MNIST 2-layer:  6.77s/epoch
# -- CIFAR Linear:   1.66s/epoch
# -- CIFAR 2-layer:  1.60s/epoch

# Value clipping:
# -- MNIST Linear:   8.62s/epoch
# -- MNIST 2-layer:  8.96s/epoch
# -- CIFAR Linear:   1.60s/epoch
# -- CIFAR 2-layer:  1.81s/epoch

# Gradient clipping, OPACUS
# -- MNIST Linear:   12.80s/epoch
# -- MNIST 2-layer:  16.61s/epoch
# -- CIFAR Linear:   4.13s/epoch
# -- CIFAR 2-layer:  7.47s/epoch


