from pathlib import Path
import pickle
from typing import Callable, Literal, Optional

from ._utils import recursively_find_named_children
from nn_compression._interfaces import quantisable
import torch
from ._fisher import estimate_fisher
from torch.utils.data import DataLoader
import torch.nn as nn


def calculate_alpha(
    net: nn.Module,
    qnet: nn.Module,
    dataloader: DataLoader,
    loss_fn: Optional[Callable],
    nbatches: int,
    device: Literal["cpu", "mps", "cuda"] = "cpu",
    verbose=False,
    save_intermediate=None,
) -> dict[str, float]:
    """Returns alpha values for each layer, where alpha represents how much
    a layer should be scaled by in the quantisation process.

    alpha = <d, Fd> / <d,Hd>, where F is the empirical fisher, H is the layer-wise hessian and
    d is the estimated quantisation direction.

    The dataloader should return tuples (data, labels) where data is a tensor of shape (batch_size, *input_shape).
    If you use a transformer from huggingface, pass None for the loss function, as we use cross-entropy internally.
    """
    for n, l in recursively_find_named_children(net):
        if quantisable(l) and not hasattr(l, "hessian"):
            raise ValueError(
                f"Hessian not computed for layer {n}. Track Hessians first."
            )
    _assign_deltas(net, qnet)
    estimate_fisher(
        net,
        dataloader,
        loss_fn,
        nbatches,
        alpha_only=True,
        device=device,
    )
    return _calculate_scaling_alpha(
        net, verbose=verbose, save_intermediate=save_intermediate
    )


def _assign_deltas(orig_net, quant_net):
    for orig_layer, quant_layer in zip(
        recursively_find_named_children(orig_net),
        recursively_find_named_children(quant_net),
    ):
        if quantisable(orig_layer[1]):
            h1 = orig_layer[1].hessian.get_weights().flatten()
            h2 = quant_layer[1].hessian.get_weights().flatten().to(h1.device)
            delta = h1 - h2
            orig_layer[1].delta = delta
            quant_layer[1].delta = delta


def _calculate_scaling_alpha(
    net, verbose: bool = False, save_intermediate: Optional[Path] = None
):
    alphas = {}
    fs = {}
    hs = {}
    for n, l in recursively_find_named_children(net):
        if quantisable(l):
            H = 0
            delta = l.delta.reshape(l.hessian.get_weights().shape)
            for row in range(delta.shape[0]):
                d_r = delta[row, :]
                H += d_r.reshape(1, -1) @ l.hessian.H @ d_r.reshape(-1, 1)
            F = l.alpha_fisher
            alphas[n] = (F / H).item()
            if isinstance(H, torch.Tensor):
                hs[n] = H.item()
            else:
                hs[n] = H
            fs[n] = F
    if verbose:
        print("Alpha:")
        print(alphas)
        print("<delta, F delta>:")
        print(fs)
        print("<delta, H belta>:")
        print(hs)
    if save_intermediate:
        pickle.dump(alphas, open(save_intermediate / "alphas", "wb"))
        pickle.dump(fs, open(save_intermediate / "Fs", "wb"))
        pickle.dump(hs, open(save_intermediate / "Hs", "wb"))
    return alphas
