from collections.abc import Callable

import torch

from compute.matrix_2norm import compute_matrix_2norm_power_method_batched


def compute_jac_norm(model: Callable, x_batch: torch.Tensor) -> torch.Tensor:
    """Computes the norm of the model jacobian for each input in the batch.
    This method works for all models with any inputs of any dimension.

    Parameters
    ----------
    model
        Model function.
    x_batch
        Batched inputs tensor of shape (batch_size x input_shape).
        Input_shape can be a number or 3 numbers (channels x height x width), making x_batch of shape
        (batch_size x channels x height x width)

    Returns
    -------
        A vector of norms of the jacobian for each input in the batch.
    """
    x_batch.requires_grad_(True)
    out = model(x_batch)

    # jacobian of dimensions batch_size x output_dim x input_dim
    # input_dim can be a list
    jac = compute_jacobian(out, x_batch)

    # to compute the 2-norm, we flatten the jacobian in the dimension of the input
    jac = jac.flatten(start_dim=2, end_dim=-1)

    jac_norm = compute_matrix_2norm_power_method_batched(jac)

    x_batch.requires_grad_(False)

    return jac_norm


@torch.jit.script
def compute_jacobian(y_batch: torch.Tensor, x_batch: torch.Tensor) -> torch.Tensor:
    """Computes the Jacobian of y wrt to x.
    Thanks to https://github.com/magamba/overparameterization/blob/530948c72662b062fcb0c5c084b857a3951efb63/core/metric_helpers.py#L235
    for providing code to this function.

    Parameters
    ----------
    y_batch
        Output tensor of the model, batched. Dimensions: batch_size x output_dim
    x_batch
        Input tensor of the model, batched. Dimensions: batch_size x input_dim

    Returns
    -------
        Jacobian of y_batch wrt to x_batch. Dimensions: batch_size x output_dim x input_dim

    Note
    ----
        x_batch has to track gradients before the function is called.
    """
    nclasses = y_batch.shape[1]

    x_batch.retain_grad()
    # placeholder for the jacobian
    jacobian = torch.zeros(x_batch.shape + (nclasses,), dtype=x_batch.dtype, device=x_batch.device)
    # this mask tells torch to only compute the jacobian wrt. to the 0th index of the output = ∇_x(f_0).
    indexing_mask = torch.zeros_like(y_batch)
    indexing_mask[:, 0] = 1.0

    for dim in range(nclasses):
        y_batch.backward(gradient=indexing_mask, retain_graph=True)
        # fill the ith index with grad data of ∇_x(f_i).
        jacobian[..., dim] = x_batch.grad.data
        x_batch.grad.data.zero_()
        # shift the mask to compute the i+1th index of the output
        indexing_mask = torch.roll(indexing_mask, shifts=1, dims=1)

    # permute jacobian dimensions
    # from batch_size x input_dim x output_dim
    # to batch_size x output_dim x input_dim
    permute_dims = [0] + [len(jacobian.shape) - 1] + list(range(1, len(jacobian.shape) - 1))

    return torch.permute(jacobian, permute_dims)
