from nesim.sparsity.conv import SparseConv2d
import torch.nn as nn
import torchvision.models as models
import torch
import matplotlib.pyplot as plt
from nesim.utils.getting_modules import get_module_by_name


layer_name = "layer3.0.conv1"
# layer_name = "conv1"

# checkpoint_filename = "/mindhive/nklab3/users/mayukh/repos/nesim/training/imagenet/resnet50/checkpoints/imagenet/torchvision_recipe_shrink_factor_[5.0]_loss_scale_50_layers_all_conv_layers__bimt_scale_None_from_pretrained_False_apply_every_30_steps/all/train_step_idx_120000.pth"
checkpoint_filename = "/mindhive/nklab3/users/mayukh/repos/nesim/training/imagenet/resnet50/checkpoints/imagenet/torchvision_recipe_shrink_factor_[5.0]_loss_scale_200_layers_layer3__bimt_scale_None_from_pretrained_False_apply_every_30_steps/all/train_step_idx_100000.pth"

x = torch.randn(1, 512, 32, 32)

thresholds = [
    # 0.1,
    # 0.2,
    # 0.3,
    # 0.4,
    # 0.5,
    # 0.6,
    # 0.7,
    # 0.8,
    0.9,
    0.925,
    0.95,
    0.975,
    0.98,
]

deltas = []

with torch.no_grad():
    for threshold in thresholds:
        model = models.resnet50(weights="DEFAULT").to("cuda:0")
        model.eval()
        sparse_conv_baseline = SparseConv2d(
            conv_layer=get_module_by_name(module=model, name=layer_name),
            cluster_similarity_threshold=threshold,
            device="cuda:0",
        )

        model.load_state_dict(torch.load(checkpoint_filename))

        sparse_conv_ours = SparseConv2d(
            conv_layer=get_module_by_name(module=model, name=layer_name),
            cluster_similarity_threshold=threshold,
            device="cuda:0",
        )

        y_original = get_module_by_name(module=model, name=layer_name)(x.to("cuda:0"))

        y_baseline = sparse_conv_baseline.forward(x.to("cuda:0"))
        assert torch.allclose(
            y_original, y_baseline, atol=1e-2
        ), f"{y_original - y_baseline}"

        y_ours = sparse_conv_ours.forward(x.to("cuda:0"))
        delta = torch.abs(y_baseline - y_ours).mean().item()

        deltas.append(delta)

fig = plt.figure()
plt.plot(thresholds, deltas)
fig.savefig("deltas.jpg")
