import torch
import torchvision.models as models
from generate_gratings import StripesDataset, imagenet_transforms
import argparse
import os
from tqdm import tqdm
from nesim.utils.checkpoint import load_and_filter_state_dict_keys
from nesim.utils.getting_modules import get_module_by_name
from nesim.utils.grid_size import find_rectangle_dimensions
import matplotlib.pyplot as plt
import matplotlib
from nesim.utils.single_conv_filter_output_extractor import (
    SingleConvFilterOutputExtractor,
    FilterLocation,
)
from einops import rearrange
from PIL import Image
from nesim.utils.hook import ForwardHook

outputs_folder = "./orientation_maps/"

# os.system(f"rm {outputs_folder}/*.jpg")
if not os.path.exists(outputs_folder):
    os.makedirs(outputs_folder)

matplotlib.use("Agg")

parser = argparse.ArgumentParser(description="generate orientation selectivity maps")
parser.add_argument(
    "--checkpoint-filename",
    type=str,
    help="Path to the model checkpoint file",
    required=True,
)  # /research/XXXX-3/repos/nesim/training/imagenet/resnet18/checkpoints/imagenet/shrink_factor_[5.0]_loss_scale_150_layers_all_conv_layers__bimt_scale_None_from_pretrained_False_apply_every_20_steps/best/best_model.ckpt

parser.add_argument(
    "--layer-name",
    type=str,
    help="name of layer to visualize",
    required=True,
    default="layer3.0.conv1",
)

args = parser.parse_args()

device = "cuda:0"

hook_outputs_folder = f"./hook_outputs/{args.layer_name}"

model = models.resnet18(weights="DEFAULT")

model.load_state_dict(load_and_filter_state_dict_keys(args.checkpoint_filename))
model.to(device)

# orientation_angles = [i * 22.5 for i in range(0, 8)]
orientation_angles = [i for i in range(0, 180)]


class GratingsDataset:
    def __init__(self, folder: str, angles: list):
        assert os.path.exists(folder)
        self.filenames = [
            os.path.join(folder, f"{orientation_angle}.jpg")
            for orientation_angle in angles
        ]
        self.orientation_angles = angles

        for f in self.filenames:
            assert os.path.exists(f), f"Invalid filename: {f}"

    def __getitem__(self, idx):
        return Image.open(self.filenames[idx]), torch.tensor(
            self.orientation_angles[idx]
        )  # label is orientation angle

    def __len__(self):
        return len(self.filenames)


stripes_dataset = GratingsDataset(folder="./stripes/", angles=orientation_angles)

target_layer = get_module_by_name(module=model, name=args.layer_name)
num_input_channels = target_layer.in_channels
num_output_channels = target_layer.out_channels

hook_output_filenames = []
os.system(f"rm -rf {hook_outputs_folder}")
os.system(f"mkdir -p {hook_outputs_folder}")

hook = ForwardHook(module=target_layer)

with torch.no_grad():
    for dataset_idx in tqdm(range(len(stripes_dataset))):

        filename = os.path.join(hook_outputs_folder, f"{dataset_idx}.pth")

        pil_image = stripes_dataset[dataset_idx][0]
        image_tensor = imagenet_transforms(pil_image).to(device).unsqueeze(0)
        y = model(image_tensor)
        hook_output = hook.output.detach().cpu()

        torch.save(hook_output, filename)
        hook_output_filenames.append(filename)

all_reduced_hook_outputs = []

for f in tqdm(hook_output_filenames, desc="filter_outputs"):
    filter_output = torch.load(f)

    ## filter_output.shape = (1, 1, h, w)
    filter_output = filter_output.mean(-1).mean(-1)
    all_reduced_hook_outputs.append(filter_output)

# ## num_angles * num_input_channels * num_output_channels
all_reduced_hook_outputs = torch.cat(all_reduced_hook_outputs, dim=0)

grid_size = find_rectangle_dimensions(area=target_layer.out_channels)

top_activating_angles_for_each_neuron = torch.argmax(all_reduced_hook_outputs, dim=0)
top_activating_angles_for_each_neuron_grid = (
    top_activating_angles_for_each_neuron.reshape(grid_size.height, grid_size.width)
)


fig = plt.figure()
fig.suptitle(f"Top activating orientation\nlayer: {args.layer_name}\nnum unique gratings: {len(orientation_angles)}")
plt.imshow(top_activating_angles_for_each_neuron_grid.cpu(), cmap="hsv")
plt.colorbar()
# plt.show()
plt.tight_layout()
fig.savefig(f"{outputs_folder}{args.layer_name}_orientation_selectivity.jpg")
plt.close(fig)
print(f"plot saved -> {outputs_folder}{args.layer_name}_orientation_selectivity.jpg")

"""
python3 generate_orientation_map.py --checkpoint-filename /research/XXXX-3/repos/nesim/training/imagenet/resnet18/checkpoints/imagenet/shrink_factor_[5.0]_loss_scale_150_layers_all_conv_layers__bimt_scale_None_from_pretrained_False_apply_every_20_steps/best/best_model.ckpt  --layer-name layer4.0.conv1
"""
