import torch

from src.diffusion import trans_to_224, trans_to_256, denoise

def classification_with_dds(images, model, diffusion_model, num_trials=10, reduction='mean'):
    ret = None
    for _ in range(num_trials):
        dds_images = dds(images, diffusion_model, fast_predict=True)
        logits = model(dds_images)

        if ret is None:
            ret = torch.zeros_like(logits)

            if reduction == 'max':
                pred_step = logits.argmax(dim=1)
                ret[torch.arange(ret.size(0)), pred_step] += 1
            elif reduction == 'mean':
                ret += logits

    return ret / num_trials


def dds(image, diffusion, denoising=True, smoothing=True, fast_predict=False,
        noise_level = 5/255,
        steps = 1000,
        start = 0.0001,
        end = 0.02):

    if denoising:
        image = trans_to_224(denoise(trans_to_256(image), diffusion, steps, start, end, noise_level,
                                     direct_pred=fast_predict))

    if smoothing:
        image += torch.randn_like(image, ) * noise_level

    image = torch.clamp(image.squeeze(), -1, 1)
    return image

def attack(image, model, noise_level, label_index=None,
           mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):

    import torchattacks
    torch.backends.cudnn.deterministic = True
    atk = torchattacks.PGD(model, eps=noise_level, alpha=noise_level/5, steps=10)
    atk.set_normalization_used(mean, std)
    labels = torch.FloatTensor([0] * 1000)
    if label_index is None:
        # with torch.no_grad():
        logits = model(image)
        label_index = logits.argmax()
        # print(label_index)

    labels[label_index] = 1
    labels = labels.reshape(1, 1000)
    adv_images = atk(image, labels.float())
    return adv_images


def batch_pgd_attack(images, model, noise_level, label_indices=None,
           mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):

    import torchattacks
    torch.backends.cudnn.deterministic = True
    atk = torchattacks.PGD(model, eps=noise_level, alpha=noise_level/5, steps=10)
    atk.set_normalization_used(mean, std)

    batch_size = images.size(0)
    labels = torch.zeros(batch_size, 1000)

    if label_indices is None:
        logits = model(images)
        label_indices = logits.argmax(dim=1)

    for i in range(batch_size):
        labels[i, label_indices[i]] = 1

    adv_images = atk(images, labels.float())
    return adv_images