from torchvision.models.resnet import ResNet, Bottleneck
from torch.cuda.amp import autocast, GradScaler

from torch.nn.parallel import DistributedDataParallel as DDP


############################################################################3

from torchvision.models.resnet import (
    resnet18,
    resnet50,
)  # Not needed,just for testing in hree

import argparse
import os


import sys

# import setGPU
# from time import time
# import time as basetime
import time
import datetime

# from tqdm import tqdm

from statsmodels.stats.proportion import multinomial_proportions_confint as multi_conf
from statsmodels.stats.proportion import proportion_confint as binom_conf

import torch

from scipy.stats import norm, binom_test

# from statsmodels.stats.proportion import proportion_confint
import numpy as np
from math import ceil

from collections import OrderedDict

import torch.backends.cudnn as cudnn

import torch.distributed as dist
from torch import nn
from torch.nn.parallel import DistributedDataParallel

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision.utils
from torchvision import models
import torchvision.datasets as dsets
from torchvision import datasets
import torchvision.transforms as transforms

# import torchattacks
# from torchattacks import PGD, FGSM

from statsmodels.stats.proportion import proportion_confint as binom_conf
from statsmodels.stats.proportion import multinomial_proportions_confint as multi_conf

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt

plt.style.use(["seaborn-paper", "./paper.mplstyle"])

from scipy.stats import norm as stats_norm

from autoattack import AutoAttack

from logging_manager import LogSet


def save_figs(fn, types=(".pdf", ".png")):
    fig = plt.gcf()
    fig.tight_layout()
    for t in types:
        fig.savefig(fn + t)


plt.savefig = save_figs


class WrapperModule(torch.nn.Module):
    """Standardize the channels of a batch of images by subtracting the dataset mean
    and dividing by the dataset standard deviation.
    In order to certify radii in original coordinates rather than standardized coordinates, we
    add the Gaussian noise _before_ standardizing, which is why we have standardization be the first
    layer of the classifier rather than as a part of preprocessing as is typical.
    """

    def __init__(self, model, samples, sigma):
        """
        :param means: the channel means
        :param sds: the channel standard deviations
        """
        super(WrapperModule, self).__init__()
        self.model = model
        self.samples = samples
        self.sigma = sigma

    def forward(self, inpt: torch.tensor):
        if inpt.shape[1] == 1:
            factor = 3
        else:
            factor = 1
        adv_input = inpt.repeat(self.samples, factor, 1, 1)
        adv_input = adv_input + torch.randn_like(adv_input) * self.sigma
        gs = F.gumbel_softmax(
            100 * self.model(adv_input), tau=1, hard=True, dim=1
        )  # wrapper_module(model, inpt)-
        sf = torch.sum(gs, dim=0).reshape(1, -1)

        return (sf[0, :] / self.samples).reshape(1, -1)


###############################################################
# Utilities
###############################################################


def at_radius(quantity, radius, count, flag_upper=True):
    if flag_upper:
        return np.sum((quantity >= radius)) / count
    else:
        return np.sum((quantity <= radius)) / count


def basic_predict(model, images, device):
    images = images.detach().clone()

    _, _, pred_baseline = return_wrapper(model, images, device)

    indices, counts = torch.unique(pred_baseline, sorted=True, return_counts=True)
    pred_baseline = indices[torch.argmax(counts)]

    return pred_baseline.detach()


def return_wrapper(model, inpt, device, validate=False):
    samples = inpt.shape[0]

    flag = False
    if inpt.shape[1] == 1:
        flag = True
        inpt = inpt.repeat(1, 3, 1, 1)

    if validate:
        direct_out = model(inpt)
        ix = torch.argmax(direct_out, dim=1)
        indices, counts = torch.unique(ix, return_counts=True)
        return indices, counts, ix
    else:
        x = inpt.detach().clone()
        gs = F.gumbel_softmax(100 * model(inpt), tau=1, hard=True, dim=1)
        sf = torch.sum(gs, dim=0).reshape(1, -1)

        if flag:
            x = x[:, 0, :, :].unsqueeze(1)
        return (sf[0, :] / samples).to(device), x, (torch.argmax(gs, dim=1)).to(device)


def norm_distance(a, b, flag):
    if flag:
        return torch.linalg.norm(a[:, 0, :, :] - b[:, 0, :, :])
    else:
        return torch.linalg.norm(a - b)


############################################
# Attacks
############################################

