
#DEPENDENCIES
import argparse
import sys

import torch
import torch.nn as nn

from torchvision import datasets, transforms
#from kymatio.torch import Scattering2D
import os
import pickle
import numpy as np
import logging



from opacus import PrivacyEngine

import torch.nn.functional as F

import math
import opacus.privacy_analysis as tf_privacy


from copy import deepcopy

'''
device = "cuda:0"

parser = argparse.ArgumentParser(description='Settings')
parser.add_argument('--model_type', default = 'target', choices=['target','shadow'])
parser.add_argument('--P_x', default=0.5, type = float)
parser.add_argument('--target_epsilon', default = 5.0, type = float)
parser.add_argument('--dataset', default = 'mnist', choices=['cifar10', 'fmnist', 'mnist'])
parser.add_argument('--Trial',default=0, type=int)
parser.add_argument('--model_number', default = 0, type=int)

args = parser.parse_args()

'''

#BELOW IS FROM TRAMER et al. code to train with DP

"""#Logging"""

import shutil
import sys
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import os
import shutil
import sys
from torch.utils.tensorboard import SummaryWriter
import torch


def model_input(data, device):
    datum = data.data[0:1]
    if isinstance(datum, np.ndarray):
        return torch.from_numpy(datum).float().to(device)
    else:
        return datum.float().to(device)


def get_script():
    py_script = os.path.basename(sys.argv[0])
    return os.path.splitext(py_script)[0]


def get_specified_params(hparams):
    keys = [k.split("=")[0][2:] for k in sys.argv[1:]]
    specified = {k: hparams[k] for k in keys}
    return specified


def make_hparam_str(hparams, exclude):
    return ",".join([f"{key}_{value}"
                     for key, value in sorted(hparams.items())
                     if key not in exclude])


class Logger(object):
    def __init__(self, logdir):

        if logdir is None:
            self.writer = None
        else:
            if os.path.exists(logdir) and os.path.isdir(logdir):
                shutil.rmtree(logdir)

            self.writer = SummaryWriter(log_dir=logdir)

    def log_model(self, model, input_to_model):
        if self.writer is None:
            return
        self.writer.add_graph(model, input_to_model)

    def log_epoch(self, epoch, train_loss, train_acc, test_loss, test_acc, epsilon=None):
        if self.writer is None:
            return
        self.writer.add_scalar("Loss/train", train_loss, epoch)
        self.writer.add_scalar("Loss/test", test_loss, epoch)
        self.writer.add_scalar("Accuracy/train", train_acc, epoch)
        self.writer.add_scalar("Accuracy/test", test_acc, epoch)

        if epsilon is not None:
            self.writer.add_scalar("Acc@Eps/train", train_acc, 100*epsilon)
            self.writer.add_scalar("Acc@Eps/test", test_acc, 100*epsilon)

    def log_scalar(self, tag, scalar_value, global_step):
        if self.writer is None or scalar_value is None:
            return
        self.writer.add_scalar(tag, scalar_value, global_step)

"""#Training and Test Functions"""

def get_device():
    use_cuda = torch.cuda.is_available()
    assert use_cuda
    device = torch.device("cuda" if use_cuda else "cpu")
    return device

def train(model, train_loader, optimizer, n_acc_steps=1):
    device = next(model.parameters()).device
    model.train()
    num_examples = 0
    correct = 0
    train_loss = 0

    rem = len(train_loader) % n_acc_steps
    num_batches = len(train_loader)
    num_batches -= rem

    bs = train_loader.batch_size if train_loader.batch_size is not None else train_loader.batch_sampler.batch_size
    print(f"training on {num_batches} batches of size {bs}")

    for batch_idx, (data, target) in enumerate(train_loader):

        if batch_idx > num_batches - 1:
            break

        data, target = data.to(device), target.to(device)

        output = model(data)

        loss = F.cross_entropy(output, target)
        loss.backward()

        if ((batch_idx + 1) % n_acc_steps == 0) or ((batch_idx + 1) == len(train_loader)):
            optimizer.step()
            optimizer.zero_grad()
        else:
            with torch.no_grad():
                # accumulate per-example gradients but don't take a step yet
                optimizer.virtual_step()

        pred = output.max(1, keepdim=True)[1]
        correct += pred.eq(target.view_as(pred)).sum().item()
        train_loss += F.cross_entropy(output, target, reduction='sum').item()
        num_examples += len(data)

    train_loss /= num_examples
    train_acc = 100. * correct / num_examples

    print(f'Train set: Average loss: {train_loss:.4f}, '
            f'Accuracy: {correct}/{num_examples} ({train_acc:.2f}%)')

    return train_loss, train_acc

