import utils
import torch
import numpy as np
import random
import torchvision
import torch.nn.functional as F
from torchvision import datasets, transforms
cuda = torch.cuda.is_available()
from analysis.private_knn import PrivateKnn
import dfmenetwork
import scipy
import scipy.stats
import math
print(cuda)
args = None


def computeentropy(t_logits):
    num_classes = 10  # Change if needed
    entropy = []
    # for i in range(t_logits.size(0)):
    #     prob = F.softmax(t_logits[0], dim=-1).cpu().numpy()
    #     entropy.append(scipy.stats.entropy(prob, axis=-1))
    prob = F.softmax(t_logits, dim=1).cpu().detach().numpy()
    entropy.append(scipy.stats.entropy(prob, axis=1))
    entropy = np.concatenate(entropy, axis=0)
    entropy_max = np.log(10)
    utility = entropy / entropy_max
    return utility

def computegap(t_logits):
    gap = []
    sorted_output = t_logits.sort(dim=-1, descending=True)[0]
    prob = F.softmax(sorted_output[:, :2], dim=1).cpu().detach().numpy()
    gap.append(prob[:, 0] - prob[:, 1])
    gap = np.concatenate(gap, axis=0)
    utility = 1 - gap
    return utility

def get_votes_for_pate_knn(model, t_logits, train_represent,
                           train_labels, args = None):
    """
    :param model: the model to be used
    :param unlabeled_loader: data points to be labeled - for which we compute
        the score
    :param train_represent: last layer representation for the teachers
    :param train_labels: labels for the teachers
    :param args: the program parameters
    :return: votes for each data point
    """

    # num_teachers: number of k nearest neighbors acting as teachers
    num_teachers = 300

    with torch.no_grad():
        # Privacy cost as a proxy for utility.
        votes = []
        predictions = []
        outputs = F.log_softmax(t_logits, dim=-1)
        outputs = outputs.cpu().numpy()
        predictions.append(np.argmax(outputs, axis=-1))
        for output in outputs:
            dis = np.linalg.norm(train_represent - output, axis=-1)
            k_index = np.argpartition(dis, kth=num_teachers)[:num_teachers]
            teachers_preds = np.array(train_labels[k_index], dtype=np.int32)
            label_count = np.bincount(
                teachers_preds, minlength=10)
            votes.append(label_count)
    votes = np.stack(votes)
    return votes


class PateKNN:
    """
    Compute the privacy cost.
    """

    def __init__(self, model, trainloader, args):
        """
        Args:
            model: the victim model.
            trainloader: the data loader for the training data.
            args: the program parameters.
        """
        self.model = model
        self.args = args

        # Extract the last layer representation of the training points and their
        # ground-truth labels.
        train_represent = []
        train_labels = []
        with torch.no_grad():
            for batch_id, (data, target) in enumerate(trainloader):
                if cuda:
                    data = data.cuda()
                outputs = model(data)
                outputs = F.log_softmax(outputs, dim=-1)
                outputs = outputs.cpu().numpy()
                train_represent.append(outputs)
                train_labels.append(target.cpu().numpy())
        self.train_represent = np.concatenate(train_represent, axis=0)
        self.train_labels = np.concatenate(train_labels, axis=0)

        self.private_knn = PrivateKnn(
            delta=1e-5, sigma_gnmax=28,
            apply_data_independent_bound=False)

    def compute_privacy_cost(self, t_logits):
        """
        Args:
            unlabeled_loader: data loader for new queries.
        Returns:
            The total privacy cost incurred by all the queries seen so far.
        """
        votes = get_votes_for_pate_knn(
            model=self.model, t_logits=t_logits, train_labels=self.train_labels,
            train_represent=self.train_represent, args=self.args
        )

        dp_eps = self.private_knn.add_privacy_cost(votes=votes)

        return dp_eps


