import torch
from torch.nn.functional import binary_cross_entropy, cross_entropy

def conv_net_to_torch(means, devs):
    for layer in range(len(means)):
        for sub_layer in range(len(means[layer])):
            means[layer][sub_layer] = torch.tensor(means[layer][sub_layer])
            devs[layer][sub_layer] = torch.tensor(devs[layer][sub_layer])
            if len(means[layer][sub_layer].shape) == 1:
                means[layer][sub_layer] = means[layer][sub_layer].unsqueeze(-1)
            if len(devs[layer][sub_layer].shape) == 1:
                devs[layer][sub_layer] = devs[layer][sub_layer].unsqueeze(-1)
    return means, devs


def spectral_norm(tensor):
    _, s, _ = torch.svd(tensor)
    return torch.max(s)


def attack(model, x, x_cf, target_class, true_class, num_iterations, stepsize, dist_weight, use_multiclass, step_decay=0.9, l0_budget=None):
    save_iter = 0
    correction_counter = torch.zeros(x.size(), dtype=torch.int64).to(x_cf.device)
    for iter in range(num_iterations):
        x_cf.requires_grad = True
        with torch.enable_grad():
            output = model.forward(x_cf)
            if use_multiclass:
                if l0_budget == None:
                    loss = (1 - dist_weight) * (output[true_class] - output[target_class]) + dist_weight * torch.sum(abs(x_cf - x) ** 2)
                else:
                    output = output[:, 0, 0]
                    target_vec = torch.zeros(output.size(), dtype=torch.float64).to(output.device)
                    target_vec[target_class] = 1
                    loss = (1 - dist_weight) * cross_entropy(output, target_vec) + dist_weight * torch.sum(
                        abs(x_cf - x) ** 2)  # Minimize
            else:
                if not torch.is_tensor(target_class):
                    target_class = torch.tensor(target_class, dtype=torch.float64).unsqueeze(0).unsqueeze(0).unsqueeze(0).to(output.device)
                if target_class.size() != output.size():
                    target_class = target_class.unsqueeze(0).unsqueeze(0).unsqueeze(0).type(torch.float64).to(output.device)
                loss = (1 - dist_weight) * binary_cross_entropy(output, target_class) + dist_weight * torch.sum(
                    abs(x_cf - x) ** 2)
            loss.backward()
            grad = x_cf.grad.clone()
            if not l0_budget == None:
                max_number_corrections = 3
                grad[correction_counter >= max_number_corrections] = 0
                grad = grad * (abs(grad) >= abs(grad).flatten().sort(descending=True).values[l0_budget-1])
                correction_counter += (abs(grad) >= abs(grad).flatten().sort(descending=True).values[l0_budget-1]).int()
        with torch.no_grad():
            x_cf -= grad * stepsize
            x_cf = x_cf.clamp(0, 1.)
        if iter % (num_iterations/10) == 0 and l0_budget == None:
            stepsize *= step_decay
    return x_cf


def get_input_grad_(model, x, target, device):
    with torch.enable_grad():
        x = x.to(device).type(torch.float64)
        x.requires_grad = True
        output = model._forward_op(x)
        if output.shape[0] == 1:
            output.sum().backward()
            return x.grad.detach() * ((target - 0.5)*2)
        else:
            output[target, ...].sum().backward()
            return x.grad.detach()