def test(model, test_loader):
    device = next(model.parameters()).device
    model.eval()
    num_examples = 0
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            num_examples += len(data)

    test_loss /= num_examples
    test_acc = 100. * correct / num_examples

    print(f'Test set: Average loss: {test_loss:.4f}, '
          f'Accuracy: {correct}/{num_examples} ({test_acc:.2f}%)')

    return test_loss, test_acc

"""#Creating Different Datasets"""

SHAPES = {
    "cifar10": (32, 32, 3),
    "cifar10_500K": (32, 32, 3),
    "fmnist": (28, 28, 1),
    "mnist": (28, 28, 1)
}

#The DATA PREP Functions from Florian

def get_scatter_transform(dataset):
    shape = SHAPES[dataset]
    scattering = Scattering2D(J=2, shape=shape[:2])
    K = 81 * shape[2]
    (h, w) = shape[:2]
    return scattering, K, (h//4, w//4)

def get_data(name, augment=False, **kwargs):
    if name == "cifar10":
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        if augment:
            train_transforms = [
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, 4),
                    transforms.ToTensor(),
                    normalize,
                ]
        else:
            train_transforms = [
                transforms.ToTensor(),
                normalize,
            ]

        train_set = datasets.CIFAR10(root=".data", train=True,
                                     transform=transforms.Compose(train_transforms),
                                     download=True)

        test_set = datasets.CIFAR10(root=".data", train=False,
                                    transform=transforms.Compose(
                                        [transforms.ToTensor(), normalize]
                                    ))

    #Added below to handle SVHN extended
    elif name == 'svhn_ext':
        train_set = datasets.SVHN(root = 'svhn_data', split = 'extra', transform=transforms.ToTensor(),
                                         download=False)

        test_set = datasets.SVHN(root = 'svhn_data', split = 'test', transform=transforms.ToTensor(),
                                         download=False)

    elif name == "fmnist":
        train_set = datasets.FashionMNIST(root='.data', train=True,
                                          transform=transforms.ToTensor(),
                                          download=True)

        test_set = datasets.FashionMNIST(root='.data', train=False,
                                         transform=transforms.ToTensor(),
                                         download=True)

    elif name == "mnist":
        train_set = datasets.MNIST(root='.data', train=True,
                                   transform=transforms.ToTensor(),
                                   download=True)

        test_set = datasets.MNIST(root='.data', train=False,
                                  transform=transforms.ToTensor(),
                                  download=True)

    elif name == "cifar10_500K":

        # extended version of CIFAR-10 with pseudo-labelled tinyimages

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        if augment:
            train_transforms = [
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ToTensor(),
                normalize,
            ]
        else:
            train_transforms = [
                transforms.ToTensor(),
                normalize,
            ]

        train_set = SemiSupervisedDataset(kwargs['aux_data_filename'],
                                          root=".data",
                                          train=True,
                                          download=True,
                                          transform=transforms.Compose(train_transforms))
        test_set = None
    else:
        raise ValueError(f"unknown dataset {name}")

    return train_set, test_set


