from typing import Optional
import torch.nn as nn
import copy

from nn_compression._interfaces import quantisable
from nn_compression.quantisation import (
    AffineGridQuantiser,
    RatedGrid,
)

from ._tensor_quantisers import AffineGridQuantiser, RdQuantiserDeepCabac
import torch
from nn_compression.networks import (
    map_net_forward,
    map_net,
    LayerWiseHessian,
    recursively_find_named_children,
    TransformWeights,
)
from ._gptq import GptqLayer


def rtn_quantise_network(net: nn.Module, nbit, per_row_grid: bool = False) -> nn.Module:
    def quantise(module):
        if quantisable(module):
            if per_row_grid:
                w = LayerWiseHessian(module).get_weights()

                q = AffineGridQuantiser(nbit, per_row_grid=per_row_grid)
                q.find_params(w, weight=True)
                for j in range(w.shape[1]):
                    w[:, j : j + 1] = q.quantise(w[:, j : j + 1], col=j).x
                LayerWiseHessian(module).set_weight(w)
            else:
                q = AffineGridQuantiser(nbit)
                q.find_params(module.weight, weight=True)
                module.weight.data = q.quantise(module.weight).x
        return module

    return map_net(net, quantise)


def gptq_quantise_network(
    net: nn.Module,
    nbits: float | int,
    x_cal: dict | torch.Tensor | None,
    per_row_grid=False,
    inplace: bool = False,
    verification_fn=None,
    symmetric_grid: bool = True,
):
    """Quantise neural network using GPTQ, exactly as proposed by Frantar et al. (2022) https://arxiv.org/abs/2210.17323. The
    quantised network is returned.
    """
    if net.training is True:
        raise ValueError("Network must be in eval mode.")
    if x_cal is None:
        return gptq_quantise_precalculated(
            net, nbits, per_row_grid, inplace, verification_fn, symmetric_grid
        )

    quantize_layer = quantize_layer_gptq_fn(
        nbits,
        per_row_grid,
        verification_fn,
        precalculated=False,
        symmetric_grid=symmetric_grid,
    )

    return map_net_forward(
        net, x_cal, quantize_layer, inplace=inplace, require_name=True
    )


def gptq_quantise_precalculated(
    net: nn.Module,
    nbits: int | float,
    per_row_grid=False,
    inplace: bool = False,
    verification_fn=None,
    symmetric_grid: bool = True,
):
    if inplace:
        net_mapped = net
    else:
        net_mapped = copy.deepcopy(net)

    quantize = quantize_layer_gptq_fn(
        nbits,
        per_row_grid,
        verification_fn,
        precalculated=True,
        symmetric_grid=symmetric_grid,
    )

    for name, module in recursively_find_named_children(net_mapped):
        quantize(name, module, None, None)
    return net_mapped


def quantize_layer_gptq_fn(
    nbits, per_row_grid, verification_fn, precalculated, symmetric_grid
):
    def quantize_layer(name, module, args, output):
        if verification_fn is not None and not verification_fn(name):
            return

        quantizer = AffineGridQuantiser(
            nbits, symmetric=symmetric_grid, per_row_grid=per_row_grid
        )

        if quantisable(module):
            gptq = GptqLayer(module, quantizer)
            if not precalculated:
                gptq.add_batch(*args)  # this could be made more general
            gptq.fasterquant()

    return quantize_layer


def rd_quantise_gptq_order_deepcabac(
    net_to_quantise: nn.Module,
    nbits: int,
    x_cal: torch.Tensor | None,
    lm: float | dict,
    uniform_posterior: bool = False,
    per_row_grid: bool = False,
    inplace: bool = False,
    verification_fn=None,
    blocksize: int = 1,
    gptq_class=GptqLayer,
):
    if net_to_quantise.training is True:
        raise ValueError("Network must be in eval mode.")
    if uniform_posterior:
        raise NotImplementedError("uniform posterior not supported for deepCABAC")

    def quantise_layer(name, module, args, output):
        if verification_fn is not None and not verification_fn(name):
            return
        if quantisable(module):
            lm_layer = lm[name] if isinstance(lm, dict) else lm
            quantiser = RdQuantiserDeepCabac(
                nbits, lm_layer, per_row_grid=per_row_grid, blocksize=blocksize
            )
            gptq = gptq_class(module, quantiser)
            if x_cal is not None:
                gptq.add_batch(args[0])
            gptq.fasterquant()

    if x_cal is None:
        if not inplace:
            net_to_quantise = copy.deepcopy(net_to_quantise)
        for name, module in recursively_find_named_children(net_to_quantise):
            quantise_layer(name, module, None, None)
        return net_to_quantise
    else:
        return map_net_forward(
            net_to_quantise, x_cal, quantise_layer, require_name=True, inplace=inplace
        )


def rd_quantise_direct_deepcabac(
    net, nbits, lmbda, per_row_grid=False, inplace: bool = False
):
    if not inplace:
        net = copy.deepcopy(net)
    if per_row_grid:
        raise NotImplementedError(
            "Per row grid not supported for direct RD quantisation"
        )
    for _, layer in recursively_find_named_children(net):
        if quantisable(layer):
            # per row grid would need TransformWeights here
            quant = RdQuantiserDeepCabac(nbits, lmbda, per_row_grid=per_row_grid)
            quant.find_params(layer.weight)
            new_weights = quant.quantise(layer.weight).x.reshape(
                layer.weight.data.shape
            )
            layer.weight.data = new_weights
    return net
