import os
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm
import torchvision.models as models
import torch
from nesim.utils.json_stuff import load_json_as_dict
from nesim.utils.getting_modules import get_module_by_name
from nesim.utils.hook import ForwardHook
from nesim.utils.grid_size import find_rectangle_dimensions
import matplotlib.pyplot as plt
import os


def get_filenames_in_a_folder(folder: str):
    """
    returns the list of paths to all the files in a given folder
    """

    if folder[-1] == "/":
        folder = folder[:-1]

    files = os.listdir(folder)
    files = [f"{folder}/" + x for x in files]
    return files


class CuratedDataset:
    def __init__(self, folders: str):

        self.labels = []
        self.label_names = [os.path.basename(os.path.normpath(f)) for f in folders]
        self.filenames = []

        for folder_idx, folder in enumerate(folders):
            filenames = get_filenames_in_a_folder(folder=folder)
            self.filenames.extend(filenames)
            self.labels.extend([folder_idx for i in range(len(filenames))])
        print(
            f"prepared a dataset of {len(self.filenames)} images with {len(self.label_names)} categories"
        )

        self.imagenet_transforms = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.48145466, 0.4578275, 0.40821073],
                    std=[0.26862954, 0.26130258, 0.27577711],
                ),
            ]
        )

    def __getitem__(self, idx):
        return {
            "filename": self.filenames[idx],
            "image": Image.open(self.filenames[idx]),
            "label": self.labels[idx],
            "label_name": self.label_names[self.labels[idx]],
            # 'image': self.imagenet_transforms(Image.open(self.filenames[idx]).convert('RGB'))
        }

    def __len__(self):
        return len(self.filenames)


def run_inference_on_dataset_and_get_hook_outputs(
    model,
    dataset: CuratedDataset,
    target_layer_names,
    transforms,
    device="cuda:0",
    progress=False,
):
    all_forward_hooks = {}

    for name in target_layer_names:
        layer = get_module_by_name(module=model, name=name)
        hook = ForwardHook(module=layer)
        all_forward_hooks[name] = hook

    all_hook_outputs = {}

    for name in target_layer_names:
        all_hook_outputs[name] = []

    with torch.no_grad():
        for dataset_idx in tqdm(range(len(dataset)), disable=not (progress)):

            item = dataset[dataset_idx]

            logits = model.forward(
                transforms(item["image"].convert("RGB")).to(device).unsqueeze(0)
            )

            for name, hook in all_forward_hooks.items():
                all_hook_outputs[name].append(hook.output.cpu().detach())

    for hook in all_forward_hooks.values():
        hook.close()

    return all_hook_outputs


import torch


def find_most_activating_label_avg(
    all_outputs_for_single_layer, dataset_labels, threshold=0.0
):
    """
    NOTE: This function is very sus. Please check it and make sure that it makes sense.
    """
    # Ensure that the shapes of inputs are consistent
    assert all_outputs_for_single_layer.shape[0] == len(
        dataset_labels
    ), "The number of samples in all_outputs_for_single_layer and dataset_labels must match. But got: {}"

    num_neurons = all_outputs_for_single_layer.shape[1]
    num_labels = len(torch.unique(torch.tensor(dataset_labels)))
    most_activating_labels_avg = torch.zeros(num_neurons, dtype=torch.int64)

    all_outputs_tensor = torch.tensor(all_outputs_for_single_layer)
    dataset_labels_tensor = torch.tensor(dataset_labels)

    pureness_heatmap = torch.zeros(num_neurons)
    for neuron_idx in range(num_neurons):
        label_sum_activations = torch.zeros(num_labels)
        label_count = torch.zeros(num_labels)

        # Iterate through all samples to accumulate activation values for each label
        for sample_idx in range(all_outputs_tensor.shape[0]):
            activation = all_outputs_tensor[sample_idx, neuron_idx]
            label = dataset_labels_tensor[sample_idx]
            label_sum_activations[label] += activation
            label_count[label] += 1

        # Compute average activation for each label and select the label with highest average
        label_avg_activations = label_sum_activations / label_count

        label_avg_activations = label_avg_activations.softmax(-1)
        most_activating_label = torch.argmax(label_avg_activations)
        max_activation_value_after_softmax = torch.amax(label_avg_activations, dim=0)
        pureness_heatmap[neuron_idx] = max_activation_value_after_softmax

        if threshold > 0:
            if max_activation_value_after_softmax > threshold:
                most_activating_labels_avg[neuron_idx] = most_activating_label
            else:
                ## set value to None
                most_activating_labels_avg[neuron_idx] = num_labels
        else:
            most_activating_labels_avg[neuron_idx] = most_activating_label

    return most_activating_labels_avg, pureness_heatmap


