import os
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm
import torchvision.models as models
import torch
from nesim.utils.json_stuff import load_json_as_dict
from nesim.utils.getting_modules import get_module_by_name
from nesim.utils.hook import ForwardHook
from nesim.utils.grid_size import find_rectangle_dimensions
import matplotlib.pyplot as plt
import os


def get_filenames_in_a_folder(folder: str):
    """
    returns the list of paths to all the files in a given folder
    """

    if folder[-1] == "/":
        folder = folder[:-1]

    files = os.listdir(folder)
    files = [f"{folder}/" + x for x in files]
    return files


class CuratedDataset:
    def __init__(self, folders: str):

        self.labels = []
        self.label_names = [os.path.basename(os.path.normpath(f)) for f in folders]
        self.filenames = []

        for folder_idx, folder in enumerate(folders):
            filenames = get_filenames_in_a_folder(folder=folder)
            self.filenames.extend(filenames)
            self.labels.extend([folder_idx for i in range(len(filenames))])
        print(
            f"prepared a dataset of {len(self.filenames)} images with {len(self.label_names)} categories"
        )

        self.imagenet_transforms = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    def __getitem__(self, idx):
        return {
            "filename": self.filenames[idx],
            "image": Image.open(self.filenames[idx]),
            "label": self.labels[idx],
            "label_name": self.label_names[self.labels[idx]],
            # 'image': self.imagenet_transforms(Image.open(self.filenames[idx]).convert('RGB'))
        }

    def __len__(self):
        return len(self.filenames)


def run_inference_on_dataset_and_get_hook_outputs(
    model,
    dataset: CuratedDataset,
    target_layer_names,
    transforms,
    device="cuda:0",
):
    all_forward_hooks = {}

    for name in target_layer_names:
        layer = get_module_by_name(module=model, name=name)
        hook = ForwardHook(module=layer)
        all_forward_hooks[name] = hook

    all_hook_outputs = {}

    for name in target_layer_names:
        all_hook_outputs[name] = []

    with torch.no_grad():
        for dataset_idx in tqdm(range(len(dataset))):

            item = dataset[dataset_idx]

            logits = model.forward(
                transforms(item["image"].convert("RGB")).to(device).unsqueeze(0)
            )

            for name, hook in all_forward_hooks.items():
                all_hook_outputs[name].append(hook.output.cpu().detach())

    for hook in all_forward_hooks.values():
        hook.close()

    return all_hook_outputs


import torch


def find_most_activating_label_avg(
    all_outputs_for_single_layer, dataset_labels, threshold=0.0
):
    # Ensure that the shapes of inputs are consistent
    assert all_outputs_for_single_layer.shape[0] == len(
        dataset_labels
    ), "The number of samples in all_outputs_for_single_layer and dataset_labels must match. But got: {}"

    num_neurons = all_outputs_for_single_layer.shape[1]
    num_labels = len(torch.unique(torch.tensor(dataset_labels)))
    most_activating_labels_avg = torch.zeros(num_neurons, dtype=torch.int64)

    all_outputs_tensor = torch.tensor(all_outputs_for_single_layer)
    dataset_labels_tensor = torch.tensor(dataset_labels)

    pureness_heatmap = torch.zeros(num_neurons)
    for neuron_idx in range(num_neurons):
        label_sum_activations = torch.zeros(num_labels)
        label_count = torch.zeros(num_labels)

        # Iterate through all samples to accumulate activation values for each label
        for sample_idx in range(all_outputs_tensor.shape[0]):
            activation = all_outputs_tensor[sample_idx, neuron_idx]
            label = dataset_labels_tensor[sample_idx]
            label_sum_activations[label] += activation
            label_count[label] += 1

        # Compute average activation for each label and select the label with highest average
        label_avg_activations = label_sum_activations / label_count

        label_avg_activations = label_avg_activations.softmax(-1)
        most_activating_label = torch.argmax(label_avg_activations)
        max_activation_value_after_softmax = torch.amax(label_avg_activations, dim=0)
        pureness_heatmap[neuron_idx] = max_activation_value_after_softmax

        if threshold > 0:
            if max_activation_value_after_softmax > threshold:
                most_activating_labels_avg[neuron_idx] = most_activating_label
            else:
                ## set value to None
                most_activating_labels_avg[neuron_idx] = num_labels
        else:
            most_activating_labels_avg[neuron_idx] = most_activating_label

    return most_activating_labels_avg, pureness_heatmap


class CategorySelectivityMapExperiment:
    def __init__(
        self,
        model,
        checkpoint_filename: str,
        device: str,
        target_layer_names: list,
        dataset: CuratedDataset,
    ):
        self.model = model
        self.device = device

        state_dict = torch.load(checkpoint_filename)["state_dict"]
        state_dict_with_fixed_keys = {}
        for key in state_dict:
            state_dict_with_fixed_keys[key.replace("model.", "")] = state_dict[key]

        self.model.load_state_dict(state_dict_with_fixed_keys)
        self.model = self.model.eval().to(device)
        self.dataset = dataset
        self.target_layer_names = target_layer_names

    def run(self):
        self.all_hook_outputs = run_inference_on_dataset_and_get_hook_outputs(
            model=self.model,
            dataset=self.dataset,
            target_layer_names=self.target_layer_names,
            device=self.device,
            transforms=self.dataset.imagenet_transforms,
        )

    def get_all_outputs_for_single_layer(self, layer_name: str):
        all_outputs_for_single_layer = []

        for dataset_idx in range(len(self.dataset)):
            # output_norm_along_hw = torch.norm(self.all_hook_outputs[layer_name][dataset_idx], dim = (2,3))
            output_mean_along_hw = (
                self.all_hook_outputs[layer_name][dataset_idx].mean(-1).mean(-1)
            )
            all_outputs_for_single_layer.append(output_mean_along_hw)

        ## all_outputs_for_single_layer.shape: len_dataset, num_neurons
        return torch.cat(all_outputs_for_single_layer, dim=0)