# CW-L2 Attack
# Based on the paper, i.e. not exact same version of the code on https://github.com/carlini/nn_robust_attacks
# (1) Binary search method for c, (2) Optimization on tanh space, (3) Choosing method best l2 adversaries is NOT IN THIS CODE.
# Based upon https://github.com/Harry24k/CW-pytorch/blob/master/CW.ipynb
def cw_l2_attack(
    model,
    images,
    labels,
    samples,
    sigma,
    device,
    targeted=False,
    c=1e-4,
    kappa=0,
    max_iter=100,
    learning_rate=0.01,
    printing=False,
    z_val=4.0,
    definitive=False,
):
    inpt_flag = True if images.shape[1] == 1 else False

    start_time = time.time()

    if definitive is not False:
        alpha = 0.5 * (1 - stats_norm.cdf(z_val))

    device = images.device
    images = images.to(device)
    labels = labels.to(device)

    # Define f-function
    def f(x, samples, sigma):
        device = x.device
        adv_input = x.repeat(samples, 1, 1, 1)
        adv_input += torch.randn_like(adv_input) * sigma

        if adv_input.shape[1] != 3:
            adv_input = adv_input.repeat(1, 3, 1, 1)

        outputs, _, _ = return_wrapper(model, adv_input, device)

        one_hot_labels = torch.eye(outputs.shape[0])[labels].to(device)

        i, _ = torch.max((1 - one_hot_labels) * outputs, dim=1)
        j = torch.masked_select(outputs, one_hot_labels.bool())

        # If targeted, optimize for making the other class most likely
        if targeted:
            return torch.clamp(i - j, min=-kappa)

        # If untargeted, optimize for making the other class most likely
        else:
            return torch.clamp(j - i, min=-kappa)

    w = torch.zeros_like(images, requires_grad=True).to(device)

    optimizer = optim.Adam([w], lr=learning_rate)

    prev = 1e10

    flag = False
    min_val, min_recorded = 1e6, None

    for step in range(max_iter):

        a = 1 / 2 * (nn.Tanh()(w) + 1)

        loss1 = nn.MSELoss(reduction="sum")(a, images)
        loss2 = torch.sum(c * f(a, samples, sigma))

        cost = loss1 + loss2

        optimizer.zero_grad()
        cost.backward()
        optimizer.step()

        # Early Stop when loss does not converge.
        if step % (max_iter // 10) == 0:
            if cost > prev:
                if printing:
                    print("Attack Stopped due to CONVERGENCE....")
                return a
            prev = cost
        if printing:
            print(
                "- Learning Progress : %2.2f %%        "
                % ((step + 1) / max_iter * 100),
                end="\r",
            )

        attack_images = 1 / 2 * (nn.Tanh()(w) + 1)

        adv_input = attack_images.repeat(samples, 1, 1, 1)
        adv_input += torch.randn_like(adv_input) * sigma

        if adv_input.shape[1] != 3:
            adv_input = adv_input.repeat(1, 3, 1, 1)

        sm_output, x, pred_classes = return_wrapper(model, adv_input, device)
        vals, indices = torch.topk(sm_output, 2, sorted=True)

        max_class, second_class = indices[0], indices[1]
        E0, E1 = vals[0], vals[1]

        if max_class != labels:
            E_0 = binom_conf(
                E0.detach().cpu().numpy() * samples, samples, alpha=alpha, method="beta"
            )[0]
            E_1 = binom_conf(
                E1.detach().cpu().numpy() * samples, samples, alpha=alpha, method="beta"
            )[
                1
            ]  # 1 - E_0
            E0, E1 = torch.tensor(E0).to(device), torch.tensor(E1).to(device)
            if E0 > E1:
                min_temp = (
                    norm_distance(attack_images, images, inpt_flag)
                    .detach()
                    .cpu()
                    .numpy()
                )
                if min_temp < min_val:
                    min_val = min_temp
                    final_cw_time = time.time() - start_time
                    min_recorded = attack_images.detach().clone()
                    if flag is False:
                        first_cw = min_temp
                        first_time = time.time() - start_time
                        flag = True

    if flag is False:
        final_cw_time = time.time() - start_time
        first_cw = min_val
        first_time = final_cw_time

    return (
        attack_images,
        flag,
        min_val,
        min_recorded,
        first_cw,
        first_time,
        final_cw_time,
    )


def direct_attack(
    model,
    yp,
    images,
    samples,
    class_set,
    z_val,
    min_recorded,
    sigma,
    device,
    inpt_flag,
    yp_best=None,
    lambda_val=0.5,
    delta=None,
    cutoff_step=0.5,
):
    stepsize = 0.05
    printing = False

    device = yp.device

    relu = torch.nn.ReLU()
    norm = torch.distributions.Normal(0, 1)

    alpha = 0.5 * (1 - stats_norm.cdf(z_val))

    yp.detach_()
    yp.requires_grad = True

    adv_input = yp.repeat(samples, 1, 1, 1)
    adv_input += torch.randn_like(adv_input) * sigma

    if adv_input.shape[1] != 3:
        adv_input = adv_input.repeat(1, 3, 1, 1)

    sm_output, x, pred_classes = return_wrapper(model, adv_input, device)

    vals, indices = torch.topk(sm_output, 2, sorted=True)

    max_class, second_class = indices[0], indices[1]

    E0, E1 = vals[0], vals[1]

    E0_t, E0_u_t = binom_conf(
        E0.detach().cpu().numpy() * samples, samples, alpha=alpha, method="beta"
    )
    E1_l_t, E1_t = binom_conf(
        E1.detach().cpu().numpy() * samples, samples, alpha=alpha, method="beta"
    )  # 1 - E_0
    E0_v, E1_v = E0.detach().cpu().numpy(), E1.detach().cpu().numpy()
    E0_t, E0_u_t = E0_t - E0_v, E0_u_t - E0_v
    E1_l_t, E1_t = E1_l_t - E1_v, E1_t - E1_v

    E0, E0_u, E1_l, E1 = (
        E0 + torch.tensor(E0_t).to(device),
        E0 + torch.tensor(E0_u_t).to(device),
        E1 + torch.tensor(E1_l_t).to(device),
        E1 + torch.tensor(E1_t).to(device),
    )

    cohen_val = relu(0.5 * sigma * (norm.icdf(relu(E0)) - norm.icdf(relu(E1))))

    # What are we looking for:
    # We want
    # - Predicted class to not be equal to class_set[0]
    # - If the above is true, then we also want E0 > E1 - ie that the lower bound of the predicted class is strictly above the upper bound of the secondary class
    # So to do this we want to minimize torch.abs(E0 - E1 - delta) when the classes don't match. When the classes do match, we want to minimize torch.abs(E0 - E1 + delta), where delta > 0
    # Does this work though? Because E0 == E1 is just saying that
    if delta is None:
        delta = torch.max(torch.abs(E0 - E0_u), torch.abs(E1_l - E1))
    delta = 0  # Just trying this out
    if max_class == class_set[0]:
        # Predction class matches. So we want to decrease the delta between the classes, and we want to penalise moving away from the origin.
        # We're minimizing. So we minimize the delta while moving further away from the origin?
        objective_recorded = torch.abs(E0 - E1 + 0.1)  # + delta)

        stepsize = torch.min(
            torch.max(1.05 * cohen_val, torch.tensor(0.05)), torch.tensor(cutoff_step)
        )
        if printing:
            print(E0 - E1, "#" * 5)

    else:
        # Now the prediction classes don't match. So now we want to bring it towards the point where the non-class prediction is only just higher than the non-class. So we want to minimize torch.abs(E0_u
        # We want non-class lower bound to be striclty above label-class upper bound. So it should still be (E0 - E1), but now we penalise moving away

        # If in this bracket then E0 is not the max_class. So we want E0 > E1 while minimising this difference. So E0 - E1 > 0, so minimise

        # objective_recorded = torch.abs(E0 - E1 - delta) + lambda_val * torch.linalg.norm(yp-images) #+ torch.linalg.norm(yp - images) # Now we want to bring them closer together, while penalising increasing the delta

        if printing:
            print(E0 - E1, torch.linalg.norm(yp - images), lambda_val, delta, "!" * 5)
        if (
            E0 < E1
        ):  # Adversarial example but not confident, then still prioritise the adversarial example
            objective_recorded = 10 * torch.abs(
                E0 - E1 - 0.05
            ) + lambda_val * norm_distance(
                yp, images, inpt_flag
            )  # torch.linalg.norm(yp-images)
            stepsize = torch.tensor(cutoff_step / 2)
        else:
            # Adversarial example found where E0 > E1, so in other words it's a clear, distinct adversarial example
            objective_recorded = torch.abs(E0 - E1) + lambda_val * norm_distance(
                yp, images, inpt_flag
            )  # Now we want to bring them closer together, while penalising
            stepsize = torch.min(
                torch.max(0.99 * cohen_val, torch.tensor(0.05)),
                torch.tensor(cutoff_step),
            )
        lambda_val *= 1.05
        # Currently predicting the secondary class

    yp_grad = torch.autograd.grad(objective_recorded, yp)[0].detach()
    yp_grad = yp_grad / (1e-5 + torch.linalg.norm(yp_grad))

    if (torch.isnan(torch.sum(yp_grad)) == 0) and ((E0 - E1) < (1 - 1e-5)):
        yp_new = yp - stepsize * yp_grad
    else:
        yp_new = yp + 0.05 * torch.randn_like(yp)

    yp_new = torch.clip(yp_new, 0, 1).detach()
    yp_new.requires_grad = True

    if max_class != class_set[0]:
        original_gap = (E0 - E1).detach()
        new_gap = E0 - E1
        if printing:
            print("AT: ", E0 - E1)
        if E0 > E1:
            distance = norm_distance(yp, images, inpt_flag)
            lambda_val *= 1.15
            if printing:
                print("First delta: ", delta)
            delta = torch.max(delta - (0.9 * (E0 - E1)), torch.tensor(0.001))
            if distance < min_recorded:
                min_recorded = distance
                return (
                    yp_new,
                    yp_grad,
                    min_recorded,
                    yp.detach().clone(),
                    lambda_val,
                    delta,
                )
        else:
            if printing:
                print("Second delta: ", delta)
            delta = delta + (1.01 * (E1 - E0))
            if printing:
                print("Changing delta to: ", delta)

    return yp_new, yp_grad, min_recorded, yp_best, lambda_val, delta


def deepfool_attack(
    model,
    images,
    label,
    device,
    classes,
    samples,
    sigma,
    overshoot=0.02,
    nb_candidate=10,
    max_iter=100,
):
    start_time = time.time()
    ori_images = images.detach()
    nb_candidate = np.min((nb_candidate, classes))

    images.requires_grad_()

    if ori_images.shape[1] != 3:
        images = images.repeat(1, 3, 1, 1)

    adv_input = images.repeat(samples, 1, 1, 1)
    adv_input = adv_input + torch.randn_like(adv_input) * sigma

    logits, _, _ = return_wrapper(model, adv_input, device)

    pred = torch.argmax(logits)

    w = torch.squeeze(torch.zeros(adv_input.size()[1:])).to(device)
    r_tot = torch.zeros(images.size()).to(device)

    iteration = 0

    basic_predic = basic_predict(model, adv_input, device)

    while (pred == label) and iteration < max_iter:
        predictions_val = torch.topk(logits, nb_candidate)[0]
        gradients = torch.stack(jacobian(predictions_val, images, nb_candidate), dim=1)
        with torch.no_grad():
            pert = 1e10
            if pred != label:
                continue
            for k in range(1, nb_candidate):
                w_k = gradients[0, k, ...] - gradients[0, 0, ...]
                f_k = predictions_val[k] - predictions_val[0]
                pert_k = (f_k.abs() + 0.00001) / w_k.view(-1).norm()
                if pert_k < pert:
                    pert = pert_k
                    w = w_k
            r_i = pert * w / w.view(-1).norm()
            r_tot += r_i

        if torch.sum(torch.isnan(r_tot)) > 0:
            return False, None, 0.0, time.time() - start_time
        images = torch.clamp(r_tot + images, 0, 1).detach().requires_grad_()

        if ori_images.shape[1] != 3:
            images = images.mean(dim=1).repeat(1, 3, 1, 1)

        adv_input = images.repeat(samples, 1, 1, 1)
        adv_input = adv_input + torch.randn_like(adv_input) * sigma

        logits, _, _ = return_wrapper(model, adv_input, device)
        pred = torch.argmax(logits)

        iteration = iteration + 1

    adv_x = torch.clamp((1 + overshoot) * r_tot + images, 0.0, 1.0)
    distance = torch.linalg.norm(adv_x - ori_images).detach().cpu().numpy()
    flag = False
    if pred != label:
        flag = True

    return flag, adv_x, distance, time.time() - start_time


def jacobian(predictions, x, classes):
    list_derivatives = []

    for class_ind in range(classes):
        outputs = predictions[class_ind]  # [:, class_ind]
        (derivatives,) = torch.autograd.grad(
            outputs, x, grad_outputs=torch.ones_like(outputs), retain_graph=True
        )
        list_derivatives.append(derivatives)

    return list_derivatives


def pgd_attack(
    model,
    images,
    labels,
    samples,
    sigma,
    device,
    epsilon=20 / 255,
    iters=40,
    z_val=4,
    probabilistic=False,
):
    # This is now the Iterative Fast Gradient Method for L2 Norms
    inpt_flag = True if images.shape[1] == 1 else False
    start_time = time.time()
    first_flag = False

    device = images.device
    images = images.to(device)
    labels = labels.to(device)
    loss = nn.CrossEntropyLoss()

    alpha_probability = 0.5 * (1 - stats_norm.cdf(z_val))

    ori_images = images.data.detach()
    flag = False
    adv_input = images.repeat(samples, 1, 1, 1)
    adv_input += torch.randn_like(adv_input) * sigma

    if adv_input.shape[1] != 3:
        adv_input = adv_input.repeat(1, 3, 1, 1)

    sm, _, _ = return_wrapper(model, adv_input, device)
    base_laebl = torch.argmax(sm)

    min_val = torch.tensor(1e6)
    min_recorded = None
    for i in range(iters):
        images.requires_grad = True

        adv_input = images.repeat(samples, 1, 1, 1)
        adv_input += torch.randn_like(adv_input) * sigma

        if adv_input.shape[1] != 3:
            adv_input = adv_input.repeat(1, 3, 1, 1)

        sm, _, _ = return_wrapper(model, adv_input, device)
        pred = torch.argmax(sm)
        if pred != labels[0]:
            second_class = sm.detach().clone()
            second_class[pred] = 0
            second_class = torch.argmax(second_class)
            pred, second_class = (
                pred.detach().cpu().numpy(),
                second_class.detach().cpu().numpy(),
            )
            m_c = multi_conf(
                samples * sm.detach().cpu().numpy(), alpha=alpha_probability
            )
            lab = labels[0].detach().cpu().numpy()

            if m_c[pred, 0] > m_c[lab, 1]:
                flag = True
                delta = norm_distance(images, ori_images, inpt_flag)
                if delta < min_val:
                    final_pgd_time = time.time() - start_time
                    min_val = delta
                    min_recorded = images.detach()
                    if first_flag is False:
                        first_pgd = min_val.detach().clone().cpu().numpy()
                        first_time = time.time() - start_time
                        first_flag = True

        model.zero_grad()
        cost = loss(sm.reshape(1, -1), labels).to(device)
        cost.backward()

        grad = images.grad
        grad_norm = torch.linalg.norm(grad)
        images = torch.clamp(
            images.detach() + epsilon * (grad / (grad_norm + 1e-12)), min=0, max=1
        ).detach_()

    if first_flag is False:
        final_pgd_time = time.time() - start_time
        first_time = final_pgd_time
        first_pgd = min_val.detach().cpu().numpy()

    return (
        images,
        flag,
        pred,
        min_val.detach().cpu().numpy(),
        min_recorded,
        first_pgd,
        first_time,
        final_pgd_time,
    )


def direct_loop(
    model,
    classes,
    labels,
    samples,
    yp,
    images,
    z_val,
    sigma,
    device,
    iters=100,
    new=True,
    cutoff_step=0.5,
):
    start_time = time.time()
    first_flag = False

    class_set = torch.zeros(classes).to(device)
    class_set[0] = labels[0]
    class_dummy = torch.arange(classes).to(device)

    class_set[1:] = class_dummy[class_dummy != labels[0]]
    min_recorded = 1000
    stationary_counter = 0
    sample_size = samples
    yp_best, lambda_val, delta_val = None, 0.5, None
    inpt_flag = True if images.shape[1] == 1 else False
    for iii in range(iters):
        min_recorded_old = min_recorded
        if new:
            yp, yp_grad, min_recorded, yp_best, lambda_val, delta_val = direct_attack(
                model,
                yp,
                images,
                sample_size,
                class_set,
                z_val,
                min_recorded,
                sigma,
                device,
                inpt_flag,
                yp_best=yp_best,
                lambda_val=lambda_val,
                delta=delta_val,
                cutoff_step=cutoff_step,
            )
        else:
            (
                yp,
                yp_grad,
                min_recorded,
                yp_best,
                lambda_val,
                delta_val,
            ) = direct_attack_old(
                model,
                yp,
                images,
                sample_size,
                class_set,
                z_val,
                min_recorded,
                sigma,
                device,
                inpt_flag,
                yp_best=yp_best,
                lambda_val=lambda_val,
                delta=delta_val,
                cutoff_step=cutoff_step,
            )

        if (min_recorded_old - 1e-5) > min_recorded:
            new_d_time = time.time() - start_time

        if min_recorded < 1000:
            if first_flag is False:
                first_time = time.time() - start_time
                first_radii = min_recorded.detach().cpu().numpy()
                first_flag = True
            if stationary_counter == 0:
                sample_size = samples
            if torch.abs(min_recorded - min_recorded_old) < 1e-5:
                stationary_counter += 1
            else:
                stationary_counter = 0

    if min_recorded >= 1000:
        new_d_time = time.time() - start_time
        first_time = new_d_time
        first_radii = min_recorded

    return new_d_time, min_recorded, first_time, first_radii


def evaluation_loop(
    device,
    model,
    test_loader,
    dataset_name,
    sigma,
    samples,
    classes,
    total_cutoff=250,
    plotting=True,
    autoattack_radii=-1,
    pgd_radii=20 / 255,
):

    new_generalised = True
    set_c, set_n = [], []
    set_c_p, set_n_p = [], []
    E0_set, E1_set = [], []
    pgd_set, our_set, new_attack_set = [], [], []
    autoattack_set = []
    counter = 0
    total = 0
    overall_count = 0
    total_cutoff = total_cutoff

    fgsm_time_set, cw_time_set, autoattack_time_set, direct_time_set, pgd_time_set = (
        [],
        [],
        [],
        [],
        [],
    )
    direct_attack_set, cw_set, fgsm_set = [], [], []
    original_time, original_set = [], []

    cw_times_s, fgsm_times_s, pgd_times_s, o_times_s = [], [], [], []
    cw_result_s, fgsm_result_s, pgd_result_s, o_result_s = [], [], [], []
    cw_times_u, fgsm_times_u, pgd_times_u, o_times_u = [], [], [], []

    evaluate_new = False

    cw_set, cw_time_set, cw_set_first, cw_time_set_first, cw_success = (
        [],
        [],
        [],
        [],
        [],
    )
    fgsm_set, fgsm_time_set, fgsm_set_first, fgsm_time_set_first, fgsm_success = (
        [],
        [],
        [],
        [],
        [],
    )
    (
        direct_time_set,
        direct_time_set_first,
        direct_set,
        direct_set_first,
        direct_success,
    ) = ([], [], [], [], [])
    pgd_time_set, pgd_time_set_first, pgd_set, pgd_set_first, pgd_success = (
        [],
        [],
        [],
        [],
        [],
    )
    deepfool_time_set, deepfool_attack_set, deepfool_success = [], [], []
    cohen_set = []

    bonus = ""
    if total_cutoff != 250:
        bonus = "cutoff-" + str(total_cutoff)

    z_val = 2.58
    norm = torch.distributions.Normal(0, 1)
    relu = torch.nn.ReLU()

    log = LogSet(
        dataset_name
        + "-"
        + str(sigma)
        + "-samples-"
        + str(samples)
        + bonus
        + "_v2_"
        + str(autoattack_radii)
        + "-"
        + str(pgd_radii),
        decimals=4,
        means=True,
        console=True,
    )

    for images, labels in test_loader:
        print("In evaluation loop #3", flush=True)
        inpt_flag = True if images.shape[1] == 1 else False

        inference_time = time.time()
        images, labels = images.to(device), labels.to(device)
        adv_input = images.repeat(samples, 1, 1, 1)
        adv_input = adv_input + torch.randn_like(adv_input) * sigma
        pred_baseline = basic_predict(model, adv_input, device)
        print("AAA", adv_input.shape, pred_baseline, labels, samples)

        inference_time = time.time() - inference_time

        if total < total_cutoff:
            overall_count += 1
        else:
            break

        if pred_baseline == labels:
            if total < total_cutoff:
                total += 1
                log.append(total, "ix")
                log.append(overall_count - total, "rej")
                log.append(labels, "lab")
                log.append(inference_time, "inf_t")

                (
                    _,
                    cw_flag,
                    min_val,
                    cw,
                    first_cw,
                    first_cw_time,
                    final_cw_time,
                ) = cw_l2_attack(
                    model,
                    images,
                    labels,
                    samples,
                    sigma,
                    device,
                    z_val=z_val,
                    definitive=True,
                )
                if cw_flag is True:
                    cw_distance = min_val
                    cw_success = 1
                else:
                    cw_distance, final_cw_time, first_cw, first_cw_time = (
                        np.nan,
                        np.nan,
                        np.nan,
                        np.nan,
                    )
                    cw_success = 0

                log.append(cw_distance, "cw_d")
                log.append(final_cw_time, "cw_t")
                log.append(first_cw, "cw_f_d")
                log.append(first_cw_time, "cw_f_t")
                log.append(cw_success, "cw_s")

                yp = images.clone().detach()

                new_d_time, new_attack, first_time, first_radii = direct_loop(
                    model, classes, labels, samples, yp, images, z_val, sigma, device
                )

                if new_attack >= 1000:
                    new_attack, first_radii, first_time, new_d_time = (
                        np.nan,
                        np.nan,
                        np.nan,
                        np.nan,
                    )
                    direct_success = 0
                else:
                    new_attack = new_attack.detach().cpu().numpy()
                    direct_success = 1

                log.append(new_attack, "n_d")
                log.append(new_d_time, "n_t")
                log.append(first_radii, "n_f_d")
                log.append(first_time, "n_f_t")
                log.append(direct_success, "n_s")

                deep_flag, _, deep_dist, deep_time = deepfool_attack(
                    model,
                    images,
                    labels,
                    device,
                    classes,
                    samples,
                    sigma,
                    overshoot=0.02,
                    nb_candidate=10,
                    max_iter=100,
                )
                if deep_flag:
                    deep_success = 1
                else:
                    deep_success = 0
                    deep_time, deep_dist = np.nan, np.nan

                log.append(deep_dist, "d_d")
                log.append(deep_time, "d_t")
                log.append(deep_success, "d_s")

                cohen_time = time.time()

                alpha = 0.5 * (1 - stats_norm.cdf(z_val))
                sm_output, x, pred_classes = return_wrapper(model, adv_input, device)
                vals, indices = torch.topk(sm_output, 2, sorted=True)

                max_class, second_class = indices[0], indices[1]
                max_count, second_count = (
                    vals[0].detach().cpu().numpy() * samples,
                    vals[1].detach().cpu().numpy() * samples,
                )

                alpha = 0.5 * (1 - stats_norm.cdf(z_val))

                E0_set.append(max_count / samples)
                E1_set.append(second_count / samples)

                E_0c = binom_conf(max_count, samples, alpha=alpha, method="beta")[0]
                E_1c = binom_conf(second_count, samples, alpha=alpha, method="beta")[1]

                baseline_cohen = np.max(
                    [
                        (
                            0.5
                            * sigma
                            * (
                                norm.icdf(torch.tensor(E_0c))
                                - norm.icdf(torch.tensor(E_1c))
                            )
                        )
                        .detach()
                        .cpu()
                        .numpy(),
                        0,
                    ]
                )
                cohen_time = time.time() - cohen_time

                log.append(baseline_cohen, "co_d")
                log.append(cohen_time, "co_t")
                log.append(E_0c, "E_0")

                pgd_time = time.time()
                (
                    pgd,
                    pgd_flag,
                    pgd_pred,
                    min_val,
                    min_recorded,
                    first_pgd,
                    first_pgd_time,
                    final_pgd_time,
                ) = pgd_attack(
                    model,
                    images,
                    labels,
                    samples,
                    sigma,
                    device,
                    epsilon=pgd_radii,
                    iters=100,
                    probabilistic=True,
                    z_val=z_val,
                )
                pgd_time = time.time() - pgd_time
                if (min_val is not None) and (min_val > 1e-5):
                    pgd_success = 1
                    pgd_radius = min_val
                else:
                    pgd_flag = 0
                    pgd_radius, pgd_time, first_pgd_time, first_pgd = (
                        np.nan,
                        np.nan,
                        np.nan,
                        np.nan,
                    )

                log.append(pgd_radius, "pgd_d")
                log.append(pgd_time, "pgd_t")
                log.append(first_pgd, "pgd_f_d")
                log.append(first_pgd_time, "pgd_f_t")
                log.append(pgd_success, "pgd_s")

                log.print()


def cutoff_test(device, model, test_loader, dataset_name, sigma, samples, classes):
    z_val = 2.58
    limit = 200
    overall_count = 0

    cutoff_steps = [0.05, 0.25, 0.5, 0.75, 1.0, 1.5, 3.0, 5.0, 10.0, 15.0]
    (
        direct_time_set,
        direct_first_time_set,
        direct_attack_set,
        direct_first_attack_set,
    ) = ({}, {}, {}, {})

    for cutoff in cutoff_steps:
        (
            direct_time_set[cutoff],
            direct_first_time_set[cutoff],
            direct_attack_set[cutoff],
            direct_first_attack_set[cutoff],
        ) = ([], [], [], [])

    cutoff_radii = 1e6

    for counter, (images, labels) in enumerate(test_loader):

        images, labels = images.to(device), labels.to(device)
        adv_input = images.repeat(samples, 1, 1, 1)
        adv_input = adv_input + torch.randn_like(adv_input) * sigma
        pred_baseline = basic_predict(model, adv_input, device)

        if pred_baseline == labels:

            if overall_count < limit:
                print(overall_count, flush=True)
                overall_count += 1
            else:
                break

            for cutoff_step in cutoff_steps:
                yp = images.detach().clone()

                first_flag = False

                new_d_time, min_recorded, first_time, first_radii = direct_loop(
                    model,
                    classes,
                    labels,
                    samples,
                    yp,
                    images,
                    z_val,
                    sigma,
                    device,
                    iters=100,
                    new=True,
                    cutoff_step=cutoff_step,
                )

                if min_recorded >= 1000:
                    new_attack = cutoff_radii  # np.array(0.)
                    first_radii = cutoff_radii
                    first_time = new_d_time
                else:
                    new_attack = min_recorded.detach().cpu().numpy()

                direct_time_set[cutoff_step].append(new_d_time)
                direct_first_time_set[cutoff_step].append(first_time)
                direct_attack_set[cutoff_step].append(new_attack)
                direct_first_attack_set[cutoff_step].append(first_radii)

    for cutoff_step in cutoff_steps:
        time_1 = np.asarray(direct_first_time_set[cutoff_step])
        time_f = np.asarray(direct_time_set[cutoff_step])

        attack_1 = np.asarray(direct_first_attack_set[cutoff_step])
        attack_f = np.asarray(direct_attack_set[cutoff_step])

        successful = attack_f < 1000
        time_1, time_f, attack_1, attack_f = (
            time_1[successful],
            time_f[successful],
            attack_1[successful],
            attack_f[successful],
        )

        print("#" * 20, dataset_name, flush=True)
        print(
            "Cutoff is : {}, Successful proportion was: {}".format(
                cutoff_step, np.sum(successful) / len(successful)
            ),
            flush=True,
        )
        print(
            "First Time Mean : {}, First Time Median : {}".format(
                np.mean(time_1), np.median(time_1)
            ),
            flush=True,
        )
        print(
            "Time Mean : {}, Time Median : {}".format(
                np.mean(time_f), np.median(time_f)
            ),
            flush=True,
        )
        # print('Mean First Time : {}, Median First Time: {}'.format(np.mean(time_1), np.median(time_1)), flush=True)

        print(
            "First Mean : {}, First Median : {}".format(
                np.mean(attack_1), np.median(attack_1)
            ),
            flush=True,
        )
        print(
            "Mean : {}, Median : {}".format(np.mean(attack_f), np.median(attack_f)),
            flush=True,
        )


def cr_loop(
    device,
    model,
    test_loader,
    dataset_name,
    sigma,
    samples,
    classes,
    total_cutoff=250,
    plotting=True,
):

    counter = 0
    total = 0
    overall_count = 0
    total_cutoff = 4000  # total_cutoff

    bonus = ""
    if total_cutoff != 250:
        bonus = "cutoff-" + str(total_cutoff)

    with open(
        dataset_name
        + "-"
        + str(sigma)
        + "-samples-"
        + str(samples)
        + bonus
        + "VVV.txt",
        "w",
    ) as f:
        z_val = 2.58
        norm = torch.distributions.Normal(0, 1)
        relu = torch.nn.ReLU()

        print(
            "# \t Maxb \t E_0 \t E_0hat \t # \t C_t \t C \t O_t \t O \t S_t \t S \t R",
            flush=True,
        )
        print(
            "# \t Maxb \t E_0 \t E_0hat \t # \t C_t \t C \t O_t \t O \t S_t \t S \t R",
            file=f,
            flush=True,
        )

        for images, labels in test_loader:
            inpt_flag = True if images.shape[1] == 1 else False

            images, labels = images.to(device), labels.to(device)
            adv_input = images.repeat(samples, 1, 1, 1)
            adv_input = adv_input + torch.randn_like(adv_input) * sigma
            pred_baseline = basic_predict(model, adv_input, device)

            if total < total_cutoff:
                overall_count += 1
            else:
                break

            if pred_baseline == labels:
                if total < total_cutoff:
                    yp = images.clone().detach()
                    total += 1

                    cohen_time = time.time()

                    alpha = 0.5 * (1 - stats_norm.cdf(z_val))
                    sm_output, x, pred_classes = return_wrapper(
                        model, adv_input, device
                    )
                    vals, indices = torch.topk(sm_output, 2, sorted=True)

                    max_class, second_class = indices[0], indices[1]
                    max_count, second_count = (
                        vals[0].detach().cpu().numpy() * samples,
                        vals[1].detach().cpu().numpy() * samples,
                    )

                    alpha = 0.5 * (1 - stats_norm.cdf(z_val))

                    E_0c = binom_conf(max_count, samples, alpha=alpha, method="beta")[0]
                    E_1c = binom_conf(
                        second_count, samples, alpha=alpha, method="beta"
                    )[1]

                    baseline_cohen = (
                        (
                            0.5
                            * sigma
                            * (
                                norm.icdf(torch.tensor(E_0c))
                                - norm.icdf(torch.tensor(E_1c))
                            )
                        )
                        .detach()
                        .cpu()
                        .numpy()
                    )
                    cohen_time = time.time() - cohen_time

                    new_d_time, new_radii = cr_loop_control(
                        model,
                        classes,
                        labels,
                        samples,
                        yp,
                        images,
                        z_val,
                        sigma,
                        device,
                        new=False,
                    )
                    new_radii = new_radii.detach().cpu().numpy()

                    secondary_time, secondary_radii = secondary_cr_step_loop(
                        model,
                        classes,
                        labels,
                        samples,
                        yp,
                        images,
                        z_val,
                        sigma,
                        device,
                        baseline_cohen,
                    )

                    max_baseline = np.max([baseline_cohen, 0])
                    if max_baseline > 0:
                        ratio = np.max([secondary_radii, new_radii]) / max_baseline

                    if new_d_time > 9.9999:
                        new_d_time = 9.49

                    print(
                        "A \t {:.2f} \t {:.2f} \t {:.2f} \t # \t {:.2f} \t {:.2f} \t {:.2f} \t {:.2f} \t {:.2f} \t {:.2f} \t {:.2f}".format(
                            max_baseline,
                            max_count / samples,
                            E_0c.item(),
                            cohen_time,
                            baseline_cohen,
                            new_d_time,
                            new_radii,
                            secondary_time,
                            secondary_radii,
                            ratio,
                        ),
                        flush=True,
                    )
                    print(
                        "A \t {:.2f} \t {:.2f} \t {:.2f} \t # \t {:.2f} \t {:.2f} \t {:.2f} \t {:.2f} \t {:.2f} \t {:.2f} \t {:.2f}".format(
                            max_baseline,
                            max_count / samples,
                            E_0c.item(),
                            cohen_time,
                            baseline_cohen,
                            new_d_time,
                            new_radii,
                            secondary_time,
                            secondary_radii,
                            ratio,
                        ),
                        file=f,
                        flush=True,
                    )


def cr_loop_control(
    model,
    classes,
    labels,
    samples,
    yp,
    images,
    z_val,
    sigma,
    device,
    iters=100,
    new=True,
    cutoff_step=0.5,
):
    start_time = time.time()
    first_flag = False
    new_d_time = None

    class_set = torch.zeros(classes).to(device)
    class_set[0] = labels[0]
    class_dummy = torch.arange(classes).to(device)

    class_set[1:] = class_dummy[class_dummy != labels[0]]
    max_recorded = 0
    stationary_counter = 0
    sample_size = samples  # int(model.samples / 5)
    yp_best, lambda_val, delta_val = None, 0.5, None
    inpt_flag = True if images.shape[1] == 1 else False
    cohen_best = 0
    for iii in range(iters):
        max_recorded_old = max_recorded
        # print('This is the target ', labels, class_set[0])

        yp, yp_grad, max_recorded, _, _ = new_cr_step(
            model,
            yp,
            images,
            samples,
            class_set,
            z_val,
            max_recorded,
            sigma,
            device,
            inpt_flag,
            cohen_best,
            yp_best=None,
            lambda_val=0.5,
            delta=None,
            cutoff_step=0.5,
        )

        if (max_recorded_old + 1e-5) < max_recorded:
            new_d_time = time.time() - start_time
            print(iii, max_recorded, flush=True)

        if max_recorded > 0:
            if first_flag is False:
                first_time = time.time() - start_time
                first_radii = max_recorded.detach().cpu().numpy()
                first_flag = True
            # print('This is min recorded ', min_recorded)
            if stationary_counter == 0:
                sample_size = samples
            if torch.abs(max_recorded - max_recorded_old) < 1e-5:
                stationary_counter += 1
            else:
                stationary_counter = 0
            # if stationary_counter > 15:
            # break # Removing this break condition

    # new_d_time = time.time() - new_d_time

    if max_recorded == 1000:
        new_d_time = time.time() - start_time
        first_time = new_d_time
        first_radii = max_recorded

    if new_d_time is None:
        new_d_time = time.time() - start_time
        max_recorded = torch.tensor(0)

    return new_d_time, max_recorded


def secondary_cr_step_loop(
    model,
    classes,
    labels,
    samples,
    yp,
    images,
    z_val,
    sigma,
    device,
    baseline_cohen,
    iters=100,
):  # (model, yp, images, samples, class_set, z_val, max_recorded, sigma, device, baseline_cohen, yp_best=None, lambda_val=0.5, delta=None, cutoff_step=0.5):
    yp_original = yp.detach().clone()

    start_time = time.time()
    first_flag = False
    new_d_time = None

    relu = torch.nn.ReLU()
    norm = torch.distributions.Normal(0, 1)

    class_set = torch.zeros(classes).to(device)
    class_set[0] = labels[0]
    class_dummy = torch.arange(classes).to(device)

    class_set[1:] = class_dummy[class_dummy != labels[0]]
    max_recorded = torch.tensor(0.0)
    stationary_counter = 0
    sample_size = samples  # int(model.samples / 5)
    yp_best, lambda_val, delta_val = None, 0.5, None
    inpt_flag = True if images.shape[1] == 1 else False
    cohen_best = 0
    yp_best = None
    for iii in range(iters):
        max_recorded_old = max_recorded
        yp, yp_grad, max_recorded, yp_best, cohen_best = new_cr_step(
            model,
            yp,
            images,
            samples,
            class_set,
            z_val,
            max_recorded,
            sigma,
            device,
            inpt_flag,
            cohen_best,
            yp_best=yp_best,
            lambda_val=0.5,
            delta=None,
            cutoff_step=0.5,
        )

    print(max_recorded, flush=True)
    try:
        print(yp_best.shape, flush=True)
    except:
        print("111", flush=True)

    if yp_best is not None:
        yp_best_save = yp_best.detach().clone()
        vector_project = yp_original.detach().clone() - yp_best

        vector_project /= torch.linalg.norm(vector_project)  # **2
        cr = baseline_cohen

        yp = yp_original + 0.95 * cr * (
            vector_project
        )  # Moves to near the border. Could also change this to 0.95*max_recorded

        aux_iters = 50
        step_size = 0.15

        # max_recorded = 0
        alpha = 0.5 * (1 - stats_norm.cdf(z_val))  # 0.005

        for iii in range(aux_iters):
            #####################

            yp.detach_()
            yp.requires_grad = True

            adv_input = yp.repeat(samples, 1, 1, 1)
            adv_input += (
                torch.randn_like(adv_input) * sigma
            )  # torch.normal(mean=torch.zeros_like(adv_input), std=sigma)

            sm_output, x, pred_classes = return_wrapper(model, adv_input, device)

            # grad_vector = torch.autograd.grad(torch.sum(sm_output), yp)[0].detach()

            vals, indices = torch.topk(sm_output, 2, sorted=True)

            max_class, second_class = (
                indices[0],
                indices[1],
            )  # indices[0][0], indices[0][1]

            E0, E1 = vals[0], vals[1]  # vals[0]

            E0_t, E0_u_t = binom_conf(
                E0.detach().cpu().numpy() * samples, samples, alpha=alpha, method="beta"
            )
            E1_l_t, E1_t = binom_conf(
                E1.detach().cpu().numpy() * samples, samples, alpha=alpha, method="beta"
            )  # 1 - E_0
            E0_v, E1_v = E0.detach().cpu().numpy(), E1.detach().cpu().numpy()
            E0_t, E0_u_t = E0_t - E0_v, E0_u_t - E0_v
            E1_l_t, E1_t = E1_l_t - E1_v, E1_t - E1_v

            # So E0_t < 0 because E0_t < E0_v

            # print('Before ', E0, E1)
            E0, E0_u, E1_l, E1 = (
                E0 + torch.tensor(E0_t).to(device),
                E0 + torch.tensor(E0_u_t).to(device),
                E1 + torch.tensor(E1_l_t).to(device),
                E1 + torch.tensor(E1_t).to(device),
            )

            cohen_val = relu(0.5 * sigma * (norm.icdf(relu(E0)) - norm.icdf(relu(E1))))

            stepsize = 0.1
            if max_class == class_set[0]:
                if (
                    E0 > E1
                ):  # So predcition classes match and it's a confident prediction
                    # _2 : The off centre ball
                    x_2 = torch.linalg.norm(
                        yp_best_save - yp_original
                    )  # Positive because we re-normalise the space to say that all 3 balls are on a single vector, _2 is to the right of the main ball, _3 is to the left
                    r_2 = cohen_best

                    # _3 : The new ball
                    x_3 = -1 * torch.linalg.norm(
                        yp - yp_original
                    )  # Opposite side of the original circle to x_1
                    r_3 = cohen_val

                    numerator = (
                        x_2 * (r_3**2)
                        - x_3 * (r_2**2)
                        + x_3 * (x_2**2)
                        - (x_3**2) * x_2
                    )
                    denominator = x_2 - x_3
                    objective_recorded = torch.sqrt(numerator / denominator)

                    if objective_recorded > max_recorded:
                        max_recorded = objective_recorded
                        yp_best = yp
                        print(
                            "max recorded is bob dylans lovely moment",
                            max_recorded,
                            iii,
                            flush=True,
                        )
                else:
                    objective_recorded = (
                        E0 - E1
                    )  # Just aim to increase E0 relative ot E1
            else:
                objective_recorded = -torch.linalg.norm(
                    yp - yp_original
                )  # Just move back towards the origin

            grad_vector = torch.autograd.grad(objective_recorded, yp)[0].detach()

            perturbation = step_size * grad_vector / torch.linalg.norm(grad_vector)
            yp_temp = yp + perturbation

            projected_step = (
                torch.sum(perturbation * vector_project) * vector_project
            )  # / vector_project_magnitude
            yp = yp + projected_step

            step_size = 0.99 * step_size

    return time.time() - start_time, max_recorded.detach().cpu().numpy()


def new_cr_step(
    model,
    yp,
    images,
    samples,
    class_set,
    z_val,
    max_recorded,
    sigma,
    device,
    inpt_flag,
    cohen_best,
    yp_best=None,
    lambda_val=0.5,
    delta=None,
    cutoff_step=0.5,
):

    stepsize = 0.05
    printing = False

    device = yp.device

    relu = torch.nn.ReLU()
    norm = torch.distributions.Normal(0, 1)

    alpha = 0.5 * (1 - stats_norm.cdf(z_val))

    yp.detach_()
    yp.requires_grad = True

    adv_input = yp.repeat(samples, 1, 1, 1)
    adv_input += torch.randn_like(adv_input) * sigma

    if adv_input.shape[1] != 3:
        adv_input = adv_input.repeat(1, 3, 1, 1)

    sm_output, x, pred_classes = return_wrapper(model, adv_input, device)

    if printing:
        print("#" * 20, pred_classes.shape, yp.shape)

    vals, indices = torch.topk(sm_output, 2, sorted=True)

    max_class, second_class = indices[0], indices[1]  # indices[0][0], indices[0][1]

    E0, E1 = vals[0], vals[1]  # vals[0]

    E0_t, E0_u_t = binom_conf(
        E0.detach().cpu().numpy() * samples, samples, alpha=alpha, method="beta"
    )
    E1_l_t, E1_t = binom_conf(
        E1.detach().cpu().numpy() * samples, samples, alpha=alpha, method="beta"
    )  # 1 - E_0
    E0_v, E1_v = E0.detach().cpu().numpy(), E1.detach().cpu().numpy()
    E0_t, E0_u_t = E0_t - E0_v, E0_u_t - E0_v
    E1_l_t, E1_t = E1_l_t - E1_v, E1_t - E1_v

    # So E0_t < 0 because E0_t < E0_v

    E0, E0_u, E1_l, E1 = (
        E0 + torch.tensor(E0_t).to(device),
        E0 + torch.tensor(E0_u_t).to(device),
        E1 + torch.tensor(E1_l_t).to(device),
        E1 + torch.tensor(E1_t).to(device),
    )

    cohen_val = relu(0.5 * sigma * (norm.icdf(relu(E0)) - norm.icdf(relu(E1))))

    stepsize = 0.1
    if max_class == class_set[0]:
        if E0 > E1:  # So predcition classes match and it's a confident prediction
            objective_recorded = cohen_val - torch.linalg.norm(yp - images)
            if objective_recorded > max_recorded:
                max_recorded = objective_recorded
                yp_best = yp
                cohen_best = cohen_val

        else:
            objective_recorded = E0 - E1  # Just aim to increase E0 relative ot E1
    else:
        objective_recorded = -torch.linalg.norm(
            yp - images
        )  # Move back towards the origin

    yp_grad = torch.autograd.grad(objective_recorded, yp)[0].detach()
    yp_grad = yp_grad / (1e-5 + torch.linalg.norm(yp_grad))
    yp_new = yp + stepsize * yp_grad  # Maximising

    yp_new = torch.clip(yp_new, 0, 1).detach()
    yp_new.requires_grad = True

    return yp_new, yp_grad, max_recorded, yp_best, cohen_best


def scaling_analysis(
    device,
    model,
    test_loader,
    dataset_name,
    sigma,
    samples,
    classes,
    total_cutoff=250,
    plotting=True,
    autoattack_radii=-1,
    pgd_radii=20 / 255,
):

    new_generalised = True
    set_c, set_n = [], []
    set_c_p, set_n_p = [], []
    E0_set, E1_set = [], []
    pgd_set, our_set, new_attack_set = [], [], []
    autoattack_set = []
    counter = 0
    total = 0
    overall_count = 0
    total_cutoff = total_cutoff

    fgsm_time_set, cw_time_set, autoattack_time_set, direct_time_set, pgd_time_set = (
        [],
        [],
        [],
        [],
        [],
    )
    direct_attack_set, cw_set, fgsm_set = [], [], []
    original_time, original_set = [], []

    cw_times_s, fgsm_times_s, pgd_times_s, o_times_s = [], [], [], []
    cw_result_s, fgsm_result_s, pgd_result_s, o_result_s = [], [], [], []
    cw_times_u, fgsm_times_u, pgd_times_u, o_times_u = [], [], [], []

    evaluate_new = False

    cw_set, cw_time_set, cw_set_first, cw_time_set_first, cw_success = (
        [],
        [],
        [],
        [],
        [],
    )
    fgsm_set, fgsm_time_set, fgsm_set_first, fgsm_time_set_first, fgsm_success = (
        [],
        [],
        [],
        [],
        [],
    )
    (
        direct_time_set,
        direct_time_set_first,
        direct_set,
        direct_set_first,
        direct_success,
    ) = ([], [], [], [], [])
    pgd_time_set, pgd_time_set_first, pgd_set, pgd_set_first, pgd_success = (
        [],
        [],
        [],
        [],
        [],
    )
    deepfool_time_set, deepfool_attack_set, deepfool_success = [], [], []
    cohen_set = []

    bonus = ""
    if total_cutoff != 250:
        bonus = "cutoff-" + str(total_cutoff)

    z_val = 2.58
    norm = torch.distributions.Normal(0, 1)
    relu = torch.nn.ReLU()

    log = LogSet(
        "sigma_variance_test_"
        + dataset_name
        + "-"
        + str(sigma)
        + "-samples-"
        + str(samples)
        + bonus
        + "_v2_"
        + str(autoattack_radii)
        + "-"
        + str(pgd_radii),
        decimals=4,
        means=True,
        console=True,
    )

    scalings = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5]

    for images, labels in test_loader:
        inpt_flag = True if images.shape[1] == 1 else False

        inference_time = time.time()
        images, labels = images.to(device), labels.to(device)
        adv_input = images.repeat(samples, 1, 1, 1)
        adv_input = adv_input + torch.randn_like(adv_input) * sigma
        pred_baseline = basic_predict(model, adv_input, device)

        inference_time = time.time() - inference_time

        if total < total_cutoff:
            overall_count += 1
        else:
            break

        if pred_baseline == labels:
            if total < total_cutoff:
                total += 1
                log.append(total, "ix")
                log.append(overall_count - total, "rej")
                log.append(labels, "lab")
                log.append(inference_time, "inf_t")

                for inx, scaling in enumerate(scalings):
                    yp = images.clone().detach()

                    new_d_time, new_attack, first_time, first_radii = direct_loop(
                        model,
                        classes,
                        labels,
                        samples,
                        yp,
                        images,
                        z_val,
                        sigma,
                        device,
                        scaling_sigma=scaling,
                    )

                    if new_attack >= 1000:
                        new_attack, first_radii, first_time, new_d_time = (
                            np.nan,
                            np.nan,
                            np.nan,
                            np.nan,
                        )
                        direct_success = 0
                    else:
                        new_attack = new_attack.detach().cpu().numpy()
                        direct_success = 1

                    log.append(new_attack, "n_d_" + str(scaling))
                    log.append(direct_success, "n_s_" + str(scaling))

                log.print()