class CategorySelectivityMapExperiment:
    def __init__(
        self, model, device: str, target_layer_names: list, dataset: CuratedDataset
    ):
        self.model = model.eval().to(device)
        self.device = device
        self.dataset = dataset
        self.target_layer_names = target_layer_names

    def run(self):
        self.all_hook_outputs = run_inference_on_dataset_and_get_hook_outputs(
            model=self.model,
            dataset=self.dataset,
            target_layer_names=self.target_layer_names,
            device=self.device,
            transforms=self.dataset.imagenet_transforms,
        )

    def get_all_outputs_for_single_layer(self, layer_name: str):
        all_outputs_for_single_layer = []

        for dataset_idx in range(len(self.dataset)):
            # output_norm_along_hw = torch.norm(self.all_hook_outputs[layer_name][dataset_idx], dim = (2,3))
            output_mean_along_hw = (
                self.all_hook_outputs[layer_name][dataset_idx].mean(-1).mean(-1)
            )
            all_outputs_for_single_layer.append(output_mean_along_hw)

        ## all_outputs_for_single_layer.shape: len_dataset, num_neurons
        return torch.cat(all_outputs_for_single_layer, dim=0)


##########################################################################

import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
from neuro.models.neuro_model.config import NeuroModelConfig
from nesim.utils.json_stuff import load_json_as_dict
from neuro.models.neuro_model.model import NeuroModel
from neuro.utils.model_building import make_convmapper_config_from_size_sequences
from PIL import Image
from nesim.utils.grid_size import find_rectangle_dimensions
import os
import argparse
from nesim.vis.video import generate_video

parser = argparse.ArgumentParser(
    description="Trains a convmapper on the murty185 dataset with nesim loss"
)

parser.add_argument("--start", type=int, help="start index", default=0)
parser.add_argument("--stop", type=int, help="stop index", default=90_000)
parser.add_argument("--step", type=int, help="step size", default=100)
parser.add_argument("--framerate", type=int, help="step size", default=10)
parser.add_argument(
    "--no-overwrite-frames", help="disable frame overwriting", action="store_false"
)
parser.add_argument("--config", type=str, help="config file")

args = parser.parse_args()

config = load_json_as_dict(filename=args.config)

overwrite_frames = args.no_overwrite_frames

if overwrite_frames is False:
    print("Will not overwrite frames. Will just render video")

conv_mapper_config = make_convmapper_config_from_size_sequences(
    conv_layer_size_sequence=config["model"]["conv_layer_size_sequence"],
    linear_layer_size_sequence=config["model"]["linear_layer_size_sequence"],
    reduce_fn=config["model"]["conv_mapper_reduce_fn"],
    activation=config["model"]["conv_mapper_activation"],
    conv_layer_kernel_size=config["model"]["conv_mapper_kernel_size"],
)

neuro_model_config = NeuroModelConfig(
    brain_response_predictor_config=conv_mapper_config,
    image_encoder=config["model"]["neuro_model_config_image_encoder"],
    hook_layer_name=config["model"]["intermediate_layer_name"],
    reduce_fn=config["model"]["neuro_model_reduce_fn"],
)


model = NeuroModel(config=neuro_model_config, device=config["model"]["device"])

