import copy
from nn_compression.evaluation import entropy_net_with_overhead
import torch.nn as nn
import torch
import pytest
import numpy as np


@pytest.fixture
def net():
    weight = torch.tensor([[0.0, 0, 0, 0, 1, 0, 0, 0], [0, 1, 1, 0, 0, 1, 0, 1]])
    net = nn.Linear(2, 8)
    net.weight = nn.Parameter(weight)
    return net


def test_overhead_total(net):
    entropy = entropy_net_with_overhead(net, regular_grid=True)["all"]
    assert (
        entropy.overhead == 2 * 5 + 32
    )  # 2 grid points with 5 bits per count + 2*16 bits for start and end
    assert entropy.entropy == -np.log2(5 / 16) * 5 / 16 - np.log2(11 / 16) * 11 / 16
    assert entropy.bpw == (entropy.entropy + entropy.overhead / 16)


def test_overhead_rowwise(net):
    entropy = entropy_net_with_overhead(net, axis_specialisation=0, regular_grid=True)[
        "all"
    ]
    assert entropy.overhead == (2 * 5 + 32) * 2
    entropy_row_1 = -np.log2(1 / 8) * 1 / 8 - np.log2(7 / 8) * 7 / 8
    entropy_row_2 = -np.log2(4 / 8) * 4 / 8 - np.log2(4 / 8) * 4 / 8
    rw_entropy = (entropy_row_1 * 8 + entropy_row_2 * 8) / 16
    assert entropy.entropy == rw_entropy
    assert entropy.bpw == (entropy.entropy + entropy.overhead / 16)


def test_collapsed_rows(net):
    net.weight.data[0, :] = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0])
    entropy = entropy_net_with_overhead(net, axis_specialisation=0, regular_grid=True)[
        "all"
    ]
    assert entropy.overhead == (2 * 5 + 32) + 16
    entropy_row_1 = 0
    entropy_row_2 = -np.log2(4 / 8) * 4 / 8 - np.log2(4 / 8) * 4 / 8
    rw_entropy = (entropy_row_1 * 8 + entropy_row_2 * 8) / 16
    assert entropy.entropy == rw_entropy
    assert entropy.bpw == (entropy.entropy + entropy.overhead / 16)


def test_entropy_layerwise(net):
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.net1 = net
            self.net2 = copy.deepcopy(net)
            self.net2.weight.data[0, :] = torch.tensor([0, 0, 0, 0, 0, 0, 0, 1])

    ent = entropy_net_with_overhead(Net(), axis_specialisation=None, regular_grid=True)
    # assert ent["all"]

    entropy = ent["net1"]
    assert (
        entropy.overhead == 2 * 5 + 32
    )  # 2 grid points with 5 bits per count + 2*16 bits for start and end
    assert entropy.entropy == -np.log2(5 / 16) * 5 / 16 - np.log2(11 / 16) * 11 / 16
    assert entropy.bpw == (entropy.entropy + entropy.overhead / 16)
    ent["all"].bits_unquantised_params = (
        ent["net1"].bits_unquantised_params + ent["net2"].bits_unquantised_params
    )
    ent["all"].numel_unquantised_params = (
        ent["net1"].numel_unquantised_params + ent["net2"].numel_unquantised_params
    )
    ent["all"].nzero_quant = ent["net1"].nzero_quant + ent["net2"].nzero_quant
    ent["all"].entropy = ent["net1"].entropy + ent["net2"].entropy
    ent["all"].overhead = ent["net1"].overhead + ent["net2"].overhead
    ent["all"].numel = ent["net1"].numel + ent["net2"].numel
    ent["all"].bpw = ent["all"].entropy + ent["all"].overhead / 16