#
# def computepknn(t_logits, model, trainloader, args):
#     train_represent = []
#     train_labels = []
#     with torch.no_grad():
#         for batch_id, (data, target) in enumerate(trainloader):
#             if cuda:
#                 data = data.cuda()
#             outputs = model(data)
#             outputs = F.log_softmax(outputs, dim=-1)
#             outputs = outputs.cpu().numpy()
#             train_represent.append(outputs)
#             train_labels.append(target.cpu().numpy())
#     train_represent = np.concatenate(train_represent, axis=0)
#     train_labels = np.concatenate(train_labels, axis=0)
#     num_teachers = 300
#     #print("len", len(t_logits)) #256
#     with torch.no_grad():
#         outputs = F.log_softmax(t_logits, dim=-1)
#         outputs = outputs.cpu().numpy()
#         votes = []
#         predictions = []
#         predictions.append(np.argmax(outputs, axis=-1))
#         for output in outputs:
#             dis = np.linalg.norm(train_represent - output, axis=-1)
#             k_index = np.argpartition(dis, kth=num_teachers)[:num_teachers]
#             teachers_preds = np.array(train_labels[k_index], dtype=np.int32)
#             label_count = np.bincount(
#                 teachers_preds, minlength=10)
#             votes.append(label_count)
#         predictions = np.concatenate(predictions, axis=-1)
#     votes = np.stack(votes)
#     # FIX THIS PART
#     max_num_query, dp_eps, _, _, _ = analyze_multiclass_gnmax(
#         votes=votes,
#         threshold=0,
#         sigma_threshold=0,
#         sigma_gnmax=28,
#         budget=np.inf,
#         delta=1e-5)
#     privacy_cost = dp_eps
#     return privacy_cost
#
# def analyze_multiclass_gnmax(
#         votes, threshold, sigma_threshold, sigma_gnmax, budget, delta,
#         file=None, show_dp_budget='disable', args=None):
#     max_num_query = 0
#
#     def compute_partition(order_opt, eps):
#         """Analyze how the current privacy cost is divided."""
#         idx = np.searchsorted(orders, order_opt)
#         rdp_eps_gnmax = rdp_eps_total_curr[idx]
#         p = np.array([rdp_eps_gnmax, -math.log(delta) / (order_opt - 1)])
#         # assert sum(p) == eps
#         # Normalize p so that sum(p) = 1
#         return p / eps
#
#     # RDP orders.
#     orders = np.concatenate((np.arange(2, 100, .5),
#                              np.logspace(np.log10(100), np.log10(1000),
#                                          num=200)))
#     # Number of queries
#     n = votes.shape[0]
#
#     # All cumulative results
#     dp_eps = np.zeros(n)
#     partition = [None] * n
#     order_opt = np.full(n, np.nan, dtype=float)
#
#     # Current cumulative results
#     rdp_eps_total_curr = np.zeros(len(orders))
#     # Iterating over all queries
#     for i in range(n):
#         v = votes[i]
#         logq = compute_logq_gnmax(v, sigma_gnmax)
#         rdp_eps_gnmax = compute_rdp_data_dependent_gnmax_no_upper_bound(
#             logq, sigma_gnmax, orders)
#
#         # Update current cumulative results.
#         rdp_eps_total_curr += rdp_eps_gnmax
#         # Update all cumulative results.
#         dp_eps[i], order_opt[i] = rdp_to_dp(orders, rdp_eps_total_curr,
#                                             delta)
#         partition[i] = compute_partition(order_opt[i], dp_eps[i])
#         # Verify if the pre-defined privacy budget is exhausted.
#         if dp_eps[i] <= budget:
#             max_num_query = i + 1
#         else:
#             break
#         # Logs
#         # if i % 100000 == 0 and i > 0:
#
#     # print(f"{threshold},{sigma_threshold},{sigma_gnmax}")
#     # analyze_results(votes=votes, max_num_query=max_num_query, dp_eps=dp_eps)
#     # answered = [x for x in range(1, max_num_query + 1)]
#     # answered is the probability of a given label being answered. For the GNMax
#     # without the confidence (no thresholding mechanism) each
#     # label < max_num_query is answered.
#     answered = np.zeros(n, dtype=float)
#     answered[0:max_num_query] = 1
#     return max_num_query, dp_eps, partition, answered, order_opt
#
# def rdp_to_dp(orders, rdp_eps, delta):
#     """
#     Conversion from (lambda, eps)-RDP to conventional (eps, delta)-DP.
#     Papernot 2018, Theorem 5. (From RDP to DP)
#
#     Args:
#         orders: an array-like list of RDP orders.
#         rdp_eps: an array-like list of RDP guarantees (of the same length as
#         orders).
#         delta: target delta (a scalar).
#
#     Returns:
#         A pair of (dp_eps, optimal_order).
#     """
#     assert not np.isscalar(orders) and not np.isscalar(rdp_eps) and len(
#         orders) == len(
#         rdp_eps), "'orders' and 'rdp_eps' must be array-like and of the same length!"
#
#     dp_eps = np.array(rdp_eps) - math.log(delta) / (np.array(orders) - 1)
#     idx_opt = np.argmin(dp_eps)
#     return dp_eps[idx_opt], orders[idx_opt]
#
# def compute_logq_gnmax(votes, sigma):
#     """
#     Computes an upper bound on log(Pr[outcome != argmax]) for the GNMax mechanism.
#
#     Implementation of Proposition 7 from PATE 2018 paper.
#
#     Args:
#         votes: a 1-D numpy array of raw ensemble votes for a given query.
#         sigma: std of the Gaussian noise in the GNMax mechanism.
#
#     Returns:
#         A scalar upper bound on log(Pr[outcome != argmax]) where log denotes natural logarithm.
#     """
#     num_classes = len(votes)
#     variance = sigma ** 2
#     idx_max = np.argmax(votes)
#     votes_gap = votes[idx_max] - votes
#     votes_gap = votes_gap[np.arange(num_classes) != idx_max]  # exclude argmax
#     # Upper bound log(q) via a union bound rather than a more precise
#     # calculation.
#     logq = _logsumexp(
#         scipy.stats.norm.logsf(votes_gap, scale=math.sqrt(2 * variance)))
#     return min(logq,
#                math.log(1 - (1 / num_classes)))  # another obvious upper bound
#
# def _logsumexp(x):
#     """
#     Sum in the log space.
#
#     An addition operation in the standard linear-scale becomes the
#     LSE (log-sum-exp) in log-scale.
#
#     Args:
#         x: array-like.
#
#     Returns:
#         A scalar.
#     """
#     x = np.array(x)
#     m = max(x)  # for numerical stability
#     return m + math.log(sum(np.exp(x - m)))
#
# def _log1mexp(x):
#     """
#     Numerically stable computation of log(1-exp(x)).
#
#     Args:
#         x: a scalar.
#
#     Returns:
#         A scalar.
#     """
#     assert x <= 0, "Argument must be positive!"
#     # assert x < 0, "Argument must be non-negative!"
#     if x < -1:
#         return math.log1p(-math.exp(x))
#     elif x < 0:
#         return math.log(-math.expm1(x))
#     else:
#         return -np.inf
#
# def compute_rdp_data_dependent_gnmax_no_upper_bound(logq, sigma, orders):
#     """
#     If the data dependent bound applies, then use it even though its higher than
#     the data independent bound. In this case, we are interested in estimating
#     the privacy budget solely on the data and are not optimizing its value to be
#     as small as possible.
#
#     Computes data-dependent RDP guarantees for the GNMax mechanism.
#     This is the bound D_\lambda(M(D) || M(D'))  from Theorem 6 (equation 2),
#     PATE 2018 (Appendix A).
#
#     Bounds RDP from above of GNMax given an upper bound on q.
#
#     Args:
#         logq: a union bound on log(Pr[outcome != argmax]) for the GNMax
#             mechanism.
#         sigma: std of the Gaussian noise in the GNMax mechanism.
#         orders: an array-like list of RDP orders.
#
#     Returns:
#         A numpy array of upper bounds on RDP for all orders.
#
#     Raises:
#         ValueError: if the inputs are invalid.
#     """
#     if logq > 0 or sigma < 0 or np.isscalar(orders) or np.any(orders <= 1):
#         raise ValueError(
#             "'logq' must be non-positive, 'sigma' must be non-negative, "
#             "'orders' must be array-like, and all elements in 'orders' must be "
#             "greater than 1!")
#
#     if np.isneginf(logq):  # deterministic mechanism with sigma == 0
#         return np.full_like(orders, 0., dtype=np.float)
#
#     variance = sigma ** 2
#     orders = np.array(orders)
#     rdp_eps = orders / variance  # data-independent bound as baseline
#
#     # Two different higher orders computed according to Proposition 10.
#     # See Appendix A in PATE 2018.
#     # rdp_order2 = sigma * math.sqrt(-logq)
#     rdp_order2 = math.sqrt(variance * -logq)
#     rdp_order1 = rdp_order2 + 1
#
#     # Filter out entries to which data-dependent bound does not apply.
#     mask = np.logical_and(rdp_order1 > orders, rdp_order2 > 1)
#
#     # Corresponding RDP guarantees for the two higher orders.
#     # The GNMAx mechanism satisfies:
#     # (order = \lambda, eps = \lambda / sigma^2)-RDP.
#     rdp_eps1 = rdp_order1 / variance
#     rdp_eps2 = rdp_order2 / variance
#
#     log_a2 = (rdp_order2 - 1) * rdp_eps2
#
#     # Make sure that logq lies in the increasing range and that A is positive.
#     if (np.any(mask) and -logq > rdp_eps2 and logq <= log_a2 - rdp_order2 *
#             (math.log(1 + 1 / (rdp_order1 - 1)) + math.log(
#                 1 + 1 / (rdp_order2 - 1)))):
#         # Use log1p(x) = log(1 + x) to avoid catastrophic cancellations when x ~ 0.
#         log1mq = _log1mexp(logq)  # log1mq = log(1-q)
#         log_a = (orders - 1) * (
#                 log1mq - _log1mexp(
#             (logq + rdp_eps2) * (1 - 1 / rdp_order2)))
#         log_b = (orders - 1) * (rdp_eps1 - logq / (rdp_order1 - 1))
#
#         # Use logaddexp(x, y) = log(e^x + e^y) to avoid overflow for large x, y.
#         log_s = np.logaddexp(log1mq + log_a, logq + log_b)
#
#         # Do not apply the minimum between the data independent and data
#         # dependent bound - but limit the computation to data dependent bound
#         # only!
#         rdp_eps[mask] = (log_s / (orders - 1))[mask]
#
#     assert np.all(rdp_eps >= 0)
#     return rdp_eps
#



