import torch
from scipy.stats import norm, binom_test
import numpy as np
from math import ceil
from statsmodels.stats.proportion import proportion_confint
import itertools
from tqdm import tqdm
from torch.autograd import Variable
import hashlib

try:
    from autoattack import AutoAttack
except ImportError:
    print("AutoAttack not installed")

#import F


#from https://codereview.stackexchange.com/questions/118883/split-up-an-iterable-into-batches
class Batch:
    def __init__(self, iterable, condition=(lambda x:True), limit=None):
        self.iterator = iter(iterable)
        self.condition = condition
        self.limit = limit
        try:
            self.current = next(self.iterator)
        except StopIteration:
            self.on_going = False
        else:
            self.on_going = True

    def group(self):
        yield self.current
        # start enumerate at 1 because we already yielded the last saved item
        for num, item in enumerate(self.iterator, 1):
            self.current = item
            if num == self.limit or self.condition(item):
                break
            yield item
        else:
            self.on_going = False

    def __iter__(self):
        while self.on_going:
            yield self.group()

class Smooth(object):
# class Smooth(torch.nn.Module):
    """A smoothed classifier g """

    # to abstain, Smooth returns this int
    ABSTAIN = -1

    def __init__(self, base_classifier: torch.nn.Module, num_classes: int, sigma: float):
        """
        :param base_classifier: maps from [batch x channel x height x width] to [batch x num_classes]
        :param num_classes:
        :param sigma: the noise level hyperparameter
        """
        # super().__init__()
        self.base_classifier = base_classifier
        self.num_classes = num_classes
        self.sigma = sigma
        self.log_dir = None
        self.basis_x = None
        self.basis_vectors = None
        self.corruption_noise = None
        self.eot_counter = 0

    def certify(self, x: torch.tensor, n0: int, n: int, alpha: float, batch_size: int) -> (int, float):
        """ Monte Carlo algorithm for certifying that g's prediction around x is constant within some L2 radius.
        With probability at least 1 - alpha, the class returned by this method will equal g(x), and g's prediction will
        robust within a L2 ball of radius R around x.

        :param x: the input [channel x height x width]
        :param n0: the number of Monte Carlo samples to use for selection
        :param n: the number of Monte Carlo samples to use for estimation
        :param alpha: the failure probability
        :param batch_size: batch size to use when evaluating the base classifier
        :return: (predicted class, certified radius)
                 in the case of abstention, the class will be ABSTAIN and the radius 0.
        """
        self.base_classifier.eval()
        # draw samples of f(x+ epsilon)
        counts_selection = self._sample_noise(x, n0, batch_size)
        # use these samples to take a guess at the top class
        cAHat = counts_selection.argmax().item()
        # draw more samples of f(x + epsilon)
        counts_estimation = self._sample_noise(x, n, batch_size)
        # use these samples to estimate a lower bound on pA
        nA = counts_estimation[cAHat].item()
        pABar = self._lower_confidence_bound(nA, n, alpha)
        if pABar < 0.5:
            return Smooth.ABSTAIN, 0.0
        else:
            radius = self.sigma * norm.ppf(pABar)
            return cAHat, radius
    
    def certify_discrete(self, x: torch.tensor, n0: int, n: int, alpha: float, num_basis_vectors: int, batch_size: int) -> (int, float):
        """ Monte Carlo algorithm for certifying that g's prediction around x is constant within some L2 radius.
        With probability at least 1 - alpha, the class returned by this method will equal g(x), and g's prediction will
        robust within a L2 ball of radius R around x.

        :param x: the input [channel x height x width]
        :param n0: the number of Monte Carlo samples to use for selection
        :param n: the number of Monte Carlo samples to use for estimation
        :param alpha: the failure probability
        :param batch_size: batch size to use when evaluating the base classifier
        :return: (predicted class, certified radius)
                 in the case of abstention, the class will be ABSTAIN and the radius 0.
        """
        self.base_classifier.eval()
        # draw samples of f(x+ epsilon)
        
        counts_selection = self._bruteforce_discrete_noise(x, 9, num_basis_vectors, batch_size)
        # use these samples to take a guess at the top class
        cAHat = counts_selection.argmax().item()
        # draw more samples of f(x + epsilon)
        counts_estimation = self._bruteforce_discrete_noise(x, 10, num_basis_vectors, batch_size)
        # use these samples to estimate a lower bound on pA
        nA = counts_estimation[cAHat].item()
        pABar = self._lower_confidence_bound(nA, n, alpha)
        if pABar < 0.5:
            return Smooth.ABSTAIN, 0.0
        else:
            radius = self.sigma * norm.ppf(pABar)
            return cAHat, radius

    def predict(self, x: torch.tensor, n: int, alpha: float, batch_size: int) -> int:
        """ Monte Carlo algorithm for evaluating the prediction of g at x.  With probability at least 1 - alpha, the
        class returned by this method will equal g(x).

        This function uses the hypothesis test described in https://arxiv.org/abs/1610.03944
        for identifying the top category of a multinomial distribution.

        :param x: the input [channel x height x width]
        :param n: the number of Monte Carlo samples to use
        :param alpha: the failure probability
        :param batch_size: batch size to use when evaluating the base classifier
        :return: the predicted class, or ABSTAIN
        """
        self.base_classifier.eval()
        counts = self._sample_noise(x, n, batch_size)
        top2 = counts.argsort()[::-1][:2]
        count1 = counts[top2[0]]
        count2 = counts[top2[1]]
        if binom_test(count1, count1 + count2, p=0.5) > alpha:
            return Smooth.ABSTAIN
        else:
            return top2[0]

    def _sample_noise(self, x: torch.tensor, num: int, batch_size) -> np.ndarray:
        """ Sample the base classifier's prediction under noisy corruptions of the input x.

        :param x: the input [channel x width x height]
        :param num: number of samples to collect
        :param batch_size:
        :return: an ndarray[int] of length num_classes containing the per-class counts
        """
        with torch.no_grad():
            counts = np.zeros(self.num_classes, dtype=int)
            for _ in range(ceil(num / batch_size)):
                this_batch_size = min(batch_size, num)
                num -= this_batch_size

                batch = x.repeat((this_batch_size, 1, 1, 1))
                noise = torch.randn_like(batch, device='cuda') * self.sigma
                predictions = self.base_classifier(torch.clamp(batch + noise,0.,1.)).argmax(1)
                counts += self._count_arr(predictions.cpu().numpy(), self.num_classes)
            return counts

    def _count_arr(self, arr: np.ndarray, length: int) -> np.ndarray:
        counts = np.zeros(length, dtype=int)
        for idx in arr:
            counts[idx] += 1
        return counts


    iterator_dict = {}
    def discrete_sphere_iterator_helper(self, bins_per_basis_vector: int, num_basis_vectors: int, max_batch_size: int, infty: bool) -> np.ndarray:
        """ Generate a set of points within the discrete sphere."""

        if bins_per_basis_vector == 1:
            return np.zeros((1,num_basis_vectors), dtype=float)

        batch = []
        for candidate in itertools.product(range(bins_per_basis_vector), repeat=num_basis_vectors):
            #transform the candidate into a vector between -1 and 1
            candidate = (np.array(candidate, dtype=np.float32) / float(bins_per_basis_vector - 1.)) * 2. - 1.
            
            #check to see if the candidate is within the unit sphere
            if infty:
                batch.append(candidate)
            elif np.linalg.norm(candidate) <= 1.:
                batch.append(candidate)
            if len(batch) >= max_batch_size:
                yield batch
                batch = []
        if len(batch) > 0:
            yield batch
    def discrete_sphere_iterator(self, bins_per_basis_vector: int, num_basis_vectors: int, max_batch_size: int, infty: bool) -> np.ndarray:
        """ Generate a list of discrete sphere points."""
        if (bins_per_basis_vector,num_basis_vectors) not in self.iterator_dict:
            grid_points = list(self.discrete_sphere_iterator_helper(bins_per_basis_vector, num_basis_vectors, 1, infty))
            grid_points = np.array(grid_points).reshape(-1, num_basis_vectors)
            self.iterator_dict[(bins_per_basis_vector,num_basis_vectors)] = grid_points
        grid_points = self.iterator_dict[(bins_per_basis_vector,num_basis_vectors)]
        last_grid_point_batch = None
        if grid_points.shape[0] % max_batch_size:
            last_grid_point_batch = grid_points[-(grid_points.shape[0] % max_batch_size):]
            grid_points = grid_points[:-(grid_points.shape[0] % max_batch_size)]
        for grid_point in grid_points.reshape(-1, max_batch_size, num_basis_vectors):
            yield grid_point
        if last_grid_point_batch is not None:
            yield last_grid_point_batch




    def _sample_discrete_noise(self, x: torch.tensor, num: int, basis_vectors: np.ndarray, batch_size: int, infty: bool, corruption_noise: torch.tensor) -> np.ndarray:
        """ Sample the base classifier's prediction under noisy corruptions of the input x.

        :param x: the input [channel x width x height]
        :param num: number of samples to collect
        :param batch_size:
        :return: an ndarray[int] of length num_classes containing the per-class counts
        """
        basis_vectors = torch.from_numpy(basis_vectors).float().to('cuda')

        self.base_classifier.eval()

        num_corrupt = 1 if corruption_noise is None else corruption_noise.shape[0]


        if corruption_noise is None:
            logits = self.base_classifier(x.unsqueeze(0))
            y = logits.argmax(1)
        else:
            logits = self.base_classifier(torch.clamp(x  + corruption_noise,0.,1.))
            predictions = logits.argmax(1)
            predictions = predictions.reshape((-1, num_corrupt))
            # calculate the mode of the last axis of predictions
            y = torch.mode(predictions, dim=1)[0]

        max_loss = -np.inf
        counts = np.zeros(self.num_classes, dtype=int)
        for _ in range(ceil(num / (batch_size // num_corrupt))):
            this_batch_size = min((batch_size // num_corrupt), num)
            num -= this_batch_size

            if corruption_noise is not None:
                batch = x + corruption_noise
            else:
                batch = x

            batch = batch.repeat((this_batch_size, 1, 1, 1))
            basis_weights = torch.randn((this_batch_size, basis_vectors.size()[0]), device='cuda')
            
            if infty:
                basis_weights = torch.clamp(basis_weights, -1., 1.)
            else:
                for i in range(this_batch_size):
                    if torch.norm(basis_weights[i,:]) > 1.:
                        basis_weights[i,:] = basis_weights[i,:] / torch.norm(basis_weights[i,:])

            noise = torch.einsum('bn,ncwh->bcwh', basis_weights, basis_vectors) * self.sigma
            noise = noise.repeat((num_corrupt, 1, 1, 1, 1))
            noise = noise.transpose(0,1).reshape((this_batch_size*num_corrupt, 3, 32, 32)) # we want to make sure pgd perturbations are the same for each num_corrupt noise ensemble

            logits = self.base_classifier(torch.clamp(batch + noise,0.,1.)).detach()
            predictions = logits.argmax(1)


            if corruption_noise is not None:
                predictions = predictions.reshape((this_batch_size, num_corrupt))
                predictions = torch.mode(predictions, dim=1)[0]

            batch_loss = torch.nn.CrossEntropyLoss(reduction='none').cuda()(logits, y.repeat(this_batch_size*num_corrupt))
            
            if corruption_noise is not None:
                batch_loss = batch_loss.reshape((this_batch_size, num_corrupt))
                batch_loss = torch.mean(batch_loss, dim=1)

            batch_max = np.max(batch_loss.cpu().numpy())
            max_loss = np.max([max_loss, batch_max])
            counts += self._count_arr(predictions.cpu().numpy(), self.num_classes)
        return counts, max_loss
    
    def _bruteforce_discrete_basis(self, x: torch.tensor, bins_per_axis: int, basis_vectors: np.ndarray, batch_size: int, infty: bool, corruption_noise: torch.tensor) -> np.ndarray:
        """ Sample the base classifier's prediction under noisy corruptions of the input x.

        :param x: the input [channel x width x height]
        :param num: number of samples to collect
        :param batch_size:
        :return: an ndarray[int] of length num_classes containing the per-class counts
        """
        basis_vectors = torch.from_numpy(basis_vectors).float().to('cuda')

        self.base_classifier.eval()

        num_corrupt = 1
        if corruption_noise is not None:
            assert corruption_noise.shape[0] < batch_size, "corruption noise must be less than batch_size"
            num_corrupt = corruption_noise.shape[0]

        if corruption_noise is None:
            logits = self.base_classifier(x.unsqueeze(0))
            y = logits.argmax(1)
        else:
            logits = self.base_classifier(torch.clamp(x  + corruption_noise,0.,1.))
            predictions = logits.argmax(1)
            predictions = predictions.reshape((-1, num_corrupt))
            # calculate the mode of the last axis of predictions
            y = torch.mode(predictions, dim=1)[0]

        max_loss = -np.inf
        counts = np.zeros(self.num_classes, dtype=int)
        for noise_vector in self.discrete_sphere_iterator(bins_per_axis, basis_vectors.size()[0], batch_size//num_corrupt, infty):
            this_batch_size = noise_vector.shape[0]
            noise_vector = torch.from_numpy(np.array(noise_vector)).float().to('cuda') 

            if corruption_noise is not None:
                batch = x + corruption_noise
            else:
                batch = x

            batch = batch.repeat((this_batch_size, 1, 1, 1))
            
            noise = torch.einsum('bn,ncwh->bcwh',noise_vector, basis_vectors) * self.sigma

            noise = noise.repeat((num_corrupt, 1, 1, 1, 1))
            noise = noise.transpose(0,1).reshape((this_batch_size*num_corrupt, 3, 32, 32)) # we want to make sure pgd perturbations are the same for each num_corrupt noise ensemble

            logits = self.base_classifier(torch.clamp(batch + noise, 0., 1.)).detach()
            predictions = logits.argmax(1)
            if corruption_noise is not None:
                predictions = predictions.reshape((this_batch_size, num_corrupt))
                predictions = torch.mode(predictions, dim=1)[0]

            batch_loss = torch.nn.CrossEntropyLoss(reduction='none').cuda()(logits, y.repeat(this_batch_size*num_corrupt))
            
            if corruption_noise is not None:
                batch_loss = batch_loss.reshape((this_batch_size, num_corrupt))
                batch_loss = torch.mean(batch_loss, dim=1)
            batch_max = np.max(batch_loss.cpu().numpy())
            max_loss = np.max([max_loss, batch_max])

            counts += self._count_arr(predictions.cpu().numpy(), self.num_classes)
        return counts, max_loss
    

    def _pgd_with_basis(self, x: torch.tensor, basis_vectors: np.ndarray, learning_rate: float, num_iterations: int, num_restarts: int, batch_size: int, infty: bool, corruption_noise: torch.tensor) -> np.ndarray:
        """ Do PGD in the subspace defined by the basis vectors with the base classifier's prediction 
        under noisy corruptions of the input x.

        :param x: the input [channel x width x height]
        :param basis_vectors: the basis vectors that define the subspace to do PGD in
        :param learning_rate: the learning rate for the pgd
        :param num_iterations: the number of iterations to run the pgd for
        :param batch_size: the batch size for the pgd (unused)
        :return: an ndarray[int] of length num_classes containing the per-class counts
        """
    
        self.base_classifier.eval()
        basis_vectors = torch.from_numpy(basis_vectors).float().to('cuda')

        logits = self.base_classifier(x.unsqueeze(0))
        y = logits.argmax(1)
        # print('y', y)
        
        counts_runs = []
        pgd_losses = []
        all_counts = []
        all_pgd_loss = []
        orig_num_restarts = num_restarts

        num_corrupt = 1 if corruption_noise is None else corruption_noise.shape[0]

        for _ in range(ceil(num_restarts / (batch_size // num_corrupt))):
            restart_batch_size = min((batch_size // num_corrupt), num_restarts)
            num_restarts-= restart_batch_size
            # for _ in range(this_batch_size):
            counts = np.zeros((restart_batch_size, self.num_classes), dtype=int)
            pgd_loss = np.zeros((restart_batch_size, num_iterations), dtype=float) - 1
            #initialize noise within weight space
            basis_weights = torch.randn((restart_batch_size, basis_vectors.size()[0]), device='cuda')
            if infty:
                basis_weights = torch.clamp(basis_weights, -1., 1.)
            else:
                for i in range(restart_batch_size):
                    if torch.norm(basis_weights[i]) > 1:
                        basis_weights[i] = torch.div(basis_weights[i], torch.norm(basis_weights[i]))
            # elif torch.norm(basis_weights) > 1.:
                # basis_weights = basis_weights / torch.norm(basis_weights)
            basis_weights = Variable(basis_weights, requires_grad=True)
            
            for pgd_i in range(num_iterations):
                
                
                pgd_perturb = torch.einsum('bn,ncwh->bcwh', basis_weights, basis_vectors) * self.sigma
                
                total_grads = torch.zeros_like(x)
                loss_sum = 0.
                iter_count = np.zeros(self.num_classes, dtype=int)
                batch_loss = None
                
                gradient_estimate_steps = num_corrupt

                assert num_corrupt < batch_size, "num_corrupt must be less than batch_size for basis PGD with smoothing"

                this_batch_size = num_corrupt * restart_batch_size

                if corruption_noise is not None:
                    batch = x + corruption_noise
                else:
                    batch = x
                
                batch = batch.repeat((restart_batch_size, 1, 1, 1))
                pgd_perturb_batch = pgd_perturb.repeat((num_corrupt, 1, 1, 1, 1))
                pgd_perturb_batch = pgd_perturb_batch.transpose(0,1).reshape((this_batch_size, 3, 32, 32)) # we want to make sure pgd perturbations are the same for each num_corrupt noise ensemble

                logits = self.base_classifier(torch.clamp(batch  + pgd_perturb_batch,0.,1.))
                predictions = logits.argmax(1)

                predictions = predictions.reshape((restart_batch_size, num_corrupt))

                # calculate the mode of the last axis of predictions
                predictions = torch.mode(predictions, dim=1)[0]

                counts += torch.nn.functional.one_hot(predictions.cpu(),self.num_classes).numpy()

                #backpropogate the loss and update perturbation
                batch_loss = torch.nn.CrossEntropyLoss(reduction='none').cuda()(logits, y.repeat(this_batch_size))
                loss = torch.mean(batch_loss)
                loss.backward()

                with torch.no_grad():
                    pgd_loss[:,pgd_i] = batch_loss.cpu().numpy().reshape((restart_batch_size, num_corrupt)).mean(1)

                    if infty:
                        basis_weights += torch.sign(basis_weights.grad.data) * learning_rate
                        basis_weights = torch.clamp(basis_weights, -1., 1.)
                        basis_weights.requires_grad = True
                    else:
                        for grad_i in range(restart_batch_size):
                            basis_weight_grad = basis_weights.grad.data[grad_i]
                            basis_weight_grad /= torch.norm(basis_weight_grad)
                            basis_weights[grad_i] += basis_weight_grad * learning_rate
                            if torch.norm(basis_weights[grad_i]) > 1:
                                basis_weights[grad_i] /= torch.norm(basis_weights[grad_i])
                        basis_weights.requires_grad = True
                        
            all_counts.append(counts)
            all_pgd_loss.append(pgd_loss)

        
        
        all_counts = np.concatenate(all_counts, axis=0).reshape(orig_num_restarts, self.num_classes)
        all_pgd_loss = np.concatenate(all_pgd_loss, axis=0).reshape(orig_num_restarts, num_iterations)

        
        return all_counts, all_pgd_loss

    def forward_with_basis_vectors(self, basis_weights: torch.tensor):
        """ Set the basis vectors to use for the forward pass. """
        assert self.basis_x is not None, "Must set basis_x before calling forward_with_basis_vectors"
        assert self.basis_vectors is not None, "Must set basis_vectors before calling forward_with_basis_vectors"
        self.basis_x = self.basis_x.to('cuda')
        self.basis_vectors = self.basis_vectors.to('cuda')
        basis_weights_scaled = (basis_weights * 2.) - 1. # scale from [0,1] to [-1,1]
        
        pgd_perturb = torch.einsum('bn,ncwh->bcwh', basis_weights_scaled, self.basis_vectors) * self.sigma
        if self.corruption_noise is None:
            return self.base_classifier(torch.clamp(self.basis_x  + pgd_perturb,0.,1.))
        else:
            if self.last_basis_weights_hash is not None and self.last_basis_weights_hash != hash(basis_weights.cpu().numpy().tobytes()):
                assert self.prediction_idx_match_mode is not None, "Must set prediction_idx_match_mode before calling forward_with_basis_vectors with corruption_noise"
                self.eot_counter = self.prediction_idx_match_mode
                self.last_basis_weights_hash = hash(basis_weights.cpu().numpy().tobytes())
            logits =  self.base_classifier(torch.clamp(self.basis_x  + pgd_perturb + self.corruption_noise[self.eot_counter,...],0.,1.))
            self.eot_counter = (self.eot_counter + 1) % self.corruption_noise.shape[0]
            return logits


    def _autoattack_with_basis(self, x: torch.tensor, basis_vectors: np.ndarray, learning_rate: float, num_iterations: int, num_restarts: int, batch_size: int, infty: bool, corruption_noise: torch.tensor) -> np.ndarray:
        """ Do PGD in the subspace defined by the basis vectors with the base classifier's prediction 
        under noisy corruptions of the input x.

        :param x: the input [channel x width x height]
        :param basis_vectors: the basis vectors that define the subspace to do PGD in
        :param learning_rate: unused
        :param num_iterations: the number of iterations to run the pgd for
        :param batch_size: the batch size for the pgd (unused)
        :return: an ndarray[int] of length num_classes containing the per-class counts
        """
        
        assert self.log_dir is not None, "Must set log_dir before calling _autoattack_with_basis"

        self.base_classifier.eval()
        basis_vectors = torch.from_numpy(basis_vectors).float().to('cuda')

        self.last_basis_weights_hash = None
        self.prediction_idx_match_mode = None
        self.corruption_noise = corruption_noise
        num_corrupt = 1 if corruption_noise is None else corruption_noise.shape[0]
        if self.corruption_noise is None:
            logits = self.base_classifier(x.unsqueeze(0))
            y = logits.argmax(1)
        else:
            logits = self.base_classifier(torch.clamp(x  + self.corruption_noise,0.,1.))
            predictions = logits.argmax(1)
            predictions = predictions.reshape((-1, num_corrupt))
            # calculate the mode of the last axis of predictions
            y = torch.mode(predictions, dim=1)[0]
            # get the first index where the prediction matches the mode
            self.prediction_idx_match_mode = torch.where(predictions == y.unsqueeze(1))[0][0]
        
        counts_runs = []
        pgd_losses = []
        all_counts = []
        all_pgd_loss = []
        orig_num_restarts = num_restarts

        if self.corruption_noise is not None:
            # need to set self.eot_counter to a value that will make the first call to forward_with_basis_vectors return the same label as the aggregate prediction
            # this is to make sure autoattack doesnt think the it has won immediately
            self.eot_counter = self.prediction_idx_match_mode


        # create a forward function that takes in basis vector weights and returns the logits

        counts = np.zeros((num_restarts, self.num_classes), dtype=int)
        pgd_loss = np.zeros((num_restarts, num_iterations), dtype=float) - 1
        #initialize noise within weight space
        basis_weights = Variable(torch.zeros((num_restarts, basis_vectors.size()[0]), device='cuda') + 0.5, requires_grad=True)

        autoattack_pgd = AutoAttack(self.forward_with_basis_vectors, norm='L2' if not infty else 'Linf', eps=0.5,
                                version='apgd-ce', attacks_to_run=['apgd-ce'],
                                log_path=f'{self.log_dir}/log_resnet.txt', device='cuda')
        

        autoattack_pgd.apgd.n_restarts = 1 # we make a batch of num_restarts samples, so we only need to run the attack once on each sample. This way we get the number of times the attack was successful for each sample.
        autoattack_pgd.apgd.n_iter = num_iterations
        autoattack_pgd.apgd.eot_iter = num_corrupt # this is a hack to make autoattack mean the gradients over multiple corruptions
        
        # adversary_resnet.fab.n_restarts = 1
        # adversary_resnet.apgd_targeted.n_restarts = 1
        # adversary_resnet.fab.n_target_classes = 9
        # adversary_resnet.apgd_targeted.n_target_classes = 9
        # adversary_resnet.square.n_queries = 5000

        self.basis_x = x
        self.basis_vectors = basis_vectors

        adv_basis_weights = autoattack_pgd.run_standard_evaluation(basis_weights, y.repeat(num_restarts), bs=batch_size)
        
        logit_corrupt_list = []
        for _ in range(num_corrupt):
            logit_corrupt_list.append(self.forward_with_basis_vectors(adv_basis_weights))
        
        logits = torch.stack(logit_corrupt_list, dim=1)

        assert logits.shape == (num_restarts, num_corrupt, self.num_classes)

        predictions = logits.argmax(-1)

        predictions = predictions.reshape((num_restarts, num_corrupt))

        # calculate the mode of the last axis of predictions
        predictions = torch.mode(predictions, dim=1)[0]

        counts += torch.nn.functional.one_hot(predictions.cpu(),self.num_classes).numpy()

        # iter_count += self._count_arr(predictions.cpu().numpy(), self.num_classes)

        #backpropogate the loss and update perturbation
        batch_loss = torch.nn.CrossEntropyLoss(reduction='none').cuda()(logits.reshape(-1, self.num_classes), y.repeat(num_restarts*num_corrupt))
        
        with torch.no_grad():
            pgd_loss[:,0] = batch_loss.cpu().numpy().reshape((num_restarts, num_corrupt)).mean(1)
        
        all_counts.append(counts)
        all_pgd_loss.append(pgd_loss)
        
        all_counts = np.concatenate(all_counts, axis=0).reshape(orig_num_restarts, self.num_classes)
        all_pgd_loss = np.concatenate(all_pgd_loss, axis=0).reshape(orig_num_restarts, num_iterations)

        
        return all_counts, all_pgd_loss


    def _pgd_on_smoothing(self, x: torch.tensor, learning_rate: float, num_iterations: int, num_corrupt: int, batch_size: int, infty: bool, make_deterministic: bool, num_eval_queries: int, grad_estimate_num: int) -> np.ndarray:
        """ Do PGD with the base classifier's prediction 
        under noisy corruptions of the input x.

        :param x: the input [channel x width x height]
        :param basis_vectors: the basis vectors that define the subspace to do PGD in
        :param learning_rate: the learning rate for the pgd
        :param num_iterations: the number of iterations to run the pgd for
        :param batch_size: the batch size for the pgd (unused)
        :return: an ndarray[int] of length num_classes containing the per-class counts
        """
    
        self.base_classifier.eval()

        logits = self.base_classifier(x.unsqueeze(0))
        y = logits.argmax(1)
        # print('y', y)
        
        counts = np.zeros(self.num_classes, dtype=int)
        inferences = []
        
        #initialize PGD perturbation
        pgd_perturb = torch.randn_like(x, device='cuda')
        if infty:
            pgd_perturb = torch.clamp(pgd_perturb, -1., 1.)
        elif torch.norm(pgd_perturb) > 1.:
            pgd_perturb = pgd_perturb / torch.norm(pgd_perturb)
        pgd_perturb = Variable(pgd_perturb, requires_grad=True)
        
        orig_num_corrupt = num_corrupt
        if make_deterministic or grad_estimate_num is None or grad_estimate_num <= 0:
            grad_estimate_num = num_corrupt
        losses = []
        loss_steps = []

        for _ in range(num_iterations):
            num_corrupt = grad_estimate_num
            if make_deterministic:
                seed = int(hashlib.sha256(x.cpu().numpy().tobytes()).hexdigest(), 16) % (2**32 - 1)
                torch.manual_seed(seed)

            total_grads = torch.zeros_like(x)
            loss_sum = 0.
            iter_count = np.zeros(self.num_classes, dtype=int)
            batch_loss = None
            for _ in range(ceil(num_corrupt / batch_size)):
                this_batch_size = min(batch_size, num_corrupt)
                num_corrupt -= this_batch_size

                batch = x.repeat((this_batch_size, 1, 1, 1))
                noise = torch.randn_like(batch, device='cuda') * self.sigma
                pgd_perturb_batch = pgd_perturb.repeat((this_batch_size, 1, 1, 1)) * self.sigma
                logits = self.base_classifier(torch.clamp(batch + noise + pgd_perturb_batch,0.,1.))
                predictions = logits.argmax(1)
                iter_count += self._count_arr(predictions.cpu().numpy(), self.num_classes)

                #backpropogate the loss and update perturbation
                batch_loss = torch.nn.CrossEntropyLoss(reduction='none').cuda()(logits, y.repeat(this_batch_size))
                loss = torch.mean(batch_loss)
                loss.backward()
                with torch.no_grad():
                    total_grads += pgd_perturb.grad.data * this_batch_size
                    loss_sum += loss.item() * this_batch_size

            with torch.no_grad():
                grads = total_grads/grad_estimate_num  #make sure that these are normalized correctly - DONE
                if infty:
                    pgd_perturb += torch.sign(grads) * learning_rate
                    pgd_perturb = torch.clamp(pgd_perturb, -1., 1.)
                    pgd_perturb.requires_grad = True
                else:
                    grads = grads / torch.norm(grads)
                    pgd_perturb += grads * learning_rate
                    if torch.norm(pgd_perturb) > 1: 
                        #normalize grads
                        pgd_perturb /= torch.norm(pgd_perturb)
                losses.append(float(loss_sum)/grad_estimate_num)
                loss_steps.append(batch_loss.cpu().numpy()[0])
                counts += self._count_arr([np.argmax(iter_count)], self.num_classes)
                inferences.append(np.argmax(iter_count))
        
        #evaluate last perturbation for num_eval_queries*num_corrupt samples and return counts
        if make_deterministic:
            seed = int(hashlib.sha256(x.cpu().numpy().tobytes()).hexdigest(), 16) % (2**32 - 1)
            torch.manual_seed(seed)
            num_corrupt = orig_num_corrupt
        else:
            num_corrupt = num_eval_queries * orig_num_corrupt

        eval_count = np.zeros(self.num_classes, dtype=int)
        for _ in range(ceil(num_corrupt / batch_size)):
            this_batch_size = min(batch_size, num_corrupt)
            num_corrupt -= this_batch_size

            batch = x.repeat((this_batch_size, 1, 1, 1))
            noise = torch.randn_like(batch, device='cuda') * self.sigma
            pgd_perturb_batch = pgd_perturb.repeat((this_batch_size, 1, 1, 1)) * self.sigma
            logits = self.base_classifier(torch.clamp(batch + noise + pgd_perturb_batch,0.,1.))
            predictions = logits.argmax(1)
            eval_count += self._count_arr(predictions.cpu().numpy(), self.num_classes)


        
        return counts, np.array(losses), np.array(loss_steps), inferences, eval_count


    def _lower_confidence_bound(self, NA: int, N: int, alpha: float) -> float:
        """ Returns a (1 - alpha) lower confidence bound on a bernoulli proportion.

        This function uses the Clopper-Pearson method.

        :param NA: the number of "successes"
        :param N: the number of total draws
        :param alpha: the confidence level
        :return: a lower bound on the binomial proportion which holds true w.p at least (1 - alpha) over the samples
        """
        return proportion_confint(NA, N, alpha=2 * alpha, method="beta")[0]
