import torch
import torch.nn as nn
import copy


class QuantisedLinearLayer(nn.Module):
    """A quantised linear layer with n quantisation values. The weights are represented as a sum of quantised weights,
    therefore (all sums from i=1 to n)
        W = sum_i a_i * I_i
    where I_i is the indicator function for the i-th quantised weight, with
        I_i = 1 if W = a_i
        I_i = 0 if W != a_i
        sum_i I_i = 1
    """

    def __init__(self, a: list[torch.Tensor], I_a: list[torch.Tensor]) -> None:
        """Initialise the quantised linear layer with a list of quantised weights and their indicator functions."""
        super().__init__()
        self.a = nn.ParameterList([nn.Parameter(ai) for ai in a])
        self.I_a = I_a

        assert len(a) == len(
            I_a
        ), "The number of quantisation values must match the number of indicator functions."

        I_sum = torch.zeros_like(I_a[0])
        for I in I_a:
            I_sum += I
        assert (I_sum == 1).all()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x @ self.weight.T

    @property
    def weight(self):
        """The quantised weight matrix W as
        W = sum_i a_i * I_i."""
        W_quant = torch.zeros_like(self.I_a[0])
        for ai, I_ai in zip(self.a, self.I_a):
            W_quant += ai * I_ai
        return W_quant

    @property
    def grid(self):
        """The quantisation grid as a tensor of unique quantisation values."""
        return torch.tensor([a.detach() for a in self.a])

    @property
    def counts(self):
        """The number of weights that attain each quantisation value."""
        return torch.tensor([I_a.sum().item() for I_a in self.I_a], dtype=torch.int)


def transform_to_quantised_linear_layers(net: nn.Sequential):
    """Quantise all linear layers in a neural network using the QuantisedLinearLayer class
    (see corresponding class for further documentation).

    net: The neural network to quantise. Must be a sequential object that
    contains solely linear layers."""
    prepared_net = copy.deepcopy(net)
    prepared_net_layers = []

    for layer in prepared_net:
        if hasattr(layer, "weight"):
            assert isinstance(
                layer, nn.Linear
            )  # Currently only works on nn.Linear layers
            a_layer = []
            I_a_layer = []
            for a in layer.weight.unique():
                a_layer.append(a)
                I_a = (layer.weight == a).to(torch.float32)
                I_a_layer.append(I_a)
            quantised_layer = QuantisedLinearLayer(a_layer, I_a_layer)
            prepared_net_layers.append(quantised_layer)
        else:
            prepared_net_layers.append(layer)

    return nn.Sequential(*prepared_net_layers)