#CIFAR10

#
trainloader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR10('/ssd003/home/akaleem/data', train=True, download=False,
                     transform=transforms.Compose([
                         transforms.ToTensor(),
                         transforms.Normalize((0.4914, 0.4822, 0.4465),
                                              (0.2023, 0.1994, 0.2010)),
                     ])),
    batch_size=64)
def load_private_model_by_id():
    """
    Load a single model by its id.
    :param args: program parameters
    :param id: id of the model
    :return: the instance of the model
    """
    import os
    from architectures.resnet import ResNet18, ResNet34
    filepath = "private-models/cifar10/ResNet34/1-models/checkpoint-model(1).pth.tar"
    if os.path.isfile(filepath):
        model = ResNet34(name='model({:d})'.format(0 + 1))
        checkpoint = torch.load(filepath)
        model.load_state_dict(checkpoint['state_dict'])
        if cuda:
            model.cuda()
        if 'label_weights' in checkpoint and args.label_reweight is 'apply':
            model.label_weights = checkpoint['label_weights']
        model.eval()
        return model
    else:
        raise Exception(
            f"Checkpoint file {filepath} does not exist, please generate it via "
            f"train_private_models(args)!")

#victim = load_private_model_by_id()
victim = dfmenetwork.resnet_8x.ResNet34_8x(num_classes=10)
ckpt = 'dfmodels/teacher/cifar10-resnet34_8x.pt'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
victim.load_state_dict(torch.load(ckpt))#, map_location=device))
if cuda:
    victim = victim.cuda()
