import torch
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm as tqdm
import math
import os
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import json
import torchvision.transforms.functional as ttf

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

def validation_step(mdl, batch):
    images, labels = batch 
    out = mdl(images) 
    loss = F.cross_entropy(out, labels)
    acc = accuracy(out, labels) 
    return {'val_loss': loss.detach(), 'val_acc': acc}
        
def validation_epoch_end(outputs):
    batch_losses = [x['val_loss'] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()
    batch_accs = [x['val_acc'] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()
    return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}

def evaluate(model, val_loader):
    outputs = [validation_step(model, batch) for batch in val_loader]
    return validation_epoch_end(outputs)

def to_device(data, device):
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
    def __iter__(self):
        for b in self.dl: 
            yield to_device(b, self.device)
    def __len__(self):
        return len(self.dl)

def compute_fgsm(x, y, k, eps = 1e-2, norm = np.inf, xi = 10, device = device):
    x.retain_grad()
    loss = F.cross_entropy(k(x), y)
    loss.backward()
    return eps * torch.sign(x.grad)

def scale_im(im):
    return (im - im.min())/(im.max() - im.min())

def expand_first(im):
    return im.reshape(1, *im.shape)

def show_attack(ims, attack, mdl):
    if ims.dim() == 3:
        ims = expand_first(ims)
    if attack.dim() == 3:
        attack = expand_first(attack)
    plt.figure()
    f, axarr = plt.subplots(ims.shape[0],3, figsize = (20, 5 * ims.shape[0])) 
    labels = torch.argmax(mdl(ims), axis = 1).cpu().numpy()
    labels_att = torch.argmax(mdl(attack), axis = 1).cpu().numpy()
    labels_adv = torch.argmax(mdl(ims + attack), axis = 1).cpu().numpy()
    ims = ims.detach().cpu()
    attack = np.squeeze(attack.detach().cpu())
    for i in range(ims.shape[0]):
        im = ims[i, :, :, :]
        axarr[(i, 0) if ims.shape[0] > 1 else 0].imshow(scale_im(im).permute(1,2,0), interpolation='nearest')
        axarr[(i, 0) if ims.shape[0] > 1 else 0].set_title(class_idx[int(labels[i])])
        axarr[(i, 1) if ims.shape[0] > 1 else 1].imshow(scale_im(attack).permute(1,2,0), interpolation='nearest')
        axarr[(i, 1) if ims.shape[0] > 1 else 1].set_title(class_idx[int(labels_att)])
        axarr[(i, 2) if ims.shape[0] > 1 else 2].imshow(scale_im(im + attack).permute(1,2,0), interpolation='nearest')
        axarr[(i, 2) if ims.shape[0] > 1 else 2].set_title(class_idx[int(labels_adv[i])])
    plt.show()

def project_lp(v, norm, xi, exact = False, device = device):
    if v.dim() == 4:
        batch_size = v.shape[0]
    else:
        batch_size = 1
    if exact:
        if norm == 2:
            if batch_size == 1:
                v = v * xi/torch.norm(v, p = 2)
            else:
                v = v * xi/torch.norm(v, p = 2, dim = (1,2,3)).reshape((batch_size, 1, 1, 1))
        elif norm == np.inf:        
            v = torch.sign(v) * torch.minimum(torch.abs(v), xi*torch.ones(v.shape, device = device))
        else:
            raise ValueError('L_{} norm not implemented'.format(norm))
    else:
        if norm == 2:
            if batch_size == 1:
                v = v * torch.minimum(torch.ones((1), device = device), xi/torch.norm(v, p = 2))
            else:
                v = v * torch.minimum(xi/torch.norm(v, p = 2, dim = (1,2,3)), torch.ones(batch_size, device = device)).reshape((batch_size, 1, 1, 1))
        elif norm == np.inf:        
            v = torch.sign(v) * torch.minimum(torch.abs(v), xi*torch.ones(v.shape, device = device))
        else:
            raise ValueError('L_{} norm not implemented'.format(norm))
    return v

def err(ds, v, k, verbose = False):
    tot = 0
    acc = 0
    for batch in (tqdm(ds) if verbose else ds):
        out = k(batch[0] + v)
        acc += accuracy(out, batch[1]) * len(batch[1])
        tot += len(batch[1])
    return 1 - acc/tot

def robust_err(ds, v, k, times = 5, vals = [10, (2,2), 0.02, 0.02, (0.02, 0.001)], verbose = False):
    acc = 0
    for i in range(times):
        temp_v = apply_rand_perm(v, vals)
        acc + err(ds, temp_v, k, verbose)
    return 1 - acc/times

def err_batch(batch, v, k, verbose = False):
    out = k(batch[0] + v)
    return 1 - accuracy(out, batch[1]) * len(batch[1])

def robust_err_batch(batch, v, k, times = 5, vals = [10, (2,2), 0.02, 0.02, (0.02, 0.001)], verbose = False):
    acc = 0
    for i in range(times):
        temp_v = apply_rand_perm(v, vals)
        acc + err_batch(batch, temp_v, k, verbose)
    return 1 - acc/times

def compute_pgd(x, y, k, norm = np.inf, xi = 10, epochs = 40, random_restart = 4, step_size = 1e-2, device = device):
    batch_size = x.shape[0]
    max_loss = F.cross_entropy(k(x), y)
    max_X = torch.zeros_like(x)
    random_delta = torch.rand(size = (batch_size * random_restart, *x.shape[1:]), device = device) - 0.5
    random_delta = project_lp(random_delta, norm = norm, xi = xi, exact = True, device = device)
    x = x.repeat(random_restart, 1, 1, 1)
    y = y.repeat(random_restart)
    for j in range(epochs):
        pert_x = x + random_delta
        pert_x.requires_grad = True
        loss = F.cross_entropy(k(pert_x), y)
        loss.backward()
        pert = step_size * torch.sign(pert_x.grad)
        random_delta = project_lp(random_delta + pert, norm = norm, xi = xi)
    _,idx = torch.max(F.cross_entropy(mdl(x + random_delta), y, reduction = 'none').reshape(random_restart, batch_size), axis = 0)
    return random_delta[idx * batch_size + torch.arange(batch_size, dtype = torch.int64, device = device)]

def jacobian_idx(x, k, idx, num_classes, tot_classes, device = device):
    batch_size = x.shape[0]
    image_size = x.shape[1:]
    jacobian = torch.zeros((batch_size, num_classes - 1, *image_size), device = device)
    index_tensor = torch.zeros((1, tot_classes), device = device)
    for i, classes in enumerate(idx):
        x_i = x[i:i+1]
        x_i.retain_grad()
        out = k(x_i)
        index_tensor[0,classes[0]] = -1
        for j, c in enumerate(classes[1:]):
            index_tensor[0,c] = 1
            out.backward(index_tensor, retain_graph = j < num_classes - 1)
            jacobian[i,j,:] = x_i.grad.data
            x_i.grad.data.zero_()
            index_tensor[0,c] = 0
        index_tensor[0,classes[0]] = -1
    return jacobian

def compute_deepfool(x, y, k, norm = 2, xi = 5, num_classes = 10, deepfool_iter = 50, overshoot = 0.02, device = device):
    y_pred = k(x)
    tot_classes = y_pred.shape[1]
    y_pred, class_idx = torch.sort(y_pred, axis = 1, descending = True)
    y_pred, class_idx = y_pred[:,0:num_classes], class_idx[:,0:num_classes]
    batch_size = x.shape[0]
    pred_label = class_idx[:,0]
    itr = 0
    pert = torch.zeros_like(x)
    pert.requires_grad = True
    pert_img = x + pert
    y_pred_pert = y_pred
    unfinished = torch.arange(batch_size, device = device)
    pred_base_idx = torch.arange(batch_size).reshape(batch_size, 1).expand(batch_size, num_classes)
    #pbar = tqdm(total = deepfool_iter)
    while len(unfinished) and itr < deepfool_iter:
        remain = len(unfinished)
        #pbar.set_description('Remaining in Batch: {}'.format(remain))
        gradient_diffs = jacobian_idx_fast(pert_img[unfinished], k, class_idx[unfinished], num_classes, tot_classes, device)
        with torch.no_grad():
            pred_diffs = y_pred_pert[unfinished,1:] - y_pred_pert[unfinished,0].reshape(remain, 1).expand(remain, num_classes - 1)
            iter_pert = torch.abs(pred_diffs)/torch.norm(gradient_diffs, p = 2, dim = (2, 3, 4))
            pert_min, idx_min = torch.min(iter_pert, axis = 1)
            gradient_min = gradient_diffs[torch.arange(remain),idx_min]
            gradient_min_norm = torch.norm(gradient_min, p = 2, dim = (1, 2, 3))
            #pert.requires_grad = False
            pert[unfinished] += pert_min.reshape(-1, 1, 1, 1) * gradient_min/gradient_min_norm.reshape(-1, 1, 1, 1)
            #pert = pert.detach()
            #pert.requires_grad = True
        pert_img = x + (1 + overshoot) * pert
        y_pred_pert = k(pert_img)[pred_base_idx, class_idx]
        unfinished = torch.where(torch.argmax(y_pred_pert, axis = 1) == 0)[0]
        itr += 1
        #pbar.update(1)
    #pbar.close()
    return pert * (1 + overshoot)

def jacobian_idx(x, k, idx, num_classes, tot_classes, device = device):
    batch_size = x.shape[0]
    image_size = x.shape[1:]
    jacobian = torch.zeros((batch_size, num_classes - 1, *image_size), device = device)
    index_tensor = torch.zeros((1, tot_classes), device = device)
    for i, classes in enumerate(idx):
        x_i = x[i:i+1]
        x_i.retain_grad()
        out = k(x_i)
        index_tensor[0,classes[0]] = -1
        for j, c in enumerate(classes[1:]):
            index_tensor[0,c] = 1
            out.backward(index_tensor, retain_graph = j < num_classes - 1)
            jacobian[i,j,:] = x_i.grad.data
            x_i.grad.data.zero_()
            index_tensor[0,c] = 0
    return jacobian

def jacobian_idx_fast(x, k, idx, num_classes, tot_classes, device = device):
    batch_size = x.shape[0]
    image_size = x.shape[1:]
    jacobian = torch.zeros((batch_size, num_classes - 1, *image_size), device = device)
    index_tensor = torch.zeros((batch_size, tot_classes), device = device)
    indexer = torch.arange(batch_size)
    index_tensor[indexer, idx[:,0]] = -1
    out = k(x)
    x.retain_grad()
    for j in range(1, num_classes):
        index_tensor[indexer,idx[:,j]] = 1
        out.backward(index_tensor, retain_graph = j < num_classes)
        jacobian[:,j - 1,:] = x.grad.data
        x.grad.data.zero_()
        index_tensor[indexer,idx[:,j]] = 0
    return jacobian

def random_angle(batch_size, angle_max = 10):
    angle_max *= np.pi/180
    angles = 2 * (torch.rand(batch_size, device = device) - 0.5) * angle_max
    thetas = torch.zeros(batch_size, 2, 3, device = device)
    cos = torch.cos(angles)
    sin = torch.sin(angles)
    thetas[:, 0, 0] = cos
    thetas[:, 1, 1] = cos
    thetas[:, 0, 1] = -sin
    thetas[:, 1, 0] = sin
    return thetas

def random_shear(batch_size, shear_factor_max = 0.02):
    shears_x = 2 * (torch.rand(batch_size, device = device) - 0.5) * shear_factor_max
    shears_y = 2 * (torch.rand(batch_size, device = device) - 0.5) * shear_factor_max
    thetas = torch.zeros(batch_size, 2, 3, device = device)
    thetas[:, 0, 0] = 1
    thetas[:, 1, 1] = 1
    thetas[:, 0, 1] = shears_x
    thetas[:, 1, 0] = shears_y
    return thetas

def random_scale(batch_size, scale_percentage_max = 0.02):
    scales = 2 * (torch.rand(batch_size, device = device) - 0.5) * scale_percentage_max
    thetas = torch.zeros(batch_size, 2, 3, device = device)
    thetas[:, 0, 0] = 1 + scales
    thetas[:, 1, 1] = 1 + scales
    return thetas

def random_contrast_brightness(x, contrast = 0.02, brightness = 0.001):
    batch_size = x.shape[0]
    rand_contrast = ((torch.rand(batch_size).cuda() - 0.5) * 2) * contrast + 1
    rand_brightness = ((torch.rand(batch_size).cuda() - 0.5) * 2) * brightness
    return torch.einsum('i,ijkl->ijkl', rand_contrast, x) + rand_brightness.reshape(batch_size, 1, 1, 1).expand(*x.shape)

def random_translate(x, trans_max_x = 2, trans_max_y = 2):
    trans_x = 2 * (torch.rand(batch_size, device = device) - 0.5) * trans_max_x
    trans_y = 2 * (torch.rand(batch_size, device = device) - 0.5) * trans_max_y
    thetas = torch.zeros(batch_size, 2, 3, device = device)
    thetas[:, 0, 2] = trans_x
    thetas[:, 1, 2] = trans_y
    return thetas

def extend_perm_matrix(a):
    return torch.cat((a, torch.Tensor([[0,0,1]]).to(device).repeat(a.shape[0], 1, 1)), axis = 1)

def apply_rand_perm(x, vals = [10, (2, 2), 0.02, 0.02, (0.02, 0.001)]):
    batch_size = x.shape[0]
    rot = random_angle(batch_size, angle_max = vals[0])
    trans = random_translate(batch_size, trans_max_x = vals[1][0], trans_max_y = vals[1][1])
    shear = random_shear(batch_size, shear_factor_max = vals[2])
    scale = random_scale(batch_size, scale_percentage_max = vals[3])
    affine_mat = torch.bmm(torch.bmm(torch.bmm(extend_perm_matrix(rot), extend_perm_matrix(trans)), extend_perm_matrix(shear)), extend_perm_matrix(scale))
    x = diff_affine(x, affine_mat[:,:2,:])
    x = random_contrast_brightness(x, contrast = vals[4][0], brightness = vals[4][1])
    return x

def diff_affine(x, thetas):
    grid = F.affine_grid(thetas, x.size(), align_corners = False)
    return F.grid_sample(x, grid, align_corners = False)

def compute_deepfool_perms(x, y, k, norm = 2, xi = 5, num_classes = 10, deepfool_iter = 50, overshoot = 0.02, device = device, vals = [10, (2,2), 0.02, 0.02, (0.02, 0.001)]):
    y_pred = k(x)
    tot_classes = y_pred.shape[1]
    y_pred, class_idx = torch.sort(y_pred, axis = 1, descending = True)
    y_pred, class_idx = y_pred[:,0:num_classes], class_idx[:,0:num_classes]
    batch_size = x.shape[0]
    pred_label = class_idx[:,0]
    itr = 0
    pert = torch.zeros_like(x)
    pert.requires_grad = True
    pert_img = x + apply_rand_perm(pert, vals)
    y_pred_pert = y_pred
    unfinished = torch.arange(batch_size, device = device)
    pred_base_idx = torch.arange(batch_size).reshape(batch_size, 1).expand(batch_size, num_classes)
    #pbar = tqdm(total = deepfool_iter)
    while len(unfinished) and itr < deepfool_iter:
        remain = len(unfinished)
        #pbar.set_description('Remaining in Batch: {}'.format(remain))
        gradient_diffs = jacobian_idx_fast(pert_img[unfinished], k, class_idx[unfinished], num_classes, tot_classes, device)
        with torch.no_grad():
            pred_diffs = y_pred_pert[unfinished,1:] - y_pred_pert[unfinished,0].reshape(remain, 1).expand(remain, num_classes - 1)
            iter_pert = torch.abs(pred_diffs)/torch.norm(gradient_diffs, p = 2, dim = (2, 3, 4))
            pert_min, idx_min = torch.min(iter_pert, axis = 1)
            gradient_min = gradient_diffs[torch.arange(remain),idx_min]
            gradient_min_norm = torch.norm(gradient_min, p = 2, dim = (1, 2, 3))
            #pert.requires_grad = False
            pert[unfinished] += pert_min.reshape(-1, 1, 1, 1) * gradient_min/gradient_min_norm.reshape(-1, 1, 1, 1)
            #pert = pert.detach()
            #pert.requires_grad = True
        pert_img = x + (1 + overshoot) * apply_rand_perm2(pert)
        y_pred_pert = k(pert_img)[pred_base_idx, class_idx]
        unfinished = torch.where(torch.argmax(y_pred_pert, axis = 1) == 0)[0]
        itr += 1
        #pbar.update(1)
    #pbar.close()
    return pert * (1 + overshoot)

def compute_batch_pgd_perm(x, y, k, norm = 2, xi = 100, epochs = 40, random_p = 5, step_size = 1e-1, device = device, vals = [10, (2,2), 0.02, 0.02, (0.02, 0.001)]):
    batch_size = x.shape[0]
    random_delta = torch.rand(size = (1, *x.shape[1:]), device = device) - 0.5
    random_delta = project_lp(random_delta, norm = norm, xi = xi, exact = True, device = device)
    random_delta.requires_grad = True
    x = x.repeat(random_p, 1, 1, 1)
    y = y.repeat(random_p)
    for j in range(epochs):
        rotated_delta = random_delta.repeat(random_p, 1, 1, 1)
        rotated_delta = apply_rand_perm(rotated_delta, vals)
        pert_x = x + rotated_delta.repeat_interleave(batch_size, dim = 0)
        loss = F.cross_entropy(k(pert_x), y)
        loss.backward()
        pert = step_size * torch.sign(random_delta.grad)
        step_size = step_size * decay
        random_delta = project_lp(random_delta.detach() + pert, norm = norm, xi = xi)
        random_delta.requires_grad = True
    return random_delta

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transform_train = transforms.Compose([transforms.ToTensor(), transforms.Resize(256), transforms.CenterCrop(224), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
transform_val = transforms.Compose([transforms.ToTensor(), transforms.Resize(256), transforms.CenterCrop(224), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
imagenet_train = torchvision.datasets.ImageNet(root='D:/datasets/',split='train',
    #    download = True,
    transform=transform_train)
imagenet_val = torchvision.datasets.ImageNet(root='D:/datasets/',split='val',
    #    download = True,
    transform=transform_val)

class_idx = json.load(open("imagenet_class_index.json"))
class_idx = {i: class_idx[str(i)][1] for i in range(1000)}

test_size = 1000
train_size = len(imagenet_train) - val_size
imagenet_train, imagenet_test = torch.utils.data.random_split(imagenet_train, [train_size, test_size])
imagenet_train, _ = torch.utils.data.random_split(imagenet_train, [2000, len(imagenet_train) - 2000])

train_ds = torch.utils.data.DataLoader(imagenet_train, batch_size=16, shuffle=True, pin_memory=True, num_workers=8)
val_ds = torch.utils.data.DataLoader(imagenet_val, batch_size=16, shuffle=True, pin_memory=True, num_workers=8)
test_ds = torch.utils.data.DataLoader(imagenet_test, batch_size=16, shuffle=True, pin_memory=True, num_workers=8)

train_loader = DeviceDataLoader(train_ds, device)
val_loader = DeviceDataLoader(val_ds, device)
test_loader = DeviceDataLoader(test_ds, device)

mdl = torchvision.models.inception_v3(pretrained=True)
mdl.eval()
mdl = to_device(mdl, device)

def standard_uap(dl, k, adv_func = compute_pgd, norm = np.inf, xi = 10, delta = 0.05, max_iter = 5, device = device, verbose = True, **kwargs):
    image_shape = next(iter(dl))[0][0].shape
    v = torch.zeros(size = image_shape, device = device)
    error = -1 #err(dl, v, k)
    kwargs['norm'] = norm
    kwargs['xi'] = xi
    kwargs['device'] = device
    itr = 0
    if verbose:
        print("Iteration {}: Error Rate - {}".format(itr, error))
    while error <= 1 - delta and itr < max_iter:
        itr += 1
        for batch in (tqdm(dl) if verbose else dl):
            batch_x, batch_y = batch
            for j in range(len(batch_x)):
                x_i = batch_x[j:j+1]
                y_i = batch_y[j:j+1]
                x_mod = x_i + v
                if torch.max(k(x_mod), axis = 1)[1] == torch.max(k(x_i), axis = 1)[1]:
                    r = adv_func(x_mod, y_i, k, **kwargs)
                    v = project_lp(v + r, norm = norm, xi = xi)
                    v = v.detach()
                
        error = err(dl, v, k)
        if verbose:
            print("Iteration {}: Error Rate - {}".format(itr, error))
    return v

def standard_uap_rp(dl, k, adv_func = compute_batch_pgd_perm, norm = np.inf, xi = 10, delta = 0.05, max_iter = 5, vals = [10, (2,2), 0.02, 0.02, (0.02, 0.001)], device = device, verbose = True, **kwargs):
    image_shape = next(iter(dl))[0][0].shape
    v = torch.zeros(size = image_shape, device = device)
    error = -1 #err(dl, v, k)
    kwargs['norm'] = norm
    kwargs['xi'] = xi
    kwargs['device'] = device
    itr = 0
    if verbose:
        print("Iteration {}: Error Rate - {}".format(itr, error))
    while error <= 1 - delta and itr < max_iter:
        itr += 1
        for batch in (tqdm(dl) if verbose else dl):
            batch_x, batch_y = batch
            for j in range(len(batch_x)):
                x_i = batch_x[j:j+1]
                y_i = batch_y[j:j+1]
                x_mod = x_i + v
                if torch.max(k(x_mod), axis = 1)[1] == torch.max(k(x_i), axis = 1)[1]:
                    r = adv_func(x_mod, y_i, k, **kwargs)
                    v = project_lp(v + r, norm = norm, xi = xi)
                    v = v.detach()
                
        error = robust_err(dl, v, k, vals = vals)
        if verbose:
            print("Iteration {}: Error Rate - {}".format(itr, error))
    return v           

def sgd_uap(dl, k, norm = 2, xi = 100, max_iter = 5, random_p = 5, device = device, delta = 0.05, lr = 0.01, vals = [10, (2,2), 0.02, 0.02, (0.02, 0.001)]):
    image_shape = next(iter(dl))[0][0].shape
    v = torch.zeros(size = image_shape, device = device)
    error = -1 #err(dl, v, k) 
    itr = 0
    while error <= 1 - delta and itr < max_iter:
        itr += 1
        for batch in tqdm(dl):
            x_i, y_i = batch
            #x_i = x_i.clone()
            x_i.requires_grad = False
            x_mod = x_i + v
            delta_v= torch.zeros(size = image_shape, device = device, requires_grad = True)
            loss = 0
            for i in range(random_p):
                loss += F.cross_entropy(k(apply_rand_perm(x_mod + delta_v, vals)), y_i)
            loss.backward()
            v = project_lp(v + lr * torch.sign(delta_v.grad), norm = norm, xi = xi)
        error = err(dl, v, mdl, vals = vals)
        print(error)
    return v

def compute_n(phi, psi):
    return (1/(2 * (psi)**2) )* np.log(2/phi)

def robust_uap(dl, k, norm = 2, xi = 100, max_iter = 5, phi = 0.05, psi = 0.1, device = device, delta = 0.05, lr = 0.01, vals = [10, (2,2), 0.02, 0.02, (0.02, 0.001)]):
    n = compute_n(phi, psi)
    image_shape = next(iter(dl))[0][0].shape
    v = torch.zeros(size = image_shape, device = device)
    error = -1 #err(dl, v, k) 
    itr = 0
    kwargs['norm'] = norm
    kwargs['xi'] = xi
    kwargs['device'] = device
    kwargs['random_p'] = n
    while error <= 1 - delta and itr < max_iter:
        itr += 1
        for batch in tqdm(dl):
            x_i, y_i = batch
            #x_i = x_i.clone()
            x_i.requires_grad = False
            x_mod = x_i + v
            if robust_err_batch(batch, v, k, vals = vals) <= 1 - delta:
                r = compute_batch_pgd_perm(x_mod, y_i, k, **kwargs)
                v = project_lp(v + r, norm = norm, xi = xi)
                v = v.detach()
        error = robust_err(dl, v, k, times = n, vals = vals)
        print(error)
    return v


