"""
Definition of datasets.
"""

import io
import json
import os
import numpy as np
import pandas as pd
import torch
from collections import Counter, OrderedDict
from nltk.tokenize import sent_tokenize, word_tokenize
from torchvision.datasets.folder import pil_loader
from typing import Optional, Callable
from spaces import NBoxSpace
import spaces
import faiss


class OrderedCounter(Counter, OrderedDict):
    """Counter that remembers the order of elements encountered."""

    def __repr__(self):
        return "%s(%r)" % (self.__class__.__name__, OrderedDict(self))

    def __reduce__(self):
        return self.__class__, (OrderedDict(self),)


class MPI3D(torch.utils.data.Dataset):
    FACTORS = {
        "image":{
        0: "object_color",
        1: "object_shape",
        2: "object_size",
        3: "camera_height",
        4: "background_color",
        5: "horizontal_axis",
        6: "vertical_axis",
        }
    }

    FACTOR_SIZES = [
                4,
                4,
                2,
                3,
                3,
                40,
                40
            ] # first 7 latent factors, last 3 img # (460,800, 64, 64, 3)

    mean_per_channel = [0.0993, 0.1370, 0.1107] # values from MPI3d-realworld complex
    std_per_channel = [0.0945, 0.0935, 0.0887]   # values from MPI3_real-world complex

    DISCRETE_FACTORS = FACTORS.copy()

    def __init__(
        self,
        data_dir,
        change_lists,
        n_view: int,
        mode="train",
        transform: Optional[Callable] = None,
        dataset_name= "complex", # "real"
        device="cuda",
    ):
        npz = np.load(data_dir, allow_pickle=True)
        data = npz[
            "images"
        ]  #NOTE: you cannot shuffle the data otherwise the dim<->latent will be messed up!

        self.num_samples = len(data)
        MPI3D.data = data.reshape(self.FACTOR_SIZES + [64, 64, 3])

        # if dataset_name == "real":
        #     self.FACTOR_SIZES = [6, 6, 2, 3, 3, 40, 40] # original shape  # shape (1036800, 3, 64, 64)
        #     self.data = data.reshape(
        #     self.FACTOR_SIZES + [3, 64, 64]).swapaxes(-2, -3).swapaxes(-1, -2)  # first 7 latent factors, last 3 img
        #     MPI3D.mean_per_channel = [0.1285, 0.1648, 0.1397] # values from MPI3d-real
        #     MPI3D.std_per_channel = [0.3346, 0.3710, 0.3467]   # values from MPI3_real

        self.data_dir = data_dir
        # self.change_lists = change_lists
        self.n_view = n_view

        MPI3D.transform = transform or (lambda x: x) 
        self.device = "cuda"
        self.mode = mode

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        multidim_idx = np.unravel_index(idx, shape=self.FACTOR_SIZES)
        return {
            "image": self.data[multidim_idx],
            "multidim_idx": multidim_idx,
            "z_image": {v: multidim_idx[k] for k, v in self.FACTORS["image"].items()},
        }

    def sample(self, size, random_state=None):
        # Shape: [num_factors, size]
        latent_idx = np.vstack([np.random.choice(a=num_values_for_factor, size=size) for i, num_values_for_factor in enumerate(self.FACTOR_SIZES)])

        return latent_idx.T, np.stack([self.data[tuple(latent_idx[:, i])] for i in range(size)])

    
    @staticmethod
    def __collate_fn__random_pair__(batch):
        # Following Locatello2020weakly
        imgs, aug_imgs = [], []
        whole_data =MPI3D.data
        factors = MPI3D.FACTORS["image"]
        factor_sizes = MPI3D.FACTOR_SIZES

        z_images = {
            k: [] for k in factors.values()
        }  # key are the name of the latent factors, not the idx
        aug_z_imgs = {k: [] for k in factors.values()}

        num_factors = len(factors)
        # only perturbing one latent
        # style_size = num_factors - 1
        # index_list = np.random.choice(num_factors, style_size, replace=False)

        # perturb a random number of latents
        style_size = np.random.randint(1, num_factors)
        index_list = np.random.choice(
            num_factors, np.random.choice([1, style_size]), replace=False
        )

        factor_keys = np.asarray(list(factors.keys()))
        content_indices = [factor for factor in factor_keys if factor not in index_list]
        for b in batch:
            img = b["image"]  # [3, 64, 64]
            idx = b["multidim_idx"]  # [7, ]
            z_image = b["z_image"]  # dict: 7 keys, 7 values
            aug_idx = tuple(
                np.random.choice(np.delete(np.array(range(size)), idx[i]))
                if i in index_list
                else idx[i]
                for i, size in enumerate(factor_sizes)
            )
            imgs += [img]
            aug_img = whole_data[aug_idx]
            aug_imgs += [aug_img]
            for idx, k in factors.items():
                z_images[k] += [z_image[k]]  # zimage[k] is just a index
                aug_z_imgs[k] += [aug_idx[idx]]

        for k in factors.values():
            z_images[k] = torch.tensor(z_images[k])
            aug_z_imgs[k] = torch.tensor(aug_z_imgs[k])
        return {
            "image": [
                torch.stack([MPI3D.transform(i) for i in imgs], 0),
                torch.stack([MPI3D.transform(i) for i in aug_imgs], 0)
            ],
            "z_image": [z_images, aug_z_imgs],
            "content_indices": [content_indices],
        }