pate_knn = PateKNN(model=victim, trainloader=trainloader,
                       args=args)
victim.eval()
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.49421429, 0.4851314, 0.45040911),
                    (0.24665252, 0.24289226,
                     0.26159238))])
base_dataset = torchvision.datasets.CIFAR10("/ssd003/home/akaleem/data", train=False, download=False, transform=transform)
x = []
ent = []
gap = []
pknn = []
for i in range(len(base_dataset)):
    print(i)
    # pate_knn = PateKNN(model=victim, trainloader=trainloader,
    #                    args=args)
    l = base_dataset[i][0]
    l = l.reshape((1, -1, 32, 32))
    l = l.to(torch.float32)
    if cuda:
        l = l.cuda()
    preds = victim(l)
    # print(preds)
    x.append(torch.argmax(preds).cpu())
    entropy = computeentropy(preds)
    g = computegap(preds)
    #pknn_cost = pate_knn.compute_privacy_cost(t_logits=preds)
    #print("pknn cost curr", pknn_cost)
    ent.append(entropy[0])
    gap.append(g[0])
    #pknn.append(pknn_cost)
    #prev = pknn_cost

# pknn_cost = pate_knn.compute_privacy_cost(t_logits=tlogits)
# print("pknn cost", pknn_cost)
x = np.array(x)
print(x)
np.save("MixMatch-pytorch/cifartargets.npy", x)
ent = np.array(ent)
print(ent)
np.save("MixMatch-pytorch/cifarent.npy", ent)
gap = np.array(gap)
np.save("MixMatch-pytorch/cifargap.npy", gap)
print(gap)
# pknn = np.array(pknn)
# np.save("MixMatch-pytorch/cifarpknn.npy", pknn)
# print(pknn)
print("Done saving cifar")
y = np.load("MixMatch-pytorch/cifartargets.npy")
print(y)
corr = 0
for i in range(len(base_dataset)):
    if x[i] == base_dataset.targets[i] :
        corr += 1
