import torch.nn as nn
import torch
from tqdm import tqdm

from nn_compression._interfaces import quantisable
from ._utils import recursively_find_named_children
from torch.utils.data import DataLoader


def estimate_fisher(
    net: nn.Module,
    samples: DataLoader,
    loss_fn,
    nbatches: int,
    damp: float = 0.001,
    verbose=False,
    alpha_only: bool = False,
    device="cpu",
):
    """Estimates the empirical fisher matrix of a network using a
    sample xs with label ys.

    The fisher is calculated as
        F := E_{x~p_x}[\\grad_p L(x) \\grad_p L(x) ^T]
    where \\grad_p L(x) is the gradient of the loss with respect to the parameters,
    and p_x is the data distribution. We calculate a monte-carlo estimate of the
    expectation using our provided sample xs.

    Additionally, as we calculate the fisher iteratively, we can also keep track of an inverse of the
    fisher using the Sherman-Morrison formula:
        F_{i+1}^{-1} = \\left(F_i + \\nabla \\mathcal{L} \\nabla \\mathcal{L}^T \\right)^{-1}
         = F_i^{-1} - \\frac{F_i^{-1} \\nabla \\mathcal{L} \\nabla \\mathcal{L}^T F_i^{-1}} {1 + (\\nabla \\mathcal{L}^T F_i^{-1} \\nabla \\mathcal{L})}
    For more information, see the WoodFisher method by Singh et al. 2020, https://arxiv.org/abs/2004.14340.

    We restrict the fisher to the layer-wise interactions to keep computational tractability.
    The fisher of each layer can be found as an attribute that is appended to the layer, "fisher".

    If the fisher-component of alpha (<delta, F delta>) should only be computed, pass true to the alpha value
    and ensure that each layer has a delta value.

    Args:
        net: The network to calculate the fisher for.
        samples: The dataloader to sample from. Shall return (xs, ys)
        loss_fn: The loss function to calculate the loss. If set to None, we assume that the model
            should be passed the labels directly via model(xs, labels=ys), and the loss is then
            taken from the output via outputs.loss. This is relevant for transformers
        nbatches: The number of batches to sample.
        damp: The damping factor to add to the fisher.
        verbose: Whether to print progress.
        alpha_only: Whether to calculate the alpha fisher component only or calculate the whole (layerwise) fisher.
        device: The device to run the calculations on.
    """

    for _, module in recursively_find_named_children(net):
        if not quantisable(module):
            continue

        num_params = sum([p.numel() for p in module.parameters()])
        module.total_number_parameters = num_params  # type: ignore
        # should we add this?
        # module.alpha_fisher = torch.inner(module.delta, module.delta) * damp
        if alpha_only:
            module.alpha_fisher = 0  # type: ignore
            if not hasattr(module, "delta") or module.delta is None:
                raise ValueError(
                    "Alpha fisher is requested, but no delta value is provided."
                )

        else:
            module.fisher = torch.eye(num_params) * damp

    N = 0
    i = 0
    for xs, ys in tqdm(samples, total=nbatches, disable=not verbose):
        xs = xs.to(device)
        ys = ys.to(device)
        if i >= nbatches:
            break
        i += 1
        N += xs.shape[0]
        net.zero_grad()
        for p in net.parameters():
            p.grad = None
        if loss_fn is None:
            out = net(xs, labels=ys)
            if not hasattr(out, "loss"):
                raise ValueError(
                    "The model does not return a loss value. Please provide a loss function."
                )
            l = out.loss
        else:
            yp = net(xs)
            l = loss_fn(yp, ys)
        l.backward()

        for _, module in recursively_find_named_children(net):
            if not quantisable(module):
                continue
            if alpha_only:
                module.alpha_fisher += (
                    torch.inner(module.delta.flatten(), module.weight.grad.flatten())
                    ** 2
                ).item()
            else:
                grads = torch.zeros(module.total_number_parameters)
                j = 0
                for p in module.parameters():
                    assert p.grad is not None
                    grads[j : j + p.numel()] = p.grad.flatten()  # type: ignore
                    j += p.numel()
                module.fisher += torch.outer(grads, grads)
    if not alpha_only:
        module.fisher = module.fisher / N
    else:
        assert module.alpha_fisher is not None
        module.alpha_fisher = module.alpha_fisher / N  # type: ignore
    return net


def estimate_layerwise_fisher(net, xs, ys, loss_fn, damp=0.001):
    estimate_fisher(net, xs, ys, loss_fn, damp)
    fs = []
    for l in net:
        if quantisable(l):
            fs.append(l.fisher)
    return fs
