from pathlib import Path
import torch.nn as nn
import math
import torch
from transformers.pytorch_utils import Conv1D

from nn_compression._interfaces import quantisable
from ._utils import recursively_find_named_children
from typing import Callable, Optional
import copy


class TotalHessian:
    """An estimation on the total Hessian, suitable for any parameter.

    Uses Hutchinsons Diagonal estimation method via H = E(z .* H @ z).

    See https://curvlinops.readthedocs.io/en/latest/linops.html#curvlinops.HutchinsonDiagonalEstimator.
    """

    def __init__(
        self, net: nn.Module, loss_fn: Callable = torch.nn.CrossEntropyLoss()
    ) -> None:
        self.net = net
        self.loss_fn = loss_fn

        params = []
        names = []
        for name, layer in recursively_find_named_children(self.net):
            if quantisable(layer):
                params.append(layer.weight)
                names.append(name)

        self.params = params
        self.names = names

        self._H = None
        self._diagonal = None

    @property
    def diagonal(self):
        if self._diagonal is None:
            raise ValueError(
                "The diagonal has not been estimated. Call sample() first."
            )
        return self._diagonal

    @property
    def posterior_variances(self):
        if self._diagonal is None:
            raise ValueError(
                "The diagonal has not been estimated. Call sample() first."
            )

        nparams = 0
        variances = {}
        for n, p in zip(self.names, self.params):
            variances[n] = 1 / (
                nn.ReLU()(torch.tensor(self._diagonal[nparams : nparams + p.numel()]))
                + 1e-6
            )

        return variances

    def _prepare_linear_operator(self, x: torch.Tensor, y: torch.Tensor):
        from curvlinops import HessianLinearOperator

        data = [(x, y)]
        H = HessianLinearOperator(self.net, self.loss_fn, self.params, data)
        self._H = H
        return H

    def sample(self, x: torch.Tensor, y: torch.Tensor, nsamples: int = 5):
        from curvlinops import HutchinsonDiagonalEstimator

        H = self._prepare_linear_operator(x, y)

        diag_sample = HutchinsonDiagonalEstimator(H).sample()
        for _ in range(nsamples):
            diag_sample += HutchinsonDiagonalEstimator(H).sample()
        diag_sample /= nsamples + 1

        self._diagonal = diag_sample
        return self


def track_hessians(net, normalize: bool = False, is_large_net: bool = False):
    def record_hessian(module, input, output):
        if not quantisable(module):
            raise ValueError(
                f"This layer should be quantisable. Internal error: {module}"
            )

        if not hasattr(module, "hessian"):
            module.hessian = LayerWiseHessian(module, normalize=normalize)
        with torch.no_grad():
            module.hessian.add_batch(
                input[0], output
            )  # input is list with args in order

    handles = []
    hessians = {}
    for n, layer in recursively_find_named_children(net):
        if quantisable(layer):
            handles.append(layer.register_forward_hook(record_hessian))
            if not hasattr(layer, "hessian"):
                layer.hessian = LayerWiseHessian(layer, normalize=normalize, is_large_net=is_large_net)  # type: ignore
            hessians[n] = layer.hessian

    net._handles_hessian = handles
    return hessians


def untrack_hessians(net):
    for handle in net._handles_hessian:
        handle.remove()
    for n, layer in recursively_find_named_children(net):
        if quantisable(layer):
            if not hasattr(layer, "hessian"):
                print(
                    f"WARNING: Hessian not computed for layer {n}. This is likely because the layer was never called during the forward pass. This layer will not be quantised."
                )
                layer.quantisable = False  # type: ignore
                continue
            layer.hessian.precalculate()
    del net._handles_hessian


def estimate_hessians(net, dataloader, nbatches: int, device=torch.device("cpu")):
    """Estimates Hessians on a neural network using the dataloader. This is useful for
    large networks, where batched inputs are necessary."""
    net.to(device)
    track_hessians(net)

    i = 0
    for x, _ in dataloader:
        x = x.to(device)
        if i >= nbatches:
            break
        i += 1
        net(x)

    untrack_hessians(net)
    net.to("cpu")