acc = corr/len(base_dataset) * 100
print(acc)


#Imagenet
# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                                      std=[0.229, 0.224, 0.225])
# preprocessing = [
#     transforms.Resize(256),
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     normalize,
# ]
# preprocessing.append(transforms.Resize(32))
# preprocessing.append(transforms.CenterCrop(32))
# #preprocessing.append(transforms.Grayscale())
#
#
# trainloader = torch.utils.data.DataLoader(
#     torchvision.datasets.ImageNet(root = "/scratch/ssd002/datasets/imagenet256/", split='val',
#                      transform=transforms.Compose(preprocessing)),
#     batch_size=64)
#
# victim = dfmenetwork.resnet_8x.ResNet34_8x(num_classes=10)
# ckpt = 'dfmodels/teacher/cifar10-resnet34_8x.pt'
# victim.load_state_dict(torch.load(ckpt))#, map_location=device))
# if cuda:
#     victim = victim.cuda()
# victim.eval()
# transform = transforms.Compose([
#                 transforms.ToTensor(),
#                 transforms.Normalize(
#                     (0.49421429, 0.4851314, 0.45040911),
#                     (0.24665252, 0.24289226,
#                      0.26159238))])
# base_dataset =  torchvision.datasets.ImageNet(root = "/scratch/ssd002/datasets/imagenet256/", split='val',
#                      transform=transforms.Compose(preprocessing))
# x = []
# ent = []
# gap = []
# pknn = []
# for i in range(len(base_dataset)):
#     print(i)
#     l = base_dataset[i][0]
#     l = l.reshape((1, -1, 32, 32))
#     l = l.to(torch.float32)
#     if cuda:
#         l = l.cuda()
#     preds = victim(l)
#     # print(preds)
#     x.append(torch.argmax(preds).cpu())
#     entropy = computeentropy(preds)
#     g = computegap(preds)
#     #p = computepknn(preds, victim, trainloader, args)
#     ent.append(entropy[0])
#     gap.append(g[0])
#     #pknn.append(p[0])
#
#
# x = np.array(x)
# print(x)
# np.save("MixMatch-pytorch/imagenettargets.npy", x)
# ent = np.array(ent)
# print(ent)
# np.save("MixMatch-pytorch/imagenetent.npy", ent)
# gap = np.array(gap)
# np.save("MixMatch-pytorch/imagenetgap.npy", gap)
# print(gap)
# # pknn = np.array(pknn)
# # np.save("MixMatch-pytorch/cifarpknn.npy", pknn)
# # print(pknn)
# print("Done saving imagenet")
# y = np.load("MixMatch-pytorch/imagenettargets.npy")
# print(y)