frames_folder = config["category_selectivity_map_frames_folder"]
dataset_root = "/mindhive/nklab3/users/XXXX-1/repos/nesim/training/murty185/datasets/curated/nsd_curated"

os.system(f'mkdir -p {config["category_selectivity_map_frames_folder"]}')
threshold = 0.0

folders = [
    "scene",
    "face",
    # 'food',
    "body",
    # 'text',
]


folder_colors = {
    "scene": "#00B140",  # Tropical Green for the scene
    "face": "#00A8E8",  # Tropical Ocean Blue for the face
    "body": "#F5DEB3",  # Pastel Sea Sand for the body
}


checkpoint_filenames = [
    f"train_step_idx_{i}.pth" for i in range(args.start, args.stop, args.step)
]

step_indices = [i for i in range(args.start, args.stop, args.step)]
print(f"start: {args.start} stop: {args.stop} step: {args.step}")
print(f"Will visualize {len(checkpoint_filenames)} checkpoints")
checkpoint_paths = [
    os.path.join(config["save_checkpoint_every_n_steps_folder"], f)
    for f in checkpoint_filenames
]

target_layer_name = "brain_response_predictor.conv_layers.0"
dataset = CuratedDataset(folders=[os.path.join(dataset_root, f) for f in folders])

if threshold > 0:
    print("Threshold is not 0")
    folder_colors["none"] = "#808080"  ## gray
    folders.append("none")
else:
    pass


# # Define your custom colors for each integer
colors = [folder_colors[f] for f in folder_colors]

# # Create a ListedColormap with the specified colors
cmap = mcolors.ListedColormap(colors)


frame_filenames = []
pbar = tqdm(total=len(checkpoint_filenames))
for filename, path, step_idx in zip(
    checkpoint_filenames, checkpoint_paths, step_indices
):

    if overwrite_frames:
        ##################################################
        model.brain_response_predictor.load(path)

        experiment = CategorySelectivityMapExperiment(
            model=model,
            device="cuda:0",
            target_layer_names=[target_layer_name],
            dataset=dataset,
        )

        experiment.run()
        outputs = experiment.get_all_outputs_for_single_layer(
            layer_name=target_layer_name
        )

        all_outputs_for_single_layer_ours = experiment.get_all_outputs_for_single_layer(
            layer_name=target_layer_name
        )
        grid_size = find_rectangle_dimensions(
            all_outputs_for_single_layer_ours.shape[1]
        )
        (
            most_activating_labels_avg_ours,
            pureness_heatmap_ours,
        ) = find_most_activating_label_avg(
            all_outputs_for_single_layer_ours, dataset.labels, threshold=threshold
        )

        pureness_heatmap_ours = pureness_heatmap_ours.reshape(
            grid_size.height, grid_size.width
        ).numpy()
        # Plot our result
        image_ours = most_activating_labels_avg_ours.reshape(
            grid_size.height, grid_size.width
        ).numpy()

        fig = plt.figure(figsize=(10, 10))
        # fig.suptitle(f'Training step: {int(step_idx)}\nThreshold: {threshold}', fontsize = 18)
        fig.suptitle(f"Training step: {int(step_idx)}", fontsize=25)

        cax = plt.imshow(image_ours, cmap=cmap)
        plt.axis("off")
        cbar = plt.colorbar(cax, ticks=[i for i in range(len(folder_colors))])
        cbar.ax.tick_params(labelsize=25)

        cbar.set_ticklabels(folders)
    ##############################################

    frame_filename = os.path.join(frames_folder, f"{filename}.jpg")
    if overwrite_frames:
        fig.savefig(frame_filename)

    frame_filenames.append(frame_filename)
    pbar.update(1)

list_of_pil_images = [Image.open(name) for name in frame_filenames]
generate_video(
    list_of_pil_images=list_of_pil_images,
    framerate=args.framerate,
    filename="category_selectivity_map.mp4",
    size=(1024, 1024),
)