class SemiSupervisedDataset(torch.utils.data.Dataset):
    def __init__(self,
                 aux_data_filename=None,
                 train=False,
                 **kwargs):
        """A dataset with auxiliary pseudo-labeled data"""

        self.dataset = datasets.CIFAR10(train=train, **kwargs)
        self.train = train

        # shuffle cifar-10
        p = np.random.permutation(len(self.data))
        self.data = self.data[p]
        self.targets = list(np.asarray(self.targets)[p])

        if self.train:
            self.sup_indices = list(range(len(self.targets)))
            self.unsup_indices = []

            aux_path = os.path.join(kwargs['root'], aux_data_filename)
            print("Loading data from %s" % aux_path)
            with open(aux_path, 'rb') as f:
                aux = pickle.load(f)
            aux_data = aux['data']
            aux_targets = aux['extrapolated_targets']
            orig_len = len(self.data)

            # shuffle additional data
            p = np.random.permutation(len(aux_data))
            aux_data = aux_data[p]
            aux_targets = aux_targets[p]

            self.data = np.concatenate((self.data, aux_data), axis=0)
            self.targets.extend(aux_targets)

            # note that we use unsup indices to track the labeled datapoints
            # whose labels are "fake"
            self.unsup_indices.extend(
                range(orig_len, orig_len+len(aux_data)))

            logger = logging.getLogger()
            logger.info("Training set")
            logger.info("Number of training samples: %d", len(self.targets))
            logger.info("Number of supervised samples: %d",
                        len(self.sup_indices))
            logger.info("Number of unsup samples: %d", len(self.unsup_indices))
            logger.info("Label (and pseudo-label) histogram: %s",
                        tuple(
                            zip(*np.unique(self.targets, return_counts=True))))
            logger.info("Shape of training data: %s", np.shape(self.data))

        # Test set
        else:
            self.sup_indices = list(range(len(self.targets)))
            self.unsup_indices = []

            logger = logging.getLogger()
            logger.info("Test set")
            logger.info("Number of samples: %d", len(self.targets))
            logger.info("Label histogram: %s",
                        tuple(
                            zip(*np.unique(self.targets, return_counts=True))))
            logger.info("Shape of data: %s", np.shape(self.data))

    @property
    def data(self):
        return self.dataset.data

    @data.setter
    def data(self, value):
        self.dataset.data = value

    @property
    def targets(self):
        return self.dataset.targets

    @targets.setter
    def targets(self, value):
        self.dataset.targets = value

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        self.dataset.labels = self.targets  # because torchvision is annoying
        return self.dataset[item]

    def __repr__(self):
        fmt_str = 'Semisupervised Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        fmt_str += '    Training: {}\n'.format(self.train)
        fmt_str += '    Root Location: {}\n'.format(self.dataset.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.dataset.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.dataset.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str


class SemiSupervisedSampler(torch.utils.data.Sampler):
    def __init__(self, num_examples, num_batches, batch_size):
        self.inds = list(range(num_examples))
        self.batch_size = batch_size
        self.num_batches = num_batches
        super().__init__(None)

    def __iter__(self):
        batch_counter = 0
        inds_shuffled = [self.inds[i] for i in torch.randperm(len(self.inds))]

        while len(inds_shuffled) < self.num_batches*self.batch_size:
            temp = [self.inds[i] for i in torch.randperm(len(self.inds))]
            inds_shuffled.extend(temp)

        for k in range(0, self.num_batches*self.batch_size, self.batch_size):
            if batch_counter == self.num_batches:
                break

            batch = inds_shuffled[k:(k + self.batch_size)]

            # this shuffle operation is very important, without it
            # batch-norm / DataParallel hell ensues
            np.random.shuffle(batch)
            yield batch
            batch_counter += 1

    def __len__(self):
        return self.num_batches


class PoissonSampler(torch.utils.data.Sampler):
    def __init__(self, num_examples, batch_size):
        self.inds = np.arange(num_examples)
        self.batch_size = batch_size
        self.num_batches = int(np.ceil(num_examples / batch_size))
        self.sample_rate = self.batch_size / (1.0 * num_examples)
        super().__init__(None)

    def __iter__(self):
        # select each data point independently with probability `sample_rate`
        for i in range(self.num_batches):
            batch_idxs = np.random.binomial(n=1, p=self.sample_rate, size=len(self.inds))
            batch = self.inds[batch_idxs.astype(np.bool)]
            np.random.shuffle(batch)
            yield batch

    def __len__(self):
        return self.num_batches


def get_scattered_dataset(loader, scattering, device, data_size):
    # pre-compute a scattering transform (if there is one) and return
    # a TensorDataset

    scatters = []
    targets = []

    num = 0
    for (data, target) in loader:
        data, target = data.to(device), target.to(device)
        if scattering is not None:
            data = scattering(data)
        scatters.append(data)
        targets.append(target)

        num += len(data)
        if num > data_size:
            break

    scatters = torch.cat(scatters, axis=0)
    targets = torch.cat(targets, axis=0)

    scatters = scatters[:data_size]
    targets = targets[:data_size]

    data = torch.utils.data.TensorDataset(scatters, targets)
    return data


def get_scattered_loader(loader, scattering, device, drop_last=False, sample_batches=False):
    # pre-compute a scattering transform (if there is one) and return
    # a DataLoader

    scatters = []
    targets = []

    for (data, target) in loader:
        data, target = data.to(device), target.to(device)
        if scattering is not None:
            data = scattering(data)
        scatters.append(data)
        targets.append(target)

    scatters = torch.cat(scatters, axis=0)
    targets = torch.cat(targets, axis=0)

    data = torch.utils.data.TensorDataset(scatters, targets)

    if sample_batches:
        sampler = PoissonSampler(len(scatters), loader.batch_size)
        return torch.utils.data.DataLoader(data, batch_sampler=sampler,
                                           num_workers=0, pin_memory=False)
    else:
        shuffle = isinstance(loader.sampler, torch.utils.data.RandomSampler)
        return torch.utils.data.DataLoader(data,
                                           batch_size=loader.batch_size,
                                           shuffle=shuffle,
                                           num_workers=0,
                                           pin_memory=False,
                                           drop_last=drop_last)

def sampling_vector(shape,probability):
  probability_vec = np.random.uniform(low = 0.0, high = 1.0, size = shape)
  accept = probability * np.ones(shape)

  sampling = (probability_vec <= accept)

  return sampling

def new_dataset(dataset,sampling):

  new_dataset = []
  for i in range(len(dataset)):
    if sampling[i]:
      new_dataset.append(dataset[i])
      
  return new_dataset

"""#DP Functions"""

ORDERS = [1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64))


