"""
the goal here is to init the nsd dataset and then iterate over each item and
check if it's highly selective over one category such as the face (softmax > threshold)
then save them into a folder. 
"""
import os
import torch
from nesim.experiments.brain_model.natural_scenes_dataset import MurtyNSD
from pprint import pprint


class NSDCuratedDatasetBuilder:
    def __init__(
        self,
        nsd_dataset: MurtyNSD,
        output_folder: str,
        num_max_images_per_category: int = 50,
        threshold: float = 0.0002,
    ) -> None:

        assert os.path.exists(output_folder)
        self.output_folder = output_folder
        self.dataset = nsd_dataset
        self.num_max_images_per_category = num_max_images_per_category
        self.threshold = threshold

        self.label_map = {0: "scene", 1: "face", 2: "food", 3: "body", 4: "text"}

        """
        output_folder:
            - scene/
                - 0.jpg
                - 1.jpg
                - 2.jpg
            - face/
                - 0.jpg
                ...
        """

    def build(self):
        num_samples = {}
        for key in self.label_map:
            num_samples[key] = 0
            folder_name = os.path.join(self.output_folder, self.label_map[key])
            os.system(f"mkdir -p {folder_name}")
            os.system(f"rm {folder_name}/*.jpg")

        for dataset_idx in range(len(self.dataset)):

            if sum(list(num_samples.values())) < (
                self.num_max_images_per_category * len(self.label_map)
            ):

                item = self.dataset[dataset_idx]

                label_idx = torch.argmax(item["brain_response"]).item()

                if (
                    item["brain_response"][label_idx] >= self.threshold
                    and num_samples[label_idx] < self.num_max_images_per_category
                ):
                    filename = os.path.join(
                        self.output_folder,
                        self.label_map[label_idx],
                        f"{num_samples[label_idx]}.jpg",
                    )
                    item["image_tensor"].save(filename)
                    num_samples[label_idx] += 1

            else:
                break

        pprint(num_samples)


nsd_dataset = MurtyNSD(
    brain_signals_filename="../natural_scenes_dataset/datasets/nsd/component_responses.npy",
    image_data_filename="../natural_scenes_dataset/datasets/nsd/test_images_ordered.npy",
    transforms=None,
)
builder = NSDCuratedDatasetBuilder(
    nsd_dataset=nsd_dataset,
    output_folder="./datasets/curated/nsd_curated/",
    num_max_images_per_category=50,
    threshold=0.0002,
)
builder.build()
