import torch
from torch.autograd import Variable

def apply_attack(args, idx, images, labels, net, delta_data, delta_smash, batch_idx):
    """
    Apply attack logic to a single batch
    """
    device = images.device
    batch_size = images.size(0)

    # select delta
    if len(images) == batch_size:
        dd = delta_data[idx][batch_idx * batch_size:(batch_idx + 1) * batch_size]
        ds = delta_smash[idx][batch_idx * batch_size:(batch_idx + 1) * batch_size]
    else:
        dd = delta_data[idx][batch_idx * batch_size:]
        ds = delta_smash[idx][batch_idx * batch_size:]

    # number of iterations
    if args.dataset_attack or args.smashed_attack:
        iter_num = args.iter_num
    else:
        iter_num = 1

    for num in range(iter_num):
        images_mod = images + args.iur * dd
        images_mod = Variable(images_mod, requires_grad=True)

        with torch.backends.cudnn.flags(enabled=False):
            fx = net(images_mod)

        fx = fx + args.iur * ds
        client_fx = fx.clone().detach().requires_grad_(True)

        if args.label_attack:
            labels = (labels + 1) % args.classes

        yield client_fx, labels, num, iter_num, dd, ds  # yield back to the training loop

    # update delta_data and delta_smash
    if args.dataset_attack:
        if len(images) == batch_size:
            delta_data[idx][batch_idx * batch_size:(batch_idx + 1) * batch_size] = dd
        else:
            delta_data[idx][batch_idx * batch_size:] = dd

    if args.smashed_attack:
        if len(images) == batch_size:
            delta_smash[idx][batch_idx * batch_size:(batch_idx + 1) * batch_size] = ds
        else:
            delta_smash[idx][batch_idx * batch_size:] = ds