def get_renyi_divergence(sample_rate, noise_multiplier, orders=ORDERS):
    rdp = torch.tensor(
        tf_privacy.compute_rdp(
            sample_rate, noise_multiplier, 1, orders
        )
    )
    return rdp


def get_privacy_spent(total_rdp, target_delta=1e-5, orders=ORDERS):
    return tf_privacy.get_privacy_spent(orders, total_rdp, target_delta)


def get_epsilon(sample_rate, mul, num_steps, target_delta=1e-5, orders=ORDERS, rdp_init=0):
    # compute the epsilon budget spent after `num_steps` with batch sampling rate
    # of `sample_rate` and a noise multiplier of `mul`

    rdp = rdp_init + get_renyi_divergence(sample_rate, mul, orders=orders) * num_steps
    eps, _ = get_privacy_spent(rdp, target_delta=target_delta, orders=orders)
    return eps


def get_noise_mul(num_samples, batch_size, target_epsilon, epochs, rdp_init=0, target_delta=1e-5, orders=ORDERS):
    # compute the noise multiplier that results in a privacy budget
    # of `target_epsilon` being spent after a given number of epochs of DP-SGD.

    mul_low = 100
    mul_high = 0.1

    num_steps = math.floor(num_samples // batch_size) * epochs
    sample_rate = batch_size / (1.0 * num_samples)

    eps_low = get_epsilon(sample_rate, mul_low, num_steps, target_delta, orders, rdp_init=rdp_init)
    eps_high = get_epsilon(sample_rate, mul_high, num_steps, target_delta, orders, rdp_init=rdp_init)

    assert eps_low < target_epsilon
    assert eps_high > target_epsilon

    while eps_high - eps_low > 0.01:
        mul_mid = (mul_high + mul_low) / 2
        eps_mid = get_epsilon(sample_rate, mul_mid, num_steps, target_delta, orders, rdp_init=rdp_init)

        if eps_mid <= target_epsilon:
            mul_low = mul_mid
            eps_low = eps_mid
        else:
            mul_high = mul_mid
            eps_high = eps_mid

    return mul_low


def get_noise_mul_privbyiter(num_samples, batch_size, target_epsilon, epochs, target_delta=1e-5):
    mul_low = 100
    mul_high = 0.1

    eps_low = priv_by_iter_guarantees(epochs, batch_size, num_samples, mul_low, target_delta, verbose=False)
    eps_high = priv_by_iter_guarantees(epochs, batch_size, num_samples, mul_high, target_delta, verbose=False)

    assert eps_low < target_epsilon
    assert eps_high > target_epsilon

    while eps_high - eps_low > 0.01:
        mul_mid = (mul_high + mul_low) / 2
        eps_mid = priv_by_iter_guarantees(epochs, batch_size, num_samples, mul_mid, target_delta, verbose=False)

        if eps_mid <= target_epsilon:
            mul_low = mul_mid
            eps_low = eps_mid
        else:
            mul_high = mul_mid
            eps_high = eps_mid

    return mul_low


def scatter_normalization(train_loader, scattering, K, device,
                          data_size, sample_size,
                          noise_multiplier=1.0, orders=ORDERS, save_dir=None):
    # privately compute the mean and variance of scatternet features to normalize
    # the data.

    rdp = 0
    epsilon_norm = np.inf
    if noise_multiplier > 0:
        # compute the RDP spent in this step
        sample_rate = sample_size / (1.0 * data_size)
        rdp = 2*get_renyi_divergence(sample_rate, noise_multiplier, orders)
        epsilon_norm, _ = get_privacy_spent(rdp)

    # try loading pre-computed stats
    use_scattering = scattering is not None
    assert use_scattering
    mean_path = os.path.join(save_dir, f"mean_bn_{sample_size}_{noise_multiplier}_{use_scattering}.npy")
    var_path = os.path.join(save_dir, f"var_bn_{sample_size}_{noise_multiplier}_{use_scattering}.npy")

    print(f"Using BN stats for {sample_size}/{data_size} samples")
    print(f"With noise_mul={noise_multiplier}, we get ε_norm = {epsilon_norm:.3f}")

    try:
        print(f"loading {mean_path}")
        mean = np.load(mean_path)
        var = np.load(var_path)
        print(mean.shape, var.shape)
    except OSError:

        # compute the scattering transform and the mean and squared mean of features
        scatters = []
        mean = 0
        sq_mean = 0
        count = 0
        for idx, (data, target) in enumerate(train_loader):
            with torch.no_grad():
                data = data.to(device)
                if scattering is not None:
                    data = scattering(data).view(-1, K, data.shape[2]//4, data.shape[3]//4)
                if noise_multiplier == 0:
                    data = data.reshape(len(data), K, -1).mean(-1)
                    mean += data.sum(0).cpu().numpy()
                    sq_mean += (data**2).sum(0).cpu().numpy()
                else:
                    scatters.append(data.cpu().numpy())

                count += len(data)
                if count >= sample_size:
                    break

        if noise_multiplier > 0:
            scatters = np.concatenate(scatters, axis=0)
            scatters = np.transpose(scatters, (0, 2, 3, 1))

            scatters = scatters[:sample_size]

            # s x K
            scatter_means = np.mean(scatters.reshape(len(scatters), -1, K), axis=1)
            norms = np.linalg.norm(scatter_means, axis=-1)

            # technically a small privacy leak, sue me...
            thresh_mean = np.quantile(norms, 0.5)
            scatter_means /= np.maximum(norms / thresh_mean, 1).reshape(-1, 1)
            mean = np.mean(scatter_means, axis=0)

            mean += np.random.normal(scale=thresh_mean * noise_multiplier,
                                     size=mean.shape) / sample_size

            # s x K
            scatter_sq_means = np.mean((scatters ** 2).reshape(len(scatters), -1, K),
                                       axis=1)
            norms = np.linalg.norm(scatter_sq_means, axis=-1)

            # technically a small privacy leak, sue me...
            thresh_var = np.quantile(norms, 0.5)
            print(f"thresh_mean={thresh_mean:.2f}, thresh_var={thresh_var:.2f}")
            scatter_sq_means /= np.maximum(norms / thresh_var, 1).reshape(-1, 1)
            sq_mean = np.mean(scatter_sq_means, axis=0)
            sq_mean += np.random.normal(scale=thresh_var * noise_multiplier,
                                        size=sq_mean.shape) / sample_size
            var = np.maximum(sq_mean - mean ** 2, 0)
        else:
            mean /= count
            sq_mean /= count
            var = np.maximum(sq_mean - mean ** 2, 0)

        if save_dir is not None:
            print(f"saving mean and var: {mean.shape} {var.shape}")
            np.save(mean_path, mean)
            np.save(var_path, var)

    mean = torch.from_numpy(mean).to(device)
    var = torch.from_numpy(var).to(device)

    return (mean, var), rdp


def priv_by_iter_guarantees(epochs, batch_size, samples, noise_multiplier, delta=1e-5, verbose=True):
    """Tabulating position-dependent privacy guarantees."""
    if noise_multiplier == 0:
        if verbose:
            print('No differential privacy (additive noise is 0).')
        return np.inf

    if verbose:
        print('In the conditions of Theorem 34 (https://arxiv.org/abs/1808.06651) '
              'the training procedure results in the following privacy guarantees.')
        print('Out of the total of {} samples:'.format(samples))

    steps_per_epoch = samples // batch_size
    orders = np.concatenate([np.linspace(2, 20, num=181), np.linspace(20, 100, num=81)])
    for p in (.5, .9, .99):
        steps = math.ceil(steps_per_epoch * p)  # Steps in the last epoch.
        coef = 2 * (noise_multiplier)**-2 * (
            # Accounting for privacy loss
            (epochs - 1) / steps_per_epoch +  # ... from all-but-last epochs
            1 / (steps_per_epoch - steps + 1))  # ... due to the last epoch
        # Using RDP accountant to compute eps. Doing computation analytically is
        # an option.
        rdp = [order * coef for order in orders]
        eps, _ = get_privacy_spent(rdp, delta, orders)
        if verbose:
            print('\t{:g}% enjoy at least ({:.2f}, {})-DP'.format(
                p * 100, eps, delta))

    return eps

def setup_priv_engine(model, optimizer, bs , num_samples, target_epsilon, epochs, target_delta, max_grad_norm):
  ORDERS = [1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64))

  noise_multiplier = get_noise_mul(num_samples, bs, target_epsilon, epochs, target_delta=target_delta)

  privacy_engine = PrivacyEngine(
        model,
        sample_rate=bs / num_samples,
        alphas=ORDERS,
        noise_multiplier=noise_multiplier,
        max_grad_norm=max_grad_norm,
    )
  
  privacy_engine.attach(optimizer)

"""#Models"""

def standardize(x, bn_stats):
    if bn_stats is None:
        return x

    bn_mean, bn_var = bn_stats

    view = [1] * len(x.shape)
    view[1] = -1
    x = (x - bn_mean.view(view)) / torch.sqrt(bn_var.view(view) + 1e-5)

    # if variance is too low, just ignore
    x *= (bn_var.view(view) != 0).float()
    return x


def clip_data(data, max_norm):
    norms = torch.norm(data.reshape(data.shape[0], -1), dim=-1)
    scale = (max_norm / norms).clamp(max=1.0)
    data *= scale.reshape(-1, 1, 1, 1)
    return data


def get_num_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


class StandardizeLayer(nn.Module):
    def __init__(self, bn_stats):
        super(StandardizeLayer, self).__init__()
        self.bn_stats = bn_stats

    def forward(self, x):
        return standardize(x, self.bn_stats)


class ClipLayer(nn.Module):
    def __init__(self, max_norm):
        super(ClipLayer, self).__init__()
        self.max_norm = max_norm

    def forward(self, x):
        return clip_data(x, self.max_norm)


class CIFAR10_CNN(nn.Module):
    def __init__(self, in_channels=3, input_norm=None, **kwargs):
        super(CIFAR10_CNN, self).__init__()
        self.in_channels = in_channels
        self.features = None
        self.classifier = None
        self.norm = None

        self.build(input_norm, **kwargs)

    def build(self, input_norm=None, num_groups=None,
              bn_stats=None, size=None):

        if self.in_channels == 3:
            if size == "small":
                cfg = [16, 16, 'M', 32, 32, 'M', 64, 'M']
            else:
                cfg = [32, 32, 'M', 64, 64, 'M', 128, 128, 'M']

            self.norm = nn.Identity()
        else:
            if size == "small":
                cfg = [16, 16, 'M', 32, 32]
            else:
                cfg = [64, 'M', 64]
            if input_norm is None:
                self.norm = nn.Identity()
            elif input_norm == "GroupNorm":
                self.norm = nn.GroupNorm(num_groups, self.in_channels, affine=False)
            else:
                self.norm = lambda x: standardize(x, bn_stats)

        layers = []
        act = nn.Tanh

        c = self.in_channels
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(c, v, kernel_size=3, stride=1, padding=1)

                layers += [conv2d, act()]
                c = v

        self.features = nn.Sequential(*layers)

        if self.in_channels == 3:
            hidden = 128
            self.classifier = nn.Sequential(nn.Linear(c * 4 * 4, hidden), act(), nn.Linear(hidden, 10))
        else:
            self.classifier = nn.Linear(c * 4 * 4, 10)

    def forward(self, x):
        if self.in_channels != 3:
            x = self.norm(x.view(-1, self.in_channels, 8, 8))
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


class MNIST_CNN(nn.Module):
    def __init__(self, in_channels=1, input_norm=None, **kwargs):
        super(MNIST_CNN, self).__init__()
        self.in_channels = in_channels
        self.features = None
        self.classifier = None
        self.norm = None

        self.build(input_norm, **kwargs)

    def build(self, input_norm=None, num_groups=None,
              bn_stats=None, size=None):
        if self.in_channels == 1:
            ch1, ch2 = (16, 32) if size is None else (32, 64)
            cfg = [(ch1, 8, 2, 2), 'M', (ch2, 4, 2, 0), 'M']
            self.norm = nn.Identity()
        else:
            ch1, ch2 = (16, 32) if size is None else (32, 64)
            cfg = [(ch1, 3, 2, 1), (ch2, 3, 1, 1)]
            if input_norm == "GroupNorm":
                self.norm = nn.GroupNorm(num_groups, self.in_channels, affine=False)
            elif input_norm == "BN":
                self.norm = lambda x: standardize(x, bn_stats)
            else:
                self.norm = nn.Identity()

        layers = []

        c = self.in_channels
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=1)]
            else:
                filters, k_size, stride, pad = v
                conv2d = nn.Conv2d(c, filters, kernel_size=k_size, stride=stride, padding=pad)

                layers += [conv2d, nn.Tanh()]
                c = filters

        self.features = nn.Sequential(*layers)

        hidden = 32
        self.classifier = nn.Sequential(nn.Linear(c * 4 * 4, hidden),
                                        nn.Tanh(),
                                        nn.Linear(hidden, 10))

    def forward(self, x):
        if self.in_channels != 1:
            x = self.norm(x.view(-1, self.in_channels, 7, 7))
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


