from data_utils.statistics import entropy
import numpy as np
from torch.utils.data import DataLoader
from nn_compression._interfaces import quantisable
from nn_compression.networks import recursively_find_named_children, LayerWiseHessian
from typing import Callable, Optional
import torch.nn as nn
import torch
from dataclasses import dataclass


@dataclass
class Entropy:
    entropy: float
    overhead: int
    bpw: float
    numel: int
    bits_unquantised_params: int
    numel_unquantised_params: int
    nzero_quant: int

    @property
    def entropy_quantised_params(self) -> float:
        total_bits = self.entropy * self.numel
        numel_quant = self.numel - self.numel_unquantised_params
        bits_quant = total_bits - self.bits_unquantised_params
        if numel_quant == 0:
            return 0
        else:
            return bits_quant / numel_quant

    @property
    def entropy_unquantised_params(self) -> float:
        return self.entropy - self.entropy_quantised_params


def _calculate_grid_overhead(nlevels, regular_grid: bool):
    if regular_grid:
        gridpoint_overhead_estimate = 5
    else:
        gridpoint_overhead_estimate = 16 + 5
    if regular_grid:
        if nlevels > 1:
            return nlevels * gridpoint_overhead_estimate + 32
        else:
            return 16
    else:
        return nlevels * gridpoint_overhead_estimate


def entropy_layer(child: nn.Module, axis_specialisation, regular_grid):
    total_bits = 0
    total_numel = 0
    overhead = 0
    nzero = 0
    unquant_bits = 0
    unquant_numel = 0
    if quantisable(child):
        p = child.weight
        if axis_specialisation is not None:
            w = LayerWiseHessian(child).get_weights()

            for j in range(w.shape[axis_specialisation]):
                arr_slice = torch.index_select(
                    w,
                    index=torch.tensor(j, device=w.device),
                    dim=axis_specialisation,
                )
                total_bits += entropy(arr_slice) * arr_slice.numel()
                total_numel += arr_slice.numel()
                # overhead

                levels = arr_slice.unique()
                overhead += _calculate_grid_overhead(levels.numel(), regular_grid)
        else:
            total_bits += entropy(p) * p.numel()
            total_numel += p.numel()
            levels = p.unique()
            overhead += _calculate_grid_overhead(levels.numel(), regular_grid)
        nzero += (p.abs() < 1e-32).sum().detach().cpu().numpy()
    else:
        for p in child.parameters():
            # assume other params are saved in 16bit fp
            total_bits += 16 * p.numel()
            total_numel += p.numel()

            unquant_bits += 16 * p.numel()
            unquant_numel += p.numel()
    if total_numel == 0:
        raise ValueError("No parameters in layer")
    return Entropy(
        total_bits / total_numel,
        overhead,
        (total_bits + overhead) / total_numel,
        total_numel,
        int(unquant_bits),
        unquant_numel,
        nzero,
    )


def entropy_net_with_overhead(
    net: nn.Module,
    axis_specialisation: Optional[int] = None,
    regular_grid: bool = False,
    filter: Optional[Callable] = None,
) -> dict[str, Entropy]:
    """Calculate the average entropy of a network's parameters.

    Gridpoint overhead is calculated as 16 bits for the grid point, 5 bits for the count.
    If the grid is regular, we only need to store the first grid point and the step size, amount to
    32 bits.
    """
    if filter is None:
        filter = lambda _: True
    total_bits = 0
    total_numel = 0
    overhead = 0
    unquant_numel = 0
    unquant_bits = 0
    nzero = 0
    entropies = {}

    for n, child in recursively_find_named_children(net):
        if filter(n) and len(list(child.parameters())) > 0:

            el = entropy_layer(child, axis_specialisation, regular_grid)
            entropies[n] = el
            total_bits += el.entropy * el.numel
            total_numel += el.numel
            overhead += el.overhead
            unquant_numel += el.numel_unquantised_params
            unquant_bits += el.bits_unquantised_params
            nzero += el.nzero_quant
    entropies["all"] = Entropy(
        total_bits / total_numel,
        overhead,
        (total_bits + overhead) / total_numel,
        total_numel,
        unquant_bits,
        unquant_numel,
        nzero,
    )
    return entropies


def entropy_net(
    net: nn.Module,
    add_overhead: bool = True,
    per_row_grid: bool = False,
    quant_weights_only: bool = False,
    filter: Optional[Callable] = None,
) -> float:
    if filter is None:
        filter = lambda _: True
    entropy = entropy_net_with_overhead(
        net, 1 if per_row_grid else None, filter=filter
    )["all"]
    if quant_weights_only:
        assert add_overhead is False, "Combination not implemented"
        return entropy.entropy_quantised_params
    if add_overhead:
        return entropy.bpw
    else:
        return entropy.entropy