#Mnist
# def load_private_model_by_id():
#     """
#     Load a single model by its id.
#     :param args: program parameters
#     :param id: id of the model
#     :return: the instance of the model
#     """
#     import os
#     from architectures.mnist_net_pate import MnistNetPate
#     filepath = "private-models/mnist/MnistNetPate/1-models/checkpoint-model(1).pth.tar"
#     if os.path.isfile(filepath):
#         model = MnistNetPate(name='model({:d})'.format(0 + 1), args = args)
#         checkpoint = torch.load(filepath)
#         model.load_state_dict(checkpoint['state_dict'])
#         if cuda:
#             model.cuda()
#         if 'label_weights' in checkpoint and args.label_reweight is 'apply':
#             model.label_weights = checkpoint['label_weights']
#         model.eval()
#         return model
#     else:
#         raise Exception(
#             f"Checkpoint file {filepath} does not exist, please generate it via "
#             f"train_private_models(args)!")
#
# victim = load_private_model_by_id()
# victim.eval()
# transform = transforms.Compose([
#                 transforms.ToTensor(),
#                 transforms.Normalize((0.13251461,), (0.31048025,))])
# base_dataset = torchvision.datasets.MNIST("/ssd003/home/akaleem/data/MNIST", train=False, download=False, transform=transform) # train was true
# x = []
# ent = []
# gap = []
# for i in range(len(base_dataset)):
#     print(i)
#     l = base_dataset[i][0]
#     l = l.reshape((-1, 1, 28, 28))
#     l = l.to(torch.float32)
#     if cuda:
#         l = l.cuda()
#     preds = victim(l)
#     x.append(torch.argmax(preds).cpu())
#     entropy = computeentropy(preds)
#     g = computegap(preds)
#     ent.append(entropy[0])
#     gap.append(g[0])
#
# x = np.array(x)
# print(x)
# np.save("MixMatch-pytorch/mnisttargets.npy", x)
# ent = np.array(ent)
# print(ent)
# np.save("MixMatch-pytorch/mnistent.npy", ent)
# gap = np.array(gap)
# np.save("MixMatch-pytorch/mnistgap.npy", gap)
# print(gap)
#
# print("Done saving mnist")
# y = np.load("MixMatch-pytorch/mnisttargets.npy")
# print(y)
# corr = 0
# for i in range(len(base_dataset)):
#     if x[i] == base_dataset.targets[i] :
#         corr += 1
# acc = corr/len(base_dataset) * 100
# print(acc)




# #SVHN
# def load_private_model_by_id():
#     """
#     Load a single model by its id.
#     :param args: program parameters
#     :param id: id of the model
#     :return: the instance of the model
#     """
#     import os
#     from architectures.resnet import ResNet18, ResNet34
#     filepath = "private-models/svhn/ResNet34/1-models/checkpoint-model(1).pth.tar"
#     if os.path.isfile(filepath):
#         model = ResNet34(name='model({:d})'.format(0 + 1))
#         checkpoint = torch.load(filepath)
#         model.load_state_dict(checkpoint['state_dict'])
#         if cuda:
#             model.cuda()
#         if 'label_weights' in checkpoint and args.label_reweight is 'apply':
#             model.label_weights = checkpoint['label_weights']
#         model.eval()
#         return model
#     else:
#         raise Exception(
#             f"Checkpoint file {filepath} does not exist, please generate it via "
#             f"train_private_models(args)!")
#
# #victim = load_private_model_by_id()
# victim = dfmenetwork.resnet_8x.ResNet34_8x(num_classes=10)
# ckpt = 'dfmodels/teacher/svhn-resnet34_8x.pt'
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# victim.load_state_dict(torch.load(ckpt))#, map_location=device))
# if cuda:
#     victim = victim.cuda()
# transform = transforms.Compose([
#                 transforms.ToTensor(),
#     transforms.Normalize(
#         (0.43768212, 0.44376972, 0.47280444),
#         (
#             0.19803013, 0.20101563,
#             0.19703615))])
# base_dataset = torchvision.datasets.SVHN("/ssd003/home/akaleem/data/SVHN", split='test', download=False, transform=transform) #split = 'train'
# victim.eval()
# x = []
# ent = []
# gap = []
# print(len(base_dataset))
# for i in range(len(base_dataset)):
#     print(i)
#     l = base_dataset[i][0]
#     l = l.reshape((1, -1, 32, 32))
#     l = l.to(torch.float32)
#     if cuda:
#         l = l.cuda()
#     preds = victim(l)
#     x.append(torch.argmax(preds).cpu())
#     entropy = computeentropy(preds)
#     g = computegap(preds)
#     ent.append(entropy[0])
#     gap.append(g[0])
#
# x = np.array(x)
# print(x)
# np.save("MixMatch-pytorch/svhntargets.npy", x)
# ent = np.array(ent)
# print(ent)
# np.save("MixMatch-pytorch/svhnent.npy", ent)
# gap = np.array(gap)
# np.save("MixMatch-pytorch/svhngap.npy", gap)
# print(gap)
#
# print("Done saving svhn")
# y = np.load("MixMatch-pytorch/svhntargets.npy")
# print(y)
# corr = 0
# for i in range(len(base_dataset)):
#     if x[i] == base_dataset.labels[i] :
#         corr += 1
# acc = corr/len(base_dataset) * 100
# print(acc)



