import os
import torch
from torch.utils.data import DataLoader
from nesim.utils.hook import ForwardHook
import torch.nn as nn
import numpy as np
from nesim.utils.grid_size import find_rectangle_dimensions
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm
import math

import torchvision
import torchvision.transforms as transforms
import argparse
from nesim.utils.json_stuff import load_json_as_dict
from nesim.utils.getting_modules import get_module_by_name
from nesim.utils.checkpoint import load_and_filter_state_dict_keys
from nesim.experiments.mnist import get_untrained_model

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

parser = argparse.ArgumentParser(description="vis layer output")
# parser.add_argument("--config-filename", type=str, help="config filename")
parser.add_argument("--checkpoint-filename", type=str, help="config filename")

args = parser.parse_args()

# create a folder "layer_output_plots" if it doesn't already exist
if not os.path.exists("layer_output_plots"):
    os.makedirs("layer_output_plots")

matplotlib.use("Agg")

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)

mnist_test = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform_test
)
test_dataloader = DataLoader(mnist_test, batch_size=1, shuffle=False)

model = get_untrained_model(hidden_size=1024)

model = model.to(device)
model.load_state_dict(load_and_filter_state_dict_keys(args.checkpoint_filename))
layer_names = [
    "4"
]

topo_layers = [get_module_by_name(module=model, name=name) for name in layer_names]
pinwheel_layer_idx = 0
target_layer = topo_layers[pinwheel_layer_idx]  # choose 0, 1, 2

# raise AssertionError(target_layer)

inference_data = []
num_outputs_to_plot = 10
terminate = [0] * 10

for batch_idx, (inputs, labels) in tqdm(
    enumerate(test_dataloader), total=len(test_dataloader), disable=True
):
    if sum(terminate) == num_outputs_to_plot:
        break

    if terminate[labels] == 0:
        # TODO:  just order them from 0-9!
        hook = ForwardHook(module=target_layer)  # 0 => first pinwheel layer
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        with torch.no_grad():
            logits = model(inputs)

        hook.close()
        data_dict_temp = {}
        data_dict_temp["hook_output"] = hook.output
        data_dict_temp["logits"] = logits
        data_dict_temp["label"] = labels
        data_dict_temp["input"] = inputs
        inference_data.append(data_dict_temp)

        terminate[labels] = 1

# Sort inference_data based on labels
inference_data.sort(key=lambda x: x["label"].item())

fig, ax = plt.subplots(ncols=num_outputs_to_plot, nrows=3, figsize=(25, 12))
images_for_colorbar = []

for plot_col_idx in range(num_outputs_to_plot):
    rect_dims = find_rectangle_dimensions(
        area=inference_data[plot_col_idx]["hook_output"][0].shape[0]
    )

    img = inference_data[plot_col_idx]["input"][0][0]
    label = inference_data[plot_col_idx]["label"][0]
    reshaped_layer_output = (
        inference_data[plot_col_idx]["hook_output"][0]
        .reshape(rect_dims.height, rect_dims.width)
        .detach()
        .cpu()
    )
    logits_to_plot = inference_data[plot_col_idx]["logits"][0].detach().cpu()
    # raise AssertionError(logits_to_plot)

    ax[0, plot_col_idx].imshow(img.cpu(), cmap="gray")
    ax[0, plot_col_idx].set_title(f"input image\nlabel={label}")

    ax[1, plot_col_idx].bar(
        [str(i) for i in range(10)], torch.nn.functional.softmax(logits_to_plot, dim=0)
    )
    # ax[2, plot_col_idx].set_ylim([-2,2])
    ax[1, plot_col_idx].set_title("logits")

    ax[2, plot_col_idx].imshow(reshaped_layer_output, cmap="hsv")
    ax[2, plot_col_idx].set_title(f"topo layer output\nlayer idx {pinwheel_layer_idx}")

    images_for_colorbar.append(
        ax[2, plot_col_idx].imshow(reshaped_layer_output, cmap="viridis")
    )

cbar = fig.colorbar(images_for_colorbar[0], ax=ax[2, :], orientation="horizontal")
cbar.set_label("colorbar (topo layer output)")

fig.savefig(
    f"./layer_output_plots/pinwheel_layer_idx_{pinwheel_layer_idx}_vis_layer_output.jpg"
)