import os
from tqdm import tqdm
import torch

from neuro.models.intermediate_output_extractor.model import IntermediateOutputExtractor
from ....utils.json_stuff import dict_to_json
from .dataset import MurtyNSD


class NaturalScenesImageEncodingDatasetBuilder:
    def __init__(
        self,
        intermediate_output_extractor: IntermediateOutputExtractor,
        device: str,
        image_filenames_and_labels_folder: str,
        image_encodings_folder: str,
    ):
        """Prepares the data required to make a dataset consisting of the intermediate representations from image encoders as input.

        Args:
            intermediate_output_extractor (IntermediateOutputExtractor): This is used to extract the outputs from a specific layer of a model given an input.
            device (str): Set this to 'cpu' or 'cuda:0' depending on your hardware.
            image_filenames_and_labels_folder (str): Folder which would contain json files which would contain the labels and the filename of the corresponding image file.
            image_encodings_folder (str): Folder which would contain the intermediate encodings of the model saved as tensors i.e ".pth" files.
        """
        assert os.path.exists(image_filenames_and_labels_folder)
        assert os.path.exists(image_encodings_folder)

        self.image_filenames_and_labels_folder = image_filenames_and_labels_folder
        self.image_encodings_folder = image_encodings_folder

        self.device = device
        self.intermediate_output_extractor = intermediate_output_extractor.eval()

    def forward(self, x):
        """Runs a forward pass through the model, and returns the output from the forward hook.

        Args:
            x (torch.tensor): the input tensor of shape (1, 3, height, width)

        Returns:
            _type_: _description_
        """
        if x.ndim == 3:
            x = x.unsqueeze(0)
        else:
            assert x.ndim == 4, "Expected 4d input batch, channel, height, width"
        with torch.no_grad():
            y = self.intermediate_output_extractor(x.to(self.device))
        return y

    def build(self, dataset: MurtyNSD):
        """Iterates over the entire dataset, one item at a time and saves the required labels and intermediate output tensors.

        Args:
            dataset: MurtyNSD
        """
        assert isinstance(dataset, MurtyNSD)

        for dataset_idx in tqdm(range(len(dataset))):
            item = dataset[dataset_idx]
            image_tensor = item["image_tensor"]
            hook_output = self.forward(x=image_tensor)

            image_filename_and_label_filename = os.path.join(
                self.image_filenames_and_labels_folder, f"{dataset_idx}.json"
            )

            image_encoding_filename = os.path.join(
                self.image_encodings_folder, f"{dataset_idx}.pth"
            )

            json_data = {
                "brain_response": item["brain_response"].tolist(),
                "image_encoding_filename": image_encoding_filename,
            }

            torch.save(hook_output.cpu().float(), image_encoding_filename)
            dict_to_json(json_data, filename=image_filename_and_label_filename)

        print(f"Saved:", dataset_idx + 1, "samples")
