import torch


def bit_flip_ic(model, w_flat, criterion, instances, target, weight_decay):
    """
    Compute current gradient with bit-flipping attack (on image classification task)

    :param model: current model
    :param w_flat: current model parameter in flat tensor form
    :param criterion: loss function
    :param instances: a batch of training instances
    :param target: the labels of training instances
    :param weight_decay: hyper-parameter of weight decay

    :return: computed gradient
    """

    if torch.cuda.is_available():
        instances = instances.cuda()
        target = target.cuda()
    input_var = torch.autograd.Variable(instances)
    target_var = torch.autograd.Variable(target)

    # compute output
    output = model(input_var)
    loss = criterion(output, target_var)

    # compute gradient and do SGD step
    g_flat = torch.zeros_like(w_flat)
    model.zero_grad()
    loss.backward()
    flatten_g(model, g_flat)
    g_flat.add_(weight_decay, w_flat)

    # bit flipping
    g_flat.mul_(-1.)

    return g_flat


def flatten_g(model, vec):
    pointer = 0
    for param in model.parameters():
        num_param = torch.prod(torch.LongTensor(list(param.size())))
        vec[pointer:pointer + num_param] = param.grad.data.view(-1)
        pointer += num_param