class ScatterLinear(nn.Module):
    def __init__(self, in_channels, hw_dims, input_norm=None, classes=10, clip_norm=None, **kwargs):
        super(ScatterLinear, self).__init__()
        self.K = in_channels
        self.h = hw_dims[0]
        self.w = hw_dims[1]
        self.fc = None
        self.norm = None
        self.clip = None
        self.build(input_norm, classes=classes, clip_norm=clip_norm, **kwargs)

    def build(self, input_norm=None, num_groups=None, bn_stats=None, clip_norm=None, classes=10):
        self.fc = nn.Linear(self.K * self.h * self.w, classes)

        if input_norm is None:
            self.norm = nn.Identity()
        elif input_norm == "GroupNorm":
            self.norm = nn.GroupNorm(num_groups, self.K, affine=False)
        else:
            self.norm = lambda x: standardize(x, bn_stats)

        if clip_norm is None:
            self.clip = nn.Identity()
        else:
            self.clip = ClipLayer(clip_norm)

    def forward(self, x):
        x = self.norm(x.view(-1, self.K, self.h, self.w))
        x = self.clip(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x


CNNS = {
    "cifar10": CIFAR10_CNN,
    "fmnist": MNIST_CNN,
    "mnist": MNIST_CNN,
    "svhn_ext": CIFAR10_CNN
}


if __name__ == "__main__":

    device = "cuda:0"

    parser = argparse.ArgumentParser(description='Settings')
    parser.add_argument('--model_type', default = 'target', choices=['target','shadow'])
    parser.add_argument('--P_x', default=0.5, type = float)
    parser.add_argument('--target_epsilon', default = 5.0, type = float)
    parser.add_argument('--dataset', default = 'mnist', choices=['cifar10', 'fmnist', 'mnist','svhn_ext'])
    parser.add_argument('--Trial',default=0, type=int)
    parser.add_argument('--model_number', default = 0, type=int)

    args = parser.parse_args()

    #Inputs
    model_type = args.model_type
    P_x = args.P_x
    dataset = args.dataset
    target_epsilon = args.target_epsilon
    Trial = args.Trial


    Sampling_Folder = f'./lira_samplings'
    sampling_path = Sampling_Folder + "/" + f"sampling_{dataset}_{int(target_epsilon)}_{model_type}_{int(P_x*10)}_{Trial}_{args.model_number}"

    print("Sampling Path")
    print(sampling_path)

    if os.path.exists(sampling_path + ".npy"):
        print("already exists")
        exit(0)


    '''
    Models_Folder = f"./lira_models"
    model_path = Models_Folder + "/" + f"model_{dataset}_{int(target_epsilon)}_{model_type}_{int(P_x*10)}_{Trial}_{args.model_number}.pt"

    if os.path.exists(model_path):
        print("already exists")
        exit(0)
    '''


    #Hyperparameters

    logdir = None
    logger = Logger(logdir)

    early_stop = False

    augment = False 
    if dataset == "mnist":
        batch_size = 512 
        lr = 0.5
    elif dataset == "cifar10":
        batch_size = 1024
        lr = 1
    elif dataset == "svhn_ext":
        batch_size = 1024
        lr = 1
    elif dataset == "fminst":
        batch_size = 2048 #mnist 512, cifar10 1024, fmnist 2048
        lr = 4

    epochs = 50
    momentum = 0.9
    nesterov = False 
    mini_batch_size = 256
    sample_batches = False 

    #More Hyperparameters 

    max_grad_norm = 0.1

    max_epsilon = None
    target_delta = 1e-5

    input_norm = None
    size = None
    num_groups = int(81) 

    #Now getting the sampled data

    train_data_before, test_data = get_data(dataset)
    sampling = sampling_vector(len(train_data_before), P_x)
    train_data = new_dataset(train_data_before, sampling)

    #define noise multiplier
    num_samples = len(train_data)
    noise_multiplier = get_noise_mul(num_samples, batch_size, target_epsilon, epochs, target_delta=target_delta)

    #Just taking from Florian's code
    scattering = None
    K = 3 if len(train_data_before.data.shape) == 4 else 1

    bs = batch_size
    assert bs % mini_batch_size == 0
    n_acc_steps = bs // mini_batch_size

    
    if sample_batches:
            assert n_acc_steps == 1
            assert not augment

    train_loader = torch.utils.data.DataLoader(
            train_data, batch_size=mini_batch_size, shuffle=True, num_workers=1, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(
            test_data, batch_size=mini_batch_size, shuffle=False, num_workers=1, pin_memory=True)


    rdp_norm = 0
    model = CNNS[dataset](K, input_norm=input_norm, num_groups=num_groups, size=size)
    model.to(device)

    train_loader = get_scattered_loader(train_loader, scattering, device,
                                        drop_last=True, sample_batches=sample_batches)

    test_loader = get_scattered_loader(test_loader, scattering, device)

    print(f"model has {get_num_params(model)} parameters")

    optimizer = torch.optim.SGD(model.parameters(), lr=lr,
                                        momentum=momentum,
                                        nesterov=nesterov)


    privacy_engine = PrivacyEngine(
            model,
            sample_rate=bs / len(train_data),
            alphas=ORDERS,
            noise_multiplier=noise_multiplier,
            max_grad_norm=max_grad_norm,
        )

    privacy_engine.attach(optimizer)

    #TRAINING RUN

    best_acc = 0
    flat_count = 0
    final_acc = 0

    for epoch in range(0, epochs):
        print(f"\nEpoch: {epoch}")

        train_loss, train_acc = train(model, train_loader, optimizer, n_acc_steps=n_acc_steps)
        test_loss, test_acc = test(model, test_loader)
        final_acc = test_acc #update the final_acc

        if noise_multiplier > 0:
            rdp_sgd = get_renyi_divergence(
                privacy_engine.sample_rate, privacy_engine.noise_multiplier
            ) * privacy_engine.steps
            epsilon, _ = get_privacy_spent(rdp_norm + rdp_sgd)
            epsilon2, _ = get_privacy_spent(rdp_sgd)
            print(f"ε = {epsilon:.3f} (sgd only: ε = {epsilon2:.3f})")

            if max_epsilon is not None and epsilon >= max_epsilon:
                break
        else:
            epsilon = None

        logger.log_epoch(epoch, train_loss, train_acc, test_loss, test_acc, epsilon)
        logger.log_scalar("epsilon/train", epsilon, epoch)

        # stop if we're not making progress
        if test_acc > best_acc:
            best_acc = test_acc
            flat_count = 0
        else:
            flat_count += 1
            if flat_count >= 20 and early_stop:
                print("plateau...")
                print(f"Best Accuracy {best_acc}")
                break


    print(best_acc)

    Models_Folder = f"./lira_models"
    if not os.path.exists(Models_Folder):
        
        # if the demo_folder directory is not present 
        # then create it.
        os.makedirs(Models_Folder)

    Sampling_Folder = f'./lira_samplings'
    if not os.path.exists(Sampling_Folder):
        
        # if the demo_folder directory is not present 
        # then create it.
        os.makedirs(Sampling_Folder)

    model_path = Models_Folder + "/" + f"model_{dataset}_{int(target_epsilon)}_{model_type}_{int(P_x*10)}_{Trial}_{args.model_number}.pt"

    sampling_path = Sampling_Folder + "/" + f"sampling_{dataset}_{int(target_epsilon)}_{model_type}_{int(P_x*10)}_{Trial}_{args.model_number}"

    torch.save(deepcopy(model.state_dict()), model_path)
    np.save(sampling_path, sampling)