def hessians_to(net: nn.Module, device: torch.device | str):
    for _, layer in recursively_find_named_children(net):
        if hasattr(layer, "hessian"):
            layer.hessian.to(device)


class LayerwiseHessianTracker:
    def __init__(
        self,
        net,
        normalize: bool = False,
        save_to: Optional[Path] = None,
        is_large_net: bool = False,
    ) -> None:
        self.net = net
        self.normalize = normalize
        self.save_to = save_to
        self.is_large_net = is_large_net

    def __enter__(self):
        self.hessians = track_hessians(
            self.net, normalize=self.normalize, is_large_net=self.is_large_net
        )

    def __exit__(self, *args):
        untrack_hessians(self.net)
        if self.save_to is not None:
            layer_pointers = {}
            # We need to remove the layer pointers, else the Hessian unnecessarily
            # saves the whole weight vectors
            for k, v in self.hessians.items():
                layer_pointers[k] = v.layer
                v.layer = None
            torch.save(self.hessians, self.save_to)
            # reattach the layer pointers
            for k, v in self.hessians.items():
                v.layer = layer_pointers[k]


def clear_hessians(net):
    for _, layer in recursively_find_named_children(net):
        if hasattr(layer, "hessian"):
            del layer.hessian


class TransformWeights:
    """Transforms weights between the shape used in GPTQ (2D, with (out, in)) and the original
    shape used in PyTorch modules."""

    def __init__(self, kind: type):
        self.kind = kind

    def into_2d(self, W: torch.Tensor):
        """Transforms given weights to a form that can be used in the GPTQ quantisation method."""
        if issubclass(self.kind, nn.Linear):
            return W
        elif issubclass(self.kind, nn.Conv2d):
            return W.flatten(1)
        elif issubclass(self.kind, Conv1D):
            return W.t()
        else:
            raise ValueError("Only nn.Conv2d and nn.Linear layers are supported")

    def from_2d(self, W: torch.Tensor, orig_shape: torch.Size):
        """Transforms the weights back to the original shape."""
        if issubclass(self.kind, nn.Conv2d):
            W = W.reshape(orig_shape)
        elif issubclass(self.kind, nn.Linear):
            W = W
        elif issubclass(self.kind, Conv1D):
            W = W.t()
        return W


