from nn_compression.evaluation._rd_quant import entropy_net
from nn_compression.quantisation import (
    extract_quant_weights,
    RatedGrid,
)
from nn_compression.quantisation._network_quantisation import gptq_quantise_network
from nn_compression.quantisation._gptq import GptqLayer
from nn_compression.networks import recursively_find_named_children
from data_utils.arrays import take_batches
import torch.nn as nn
import pytest
import torch
import copy
import timm
import torchvision
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from pathlib import Path
import detectors  # do not delete this


@pytest.fixture
def net_nested():
    net = nn.Sequential(
        nn.Linear(2, 10), nn.ReLU(), nn.Sequential(nn.Linear(10, 3), nn.ReLU())
    )
    return net


@pytest.fixture
def quant_net_nested(net_nested):
    quant_net = copy.deepcopy(net_nested)
    quant_net[0].weight.data = torch.ones_like(quant_net[0].weight.data)
    quant_net[0].weight.data[0] = 2
    quant_net[2][0].weight.data = torch.ones_like(quant_net[2][0].weight.data)
    quant_net[2][0].weight.data[0] = 3
    return quant_net


def test_extract_grid(quant_net_nested):
    grids = extract_quant_weights(quant_net_nested)
    grid0 = torch.ones((10, 2))
    grid0[0] = 2
    grid1 = torch.ones((10, 2))
    grid1 = torch.ones((3, 10))
    grid1[0] = 3
    assert (grids["0"] == grid0).all()
    assert (grids["2.0"] == grid1).all()


def cifar_acc(net, steps, dataset):
    bs = 128
    acc = 0

    dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
    for i, batch in enumerate(dataloader):
        xsanity, ysanity = batch
        predsanity = net(xsanity).softmax(dim=1).argmax(dim=1)
        accsanity = (predsanity == ysanity).sum().item() / bs
        acc += accsanity
        if i + 1 >= steps:
            break
    return acc / steps