# --------------------------------- Independent MPI3D -------------------------------
class Indepdenent3DIdent(torch.utils.data.Dataset):
    FACTORS = {
        "image":{
            0: "object_xpos", #"object_shape",
            1: "object_ypos",
            2: "object_zpos",
            3: "object_alpharot",
            4: "object_betarot",
            5: "object_gammarot",
            6: "object_color",
            7: "background_color",
            8: "spotlight_pos",
            9: "spotlight_color"
            }
        }
    DISCRETE_FACTORS = {
        "image": {
        }
    }

    POSITIONS = [0, 1, 2, 8]
    ROTATIONS = [3, 4, 5]
    HUES = [6, 7, 9]
    CHANGE_LISTS = [HUES, POSITIONS + ROTATIONS]

    mean_per_channel = [0.4363, 0.2818, 0.3045]
    std_per_channel = [0.1197, 0.0734, 0.0919]

    def __init__(self,
        data_dir:str,
        change_lists,
        n_view: int,
        mode="train",
        transform: Optional[Callable] = None,
        loader: Optional[Callable] = pil_loader,
        approximate_mode: Optional[bool] = True,
        latent_dimensions_to_use = None) -> None:

        super(Indepdenent3DIdent, self).__init__()

        self.mode = mode
        if self.mode == "val":
            self.mode = "test"

        root = os.path.join(data_dir, f"{self.mode}")
        self.latents = np.load(os.path.join(root, "raw_latents.npy"))
        self.unfiltered_latents = self.latents
        self.change_lists = self.CHANGE_LISTS #TODO: input change list
        self.n_view = n_view


        if latent_dimensions_to_use is not None:
            self.latents = np.ascontiguousarray(
                self.latents[:, latent_dimensions_to_use]
            )

        self.space = NBoxSpace(n=1, min_=-1.0, max_=1.0)
        self.sigma = 1.0
        self.transform = transform or torch.nn.Identity()

        max_length = int(np.ceil(np.log10(len(self.latents))))
        self.image_paths = [
            os.path.join(root, "images", f"{str(i).zfill(max_length)}.png")
            for i in range(self.latents.shape[0])
        ]
        self.loader = loader

        if approximate_mode:
            self._index = faiss.index_factory(
                self.latents.shape[1], "IVF1024_HNSW32,Flat"
            )
            self._index.efSearch = 8
            self._index.nprobe = 10
        else:
            self._index = faiss.IndexFlatL2(self.latents.shape[1])

        if approximate_mode:
            self._index.train(self.latents)
        self._index.add(self.latents)

    def __len__(self) -> int:
        return len(self.latents)

    def __repr__(self) -> str:
        head = "Dataset " + self.__class__.__name__
        body = ["Number of datapoints: {}".format(len(self.latents))]
        if self.root is not None:
            body.append("Root location: {}".format(self.root))
        body += self.extra_repr().splitlines()
        if hasattr(self, "transforms") and self.transforms is not None:
            body += [repr(self.transforms)]
        lines = [head] + [" " * self._repr_indent + line for line in body]
        return "\n".join(lines)

    def __getitem__(self, item):
        del item

        # at first sample z, z~
        # then map them to the closes grid point for which we have images
        z = self.space.uniform(size=len(self.FACTORS["image"]), device="cpu").T
        distance_z, index_z = self._index.search(z.numpy(), 1)
        index_z = index_z[0, 0]
        z = self.latents[index_z]
        path_z = self.image_paths[index_z]
        imgs = [self.transform(self.loader(path_z))]
        zs = [z]

        for k in range(self.n_view - 1):
            z_tilde = np.copy(z)
            change_list = self.change_lists[k]
            for j in change_list:
                z_tilde[j] = (
                    self.space.uniform(
                        # torch.reshape(torch.from_numpy(np.array([z[j]])), (1, 1)),
                        # self.sigma,
                        size=1,
                        device="cpu",
                    )
                    .numpy()
                    .flatten()
                )
            distance_z_tilde, index_z_tilde = self._index.search(z_tilde[None], 2)

            # don't use the same sample for z, z~
            if index_z_tilde[0, 0] != index_z:
                index_z_tilde = index_z_tilde[0, 0]
            else:
                index_z_tilde = index_z_tilde[0, 1]

            z_tilde = self.latents[index_z_tilde]

            path_z_tilde = self.image_paths[index_z_tilde]

            x_tilde = self.transform(self.loader(path_z_tilde))
            zs +=[z_tilde]
            imgs += [x_tilde]

        return {
            "image": imgs,
            "z_image": [{self.FACTORS["image"][i]: v for i, v in enumerate(z)} for z in zs]
            }

    def sample(self, size, random_state=None):
        latents = NBoxSpace(n=len(self.FACTORS["image"]), min_=-1.0, max_=1.0).uniform(
                        # torch.reshape(torch.from_numpy(np.array([z[j]])), (1, 1)),
                        # self.sigma,
                        size=size,
                        device="cpu",
                    ).numpy()
        zs, imgs = [], []
        for z_prop in latents:
            distance_z, index_z = self._index.search(z_prop[None], 1)
            index_z = index_z[0, 0]
            z = self.latents[index_z]
            path_z = self.image_paths[index_z]
            imgs += [self.transform(self.loader(path_z))]
            zs += [z]
        # Shape: [size, num_factors]

        return np.vstack(zs), torch.stack(imgs)