class LayerWiseHessian:
    """A representation of the layer-wise Hessian. This class is used to estimate the Hessian matrix of a layer.
    This is useful to encompass different types of layers, such as linear and convolutional layers and treat them the same way.
    This class provides both the Hessian and the Cholesky Decomposition of the inverse of the Hessian, which is used in the GPTQ quantisation method.
    """

    def __init__(
        self,
        layer: nn.Module,
        percdamp: float = 0.01,
        normalize: bool = False,
        is_large_net: bool = False,
    ) -> None:
        """Initialise the LayerWiseHessian for a layer of a neural network.

        Args:
            layer: The layer for which the Hessian should be estimated.
            percdamp: Dampening when calculating the inverse.
        """
        self.normalize = normalize
        self.layer = layer
        if isinstance(self.layer, nn.Conv2d):
            self.type = "conv"
        elif isinstance(self.layer, nn.Linear):
            self.type = "linear"
        elif isinstance(self.layer, Conv1D):
            self.type = "conv1d-t"
        else:
            raise ValueError("Only nn.Conv2d and nn.Linear layers are supported")

        self.is_large_net = is_large_net

        self.device = layer.weight.device if not is_large_net else "cpu"
        self.chol_device = layer.weight.device

        self.columns = self.get_weights().shape[1]
        self._H = torch.zeros((self.columns, self.columns), device=self.device)
        self._chol_inv = None
        self.nsamples = 0
        self.percdamp = percdamp

    def to(self, device):
        self.device = device
        self._H = self._H.to(device)
        if self._chol_inv is not None:
            self._chol_inv = self._chol_inv.to(device)
        return self

    def deepcopy(self, layer):
        hessian = LayerWiseHessian(layer, self.percdamp)
        hessian._H = copy.deepcopy(self._H.detach())
        if self._chol_inv is None:
            self.precalculate()
        assert self._chol_inv is not None
        hessian._chol_inv = copy.deepcopy(self._chol_inv.detach())
        return hessian

    @staticmethod
    def load_into_model(model: nn.Module, hessians: dict[str, "LayerWiseHessian"]):
        for n, layer in recursively_find_named_children(model):
            if quantisable(layer):
                layer.hessian = hessians[n]  # type: ignore
                layer.hessian.layer = layer  # else we still point to the old layer

    @property
    def H(self):
        """Returns the Hessian matrix as H = 2*X^T*X, where X is the input to the layer."""
        return self._H

    def precalculate(self):
        return self.cholesky_inverse

    @property
    def cholesky_inverse(self) -> torch.Tensor:
        """Returns the cholesky decomposition of the inverse of the Hessian."""
        if self._chol_inv is None:
            H = self._H.clone()
            dead = torch.diag(H) == 0
            H[dead, dead] = 1

            damp = self.percdamp * torch.mean(torch.diag(H))
            H.diagonal().add_(damp)
            H.to(self.chol_device)
            try:
                self._chol_inv = torch.linalg.cholesky(
                    torch.cholesky_inverse(torch.linalg.cholesky(H)), upper=True
                )
            except torch._C._LinAlgError:  # type: ignore
                print("WARNING: Cholesky decomposition failed. Adding more dampening.")
                H.diagonal().add_(damp * 100)
                self._chol_inv = torch.linalg.cholesky(
                    torch.cholesky_inverse(torch.linalg.cholesky(H)), upper=True
                )
            self._chol_inv.to(self.device)

        assert self._chol_inv is not None
        return self._chol_inv

    def reset(self):
        """Sets the Hessian matrix to zero."""
        self._H = torch.zeros((self.columns, self.columns), device=self.device)
        self.nsamples = 0
        self._chol_inv = None

    def get_weights(self):
        """See transform_weights."""
        return TransformWeights(type(self.layer)).into_2d(self.layer.weight.data)

    def set_weight(self, W: torch.Tensor):
        """Sets the weights of the layer. The given weights should have the same shape as
        the ones returned from get_weights."""
        self.layer.weight.data = TransformWeights(type(self.layer)).from_2d(
            W, self.layer.weight.shape
        )

    def transform_input(self, inp: torch.Tensor):
        """Returns the input in a transformed way f(X) such that the output Y is calculated by
        Y = f(X)W. For example, convolutions are unfolded."""
        if self.type == "linear" or self.type == "conv1d-t":
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()
        if self.type == "conv":
            unfold = nn.Unfold(
                self.layer.kernel_size,
                dilation=self.layer.dilation,
                padding=self.layer.padding,  # type: ignore
                stride=self.layer.stride,
            )
            inp = unfold(inp)
            inp = inp.permute([1, 0, 2])
            inp = inp.flatten(1)
        return inp

    def add_batch(self, inp: torch.Tensor, output: Optional[torch.Tensor] = None):
        """Adds a calibration batch to the GPTQ quantiser. The batch is used to estimate the Hessian matrix."""
        inp = inp.to(self.device)
        self._chol_inv = None  # have to recalculate the inverse
        inp = inp.detach()
        tmp = inp.shape[0]
        # after this, batch is last
        if self.normalize:
            assert output is not None
            inp = inp / torch.var(output.detach(), dim=-1, keepdim=True).sqrt()
        inp = self.transform_input(inp)
        self._H *= self.nsamples / (self.nsamples + tmp)
        self.nsamples += tmp
        inp = math.sqrt(2 / self.nsamples) * inp.float()
        self._H += inp.matmul(inp.t())
