
import numpy as np
import torch
import torchvision
import torch.utils.data as data

from tqdm import tqdm
from torch.utils.data.dataloader import default_collate


class DisentangledDataset(data.Dataset):
    def __init__(self,
                 root,
                 transforms_list=[]):
        self.root = root
        self.transforms = torchvision.transforms.Compose(transforms_list)
        self.data, self.latents_values, self.latents_classes = None, None, None
        self.classes_dict = {}
        # self.train_data = os.path.join(path, type(self).files["train"])
        # self.train_data = os.path.join(root, files["train"])
        # if not os.path.isdir(root):
        #     # self.logger.info("Downloading {} ...".format(str(type(self))))
        #     self.download()
            # self.logger.info("Finished Downloading.")
    def set_classes_dict(self):
        for i, classes in enumerate(self.latents_classes):
            self.classes_dict[classes.tobytes()] = i
        return

    def __getitem__(self, idx):
        raise NotImplementedError("Build getitem function")

    def download(self):
        """Download the dataset. """
        pass

    def __len__(self):
        return len(self.data)


    # set for disentanglement metric
    def random_sampling_for_disen_global_variance(self, batch_size, replace=False):
        # manual_seed(self.random_seed)
        samples = []
        g = np.random.Generator(np.random.PCG64(seed=np.random.randint(0, 2**32)))
        indices = g.choice(len(self.data), batch_size, replace=replace)
        for idx in indices:
            samples.append(self.transforms(self.data[idx]).to(torch.float32))
            # pdb.set_trace()
        samples = torch.stack(samples, dim=0)
        return samples # self.transforms(self.data[indices])

    def sampling_factors_and_img(self, batch_size, num_train):
        dataset_size = len(self.data)
        idxs = list(range(dataset_size))
        factors, imgs = [], []
        # manual_seed(self.random_seed)
        for i in tqdm(range(num_train)):
            img = []
            np.random.shuffle(idxs)
            factor_idxs = idxs[:batch_size]
            factors.append(
                torch.Tensor(self.latents_classes[factor_idxs])
            )  # (B, num factors -1)
            for idx in factor_idxs:
                img.append(self.transforms(self.data[idx]).to(torch.float32))
            img = torch.stack(img, dim=0)
            imgs.append(img)  # (B, C, H, W)

        return torch.stack(imgs, dim=0), torch.stack(
            factors, dim=0
        )  # (num_train, B, C, H, W), (num_train, B, -1)

    def find_index_from_factor(self, factor):
        sampled_idx = self.classes_dict[factor.tobytes()]
        return sampled_idx


    def img_from_idx(self, idx):
        return self.data[idx]

    def factor_from_idx(self, idx):
        return self.latents_classes[idx]

    # def idx_from_factor(self, factor):
    #     pdb.set_trace()
    #     base = np.concatenate(
    #         self.latents_values[1:][::-1].astype(int).cumprod()[::-1][1:],
    #         np.array(
    #             [
    #                 1,
    #             ]
    #         ),
    #     )
    #     return np.dot(factor, base).astype(int)

    def dataset_sample_batch(self, num_samples, mode, replace=False):
        g = np.random.Generator(np.random.PCG64(seed=np.random.randint(0, 2**32)))
        indices = g.choice(len(self), num_samples, replace=replace)
        return self.dataset_batch_from_indices(indices, mode=mode)

    def dataset_batch_from_indices(self, indices, mode):
        return default_collate([self.dataset_get(idx, mode=mode) for idx in indices])

    def dataset_get(self, idx, mode: str):
        try:
            idx = int(idx)
        except:
            raise TypeError(f"Indices must be integer-like ({type(idx)}): {idx}")


    def factor_to_idx(self, factor):
        raise NotImplementedError("Build getitem function")