# ----------------------------------- Causal 3d ident --------------------------------
class Causal3DIdent(torch.utils.data.Dataset):
    """Load Causal3DIdent dataset"""
    FACTORS = {
        "image": {
            0: "object_shape",
            1: "object_ypos",
            2: "object_xpos",
            3: "object_zpos", 
            4: "object_alpharot",
            5: "object_betarot",
            6: "object_gammarot",
            7: "spotlight_pos",
            8: "object_color",
            9: "spotlight_color",
            10: "background_color",
        }
    }
    CLASSES = range(7) # number of object shapes

    mean_per_channel = [0.4327, 0.2689, 0.2839]
    std_per_channel = [0.1201, 0.1457, 0.1082]

    POSITIONS = [1, 2, 3]
    ROTATIONS = [4, 5, 6]
    HUES = [7, 8, 9]

    DISCRETE_FACTORS = {
        "image": {
            0: "object_shape"
        }
    }

    def __init__(
        self,
        change_lists,
        n_view: int,
        data_dir: str,
        mode: str = "train",
        transform: Optional[Callable] = None,
        loader: Optional[Callable] = pil_loader,
        latent_dimensions_to_use=range(10),
        approximate_mode: Optional[bool] = True,
    ):
        super(Causal3DIdent, self).__init__()

        

        # self.apply_rotation = apply_rotation
        self.change_lists = change_lists

        self.mode = mode
        if self.mode == "val":
            self.mode = "test"

        # self.use_augmentations = use_augmentations
        self.space = NBoxSpace(n=1, min_=-1.0, max_=1.0)
        self.sigma = 1.0
        self.root = os.path.join(data_dir, self.mode)
        self.n_view = n_view

        self.classes = self.CLASSES
        self.latent_classes = []
        for i in self.classes:
            self.latent_classes.append(
                np.load(os.path.join(self.root, "raw_latents_{}.npy".format(i)))
            )
        self.unfiltered_latent_classes = self.latent_classes

        if latent_dimensions_to_use is not None:
            # print('not none')
            for i in self.classes:
                self.latent_classes[i] = np.ascontiguousarray(
                    self.latent_classes[i][:, latent_dimensions_to_use]
                )

        self.image_paths_classes = []
        for i in self.classes:
            max_length = int(np.ceil(np.log10(len(self.latent_classes[i]))))
            self.image_paths_classes.append(
                [
                    os.path.join(
                        self.root, "images_{}".format(i), f"{str(j).zfill(max_length)}.png"
                    )
                    for j in range(self.latent_classes[i].shape[0])
                ]
            )
        self.loader = loader

        self._index_classes = []

        self.transform = transform or (lambda x: x)
        for i in self.classes:
            if approximate_mode:
                _index = faiss.index_factory(
                    self.latent_classes[i].shape[1], "IVF1024_HNSW32,Flat"
                )
                _index.efSearch = 8
                _index.nprobe = 10
            else:
                _index = faiss.IndexFlatL2(self.latent_classes[i].shape[1])

            if approximate_mode:
                _index.train(self.latent_classes[i])
            _index.add(self.latent_classes[i])
            self._index_classes.append(_index)

    def __len__(self) -> int:
        return len(self.latent_classes[0]) * len(self.classes)

    def __getitem__(self, item):  # item is the number of image
        class_id = item // len(self.latent_classes[0])
        z = self.latent_classes[class_id][item % len(self.latent_classes[0])]
        # z.shape=(10, ), contains factors except object shape
        path_z = self.image_paths_classes[item // len(self.latent_classes[0])][item % len(self.latent_classes[0])]

        sample = self.loader(path_z)
        x1 = self.transform(sample)
        
        tuple_z = [z.flatten()]
        z_dict = {self.FACTORS["image"][0]: class_id}
        for i in range(len(z.flatten())):
            z_dict[self.FACTORS["image"][i+1]] = z.flatten()[i]
            
        z_dicts = [z_dict]
        tuple_x = [x1]
        for k in range(len(self.change_lists)):
            z_tilde = np.copy(z)
            change_list = self.change_lists[k]
            for j in change_list:
                assert j > 0
                z_tilde[j-1] = self.space.uniform(size=1, device="cpu").numpy().flatten()
            
                # in the same class
            _, index_z_tilde = self._index_classes[item // len(self.latent_classes[0])].search(z_tilde[np.newaxis], 2)
            if index_z_tilde[0, 0] != (item % len(self.latent_classes[0])):
               #in case search original item
              
               index_z_tilde = index_z_tilde[0, 0]
            else:
               index_z_tilde = index_z_tilde[0, 1]
            z_tilde = self.latent_classes[item // len(self.latent_classes[0])][index_z_tilde]
            path_z_tilde = self.image_paths_classes[item // len(self.latent_classes[0])][index_z_tilde]
            sample = self.loader(path_z_tilde)
            x2 = self.transform(sample)
            tuple_z += [z_tilde.flatten()]
            z_tilde_dict = {self.FACTORS["image"][0]: class_id}
            for i in range(1, len(z_tilde.flatten())):
                # optional manuelly fix invariance here
                z_tilde_dict[self.FACTORS["image"][i]] = z_tilde.flatten()[i] if i in change_list else z.flatten()[i]
            z_dicts += [z_tilde_dict]
            tuple_x += [x2]
        # return item // len(self.latent_classes[0]), tuple_z, tuple_x
        return {
            "image": tuple_x,
            "z_image": z_dicts}
        
    def sample(self, size, random_state=None):
        classes= np.random.choice(Causal3DIdent.CLASSES, size=size)
        latents = NBoxSpace(n=len(self.FACTORS["image"]) - 1, min_=-1.0, max_=1.0).uniform(
                    # torch.reshape(torch.from_numpy(np.array([z[j]])), (1, 1)),
                    # self.sigma,
                    size=size,
                    device="cpu",
                ).numpy()
        zs, imgs = [], []
        for c, z_prop in zip(classes, latents):
            distance_z, index_z = self._index_classes[c].search(z_prop[None], 1)
            index_z = index_z[0, 0]
            z = self.latent_classes[c][index_z]
            path_z = self.image_paths_classes[c][index_z]
            imgs += [self.transform(self.loader(path_z))]
            z_full = np.concatenate([[c],z])
            zs += [z_full]
        # Shape: [size, num_factors]
        return np.vstack(zs), torch.stack(imgs)


# ----------------------------------- Multimodal3DIdent --------------------------------
class Multimodal3DIdent(torch.utils.data.Dataset):
    """Multimodal3DIdent Dataset.

    Attributes:
        FACTORS (dict): names of factors for image and text modalities.
        DISCRETE_FACTORS (dict): names of discrete factors, respectively.
    """

    FACTORS = {
        "image": {
            0: "object_shape",
            1: "object_ypos",
            2: "object_xpos",
            3: "object_zpos",  # is constant
            4: "object_alpharot",
            5: "object_betarot",
            6: "object_gammarot",
            7: "spotlight_pos",
            8: "object_color",
            9: "spotlight_color",
            10: "background_color",
        },
        "text": {
            0: "object_shape",
            1: "object_ypos",
            2: "object_xpos",
            3: "object_zpos",  # is constant
            4: "object_color_index",
            5: "text_phrasing",
        },
    }

    DISCRETE_FACTORS = {
        "image": {
            0: "object_shape",
            1: "object_ypos",
            2: "object_xpos",
            3: "object_zpos",  # is constant
        },
        "text": {
            0: "object_shape",
            1: "object_ypos",
            2: "object_xpos",
            3: "object_zpos",  # is constant
            4: "object_color_index",
            5: "text_phrasing",
        },
    }
    IMAGE_SPACES = {
        0: spaces.DiscreteSpace(n_choices=7),
        1: spaces.DiscreteSpace(n_choices=3),
        2: spaces.DiscreteSpace(n_choices=3),
        3: spaces.DiscreteSpace(n_choices=1),
        4: spaces.NBoxSpace(n=1, min_=0., max_=1.),
        5: spaces.NBoxSpace(n=1, min_=0., max_=1.),
        6: spaces.NBoxSpace(n=1, min_=0., max_=1.),
        7: spaces.NBoxSpace(n=1, min_=0., max_=1.),
        8: spaces.NBoxSpace(n=1, min_=0., max_=1.),
        9: spaces.NBoxSpace(n=1, min_=0., max_=1.),
        10: spaces.NBoxSpace(n=1, min_=0., max_=1.)    
    }
    mean_per_channel = [0.4327, 0.2689, 0.2839]  # values from Causal3DIdent
    std_per_channel = [0.1201, 0.1457, 0.1082]  # values from Causal3DIdent

    def __init__(
        self,
        data_dir,
        change_lists,
        n_view: int,
        mode="train",
        has_labels=True,
        vocab_filepath=None,
        transform: Optional[Callable] = None,
        loader: Optional[Callable] = pil_loader,
        approximate_mode: Optional[bool] = True,
    ):
        """
        Args:
            data_dir (string): path to  directory.
            mode (string): name of data split, 'train', 'val', or 'test'.
            transform (callable): Optional transform to be applied.
            has_labels (bool): Indicates if the data has ground-truth labels.
            vocab_filepath (str): Optional path to a saved vocabulary. If None,
              the vocabulary will be (re-)created.
        """
        assert has_labels, "must have latent labels"
        self.mode = mode
        self.transform = transform
        self.has_labels = has_labels
        self.data_dir = data_dir
        self.data_dir_mode = os.path.join(data_dir, mode)
        self.latents_text_filepath = os.path.join(
            self.data_dir_mode, "latents_text.csv"
        )
        self.latents_image_filepath = os.path.join(
            self.data_dir_mode, "latents_image.csv"
        )
        self.text_filepath = os.path.join(self.data_dir_mode, "text", "text_raw.txt")
        self.image_dir = os.path.join(self.data_dir_mode, "images")

        # load text
        text_in_sentences, text_in_words = self._load_text()
        self.text_in_sentences = text_in_sentences  # sentence-tokenized text
        self.text_in_words = text_in_words  # word-tokenized text

        # determine num_samples and max_sequence_lengt    
        self.num_samples = len(self.text_in_sentences)
        self.max_sequence_length = (
            max([len(sent) for sent in self.text_in_words]) + 1
        )  # +1 for "eos"

        # load or create the vocabulary (i.e., word <-> index maps)
        self.w2i, self.i2w = self._load_vocab(vocab_filepath)
        self.vocab_size = len(self.w2i)
        if vocab_filepath:
            self.vocab_filepath = vocab_filepath
        else:
            self.vocab_filepath = os.path.join(self.data_dir, "vocab.json")

        # optionally, load ground-truth labels
        if has_labels:
            self.labels = self._load_labels()
            self.z_image = self.labels["z_image"]

        # create list of image filepaths
        image_paths = []
        width = int(np.ceil(np.log10(self.num_samples)))
        for i in range(self.num_samples):
            fp = os.path.join(self.image_dir, str(i).zfill(width) + ".png")
            image_paths.append(fp)
        self.image_paths = image_paths

        # perturbed latent variables
        self.change_lists = change_lists
        # self.use_augmentations = use_augmentations
        self.space = NBoxSpace(n=1, min_=0.0, max_=1.0)
        self.sigma = 1.0
        self.n_view = n_view
        
        self._index = self._retrieve_idx(approximate_mode)
        
        self.loader = loader
        self.transform = transform or (lambda x: x)
    
    def _retrieve_idx(self, approximate_mode=True):
        search_ndim = len(self.FACTORS["image"])
        train_data = self.z_image.values
        if approximate_mode:
            _index = faiss.index_factory(search_ndim, "IVF1024_HNSW32,Flat")
            _index.efSearch = 8
            _index.nprobe = 10
        else:
            _index = faiss.IndexFlatL2(
                search_ndim
            )  # number of latents we want to compare; exclude background color

        if approximate_mode:
            _index.train(train_data)
        _index.add(train_data)
        return _index

    def get_w2i(self, word):
        try:
            return self.w2i[word]
        except KeyError:
            return "{unk}"  # special token for unknown words

    def _load_text(self):
        print(f"Tokenization of {self.mode} data...")

        # load raw text
        with open(self.text_filepath, "r") as f:
            text_raw = f.read()

        # create sentence-tokenized text
        text_in_sentences = sent_tokenize(text_raw)

        # create word-tokenized text
        text_in_words = [word_tokenize(sent) for sent in text_in_sentences]

        return text_in_sentences, text_in_words

    def _load_labels(self):
        # load image labels
        z_image = pd.read_csv(self.latents_image_filepath)

        # load text labels
        z_text = pd.read_csv(self.latents_text_filepath)

        # check if all factors are present
        for v in self.FACTORS["image"].values():
            assert v in z_image.keys()
        for v in self.FACTORS["text"].values():
            assert v in z_text.keys()

        # create label dict
        labels = {"z_image": z_image, "z_text": z_text}

        return labels

    def _load_labels_by_class(self):
        # load image labels
        z_image = pd.read_csv(self.latents_image_filepath)
        # groupby object shape (or we call it class)
        z_image_by_class = z_image.groupby("object_shape")
        return z_image_by_class

    def _create_vocab(self, vocab_filepath):
        print(f"Creating vocabulary as '{vocab_filepath}'...")

        if self.mode != "train":
            raise ValueError("Vocabulary should be created from training data")

        # initialize counter and word <-> index maps
        ordered_counter = OrderedCounter()  # counts occurrence of each word
        w2i = dict()  # word-to-index map
        i2w = dict()  # index-to-word map
        unique_words = []

        # add special tokens for padding, end-of-string, and unknown words
        special_tokens = ["{pad}", "{eos}", "{unk}"]
        for st in special_tokens:
            i2w[len(w2i)] = st
            w2i[st] = len(w2i)

        for i, words in enumerate(self.text_in_words):
            ordered_counter.update(words)

        for w, _ in ordered_counter.items():
            if w not in special_tokens:
                i2w[len(w2i)] = w
                w2i[w] = len(w2i)
            else:
                unique_words.append(w)
        if len(w2i) != len(i2w):
            print(unique_words)
            raise ValueError("Mismatch between w2i and i2w mapping")

        # save vocabulary to disk
        vocab = dict(w2i=w2i, i2w=i2w)
        with io.open(vocab_filepath, "wb") as vocab_file:
            jd = json.dumps(vocab, ensure_ascii=False)
            vocab_file.write(jd.encode("utf8", "replace"))

        return vocab

    def _load_vocab(self, vocab_filepath=None):
        if vocab_filepath is not None:
            with open(vocab_filepath, "r") as vocab_file:
                vocab = json.load(vocab_file)
        else:
            new_filepath = os.path.join(self.data_dir, "vocab.json")
            vocab = self._create_vocab(vocab_filepath=new_filepath)
        return (vocab["w2i"], vocab["i2w"])
    
    
    def _get_augmented_view(self, idx, z, change_list):
        z_tilde = np.copy(z)
        # change_list = self.change_lists[k]
        for j in change_list:
            if j in self.DISCRETE_FACTORS["image"]:
                z_tilde[j] = self.IMAGE_SPACES[j].uniform(size=1,
                        original=z[j],
                        device="cpu").numpy().flatten()
            else:
                z_tilde[j] = self.IMAGE_SPACES[j].uniform(size=1, device="cpu").numpy().flatten()
                
            # in the same class
        _, index_z_tilde = self._index.search(
            z_tilde[np.newaxis], 2
        )  # search for 2 nearst neighbors
        if index_z_tilde[0, 0] != idx:
            index_z_tilde = index_z_tilde[0, 0]
        else:
            index_z_tilde = index_z_tilde[0, 1]

        return index_z_tilde

    def __getitem__(self, idx):
        samples = self.__get_img_text__(idx)
        z_dict = samples["z_image"]
        z_values = np.fromiter(z_dict.values(), dtype=float)

        for k, v in samples.items():
            samples[k] = [v]

        # iterate over number of views and perturb different latents to generate augmented views
        for k in range(len(self.change_lists)):
            index_z_tilde = self._get_augmented_view(
                idx=idx,
                z=z_values,
                change_list=self.change_lists[k],
            )
            sample = self.__get_img_text__(index_z_tilde)
            for key, v in sample.items():
                samples[key] += [v]

        # only use the text for the original view
        samples["text"] = [samples["text"][0]]
        samples["z_text"] = [samples["z_text"][0]]

        return samples

    def __get_img_text__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # load image
        img_name = self.image_paths[idx]
        image = pil_loader(img_name)
        if self.transform is not None:
            image = self.transform(image)

        # load text
        words = self.text_in_words[idx]
        words = words + ["{eos}"]
        words = words + ["{pad}" for c in range(self.max_sequence_length - len(words))]
        indices = [self.get_w2i(word) for word in words]
        indices_onehot = torch.nn.functional.one_hot(
            torch.Tensor(indices).long(), self.vocab_size
        ).float()

        # load labels
        if self.has_labels:
            z_image = {k: v[idx] for k, v in self.labels["z_image"].items()}
            z_text = {k: v[idx] for k, v in self.labels["z_text"].items()}
        else:
            z_image, z_text = None, None

        sample = {
            "image": image,
            "text": indices_onehot,
            "z_image": z_image,
            "z_text": z_text,
        }
        return sample

    def __len__(self):
        return self.num_samples