# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import numpy as np
import torch.utils.data
import os
from torchvision import datasets, transforms, utils
from torch import nn
from torch.nn import functional as F

import pathlib
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler


def subsample_dsprite():
    import json
    import re
    with open(os.path.join('conf', 'mini_dsprites.json'),'r+') as f:
        dsprite_conf = json.load(f)
        
    shape_list = [int (x) for x in re.findall(r"\d+",dsprite_conf['shape'])]
    scale_list = [int (x) for x in re.findall(r"\d+",dsprite_conf['scale'])]
    orien_list = np.arange(0,40,int(40/int(dsprite_conf['orien']))).tolist()
    xposi_list = np.arange(0,32,int(32/int(dsprite_conf['xposi']))).tolist()
    yposi_list = np.arange(0,32,int(32/int(dsprite_conf['yposi']))).tolist()
    
    select_idx = []
    
    for s1 in shape_list:
        for s2 in scale_list:
            for o in orien_list:
                for x in xposi_list:
                    for y in yposi_list:
                        select_idx.append(s1*245760 + s2*40960 + o*1024 + x*32 + y)    
    return select_idx


class DSpritesDataset(Dataset):
    """D Sprites dataset."""
    def __init__(self, path_to_data, subsample=1, transform=None):
        """
        Parameters
        ----------
        subsample : int
            Only load every |subsample| number of images.
        """
        # self.subsample_idx = subsample_dsprite()
        self.imgs = np.load(path_to_data)['imgs'] # [self.subsample_idx]
        self.labels = np.load(path_to_data)['latents_classes'] # [self.subsample_idx]
        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        # Each image in the dataset has binary values so multiply by 255 to get
        # pixel values
        sample = self.imgs[idx] * 255
        label = self.labels[idx]
        # Add extra dimension to turn shape into (H, W) -> (H, W, C)
        sample = sample.reshape(sample.shape + (1,))

        if self.transform:
            sample = self.transform(sample)
        # Since there are no labels, we just return 0 for the "label" here
        return sample, label


class DSpritesDatasetRefer(Dataset):
    """DSprites data set for referential games"""

    def __init__(self, path_to_data:str, game_size:int, subsample:int=1, transform:list=None) -> None:
        """
        Parameters
        ----------
        game_size:
            The number of candidates. If game_size is $n$, then there are $n-1$ distractors for every item.
        subsample:
            Only load every |subsample| number of images.
        """
        # self.subsample_idx = subsample_dsprite()
        self.imgs = np.load(path_to_data)['imgs'] # [self.subsample_idx]
        self.labels = np.load(path_to_data)['latents_classes'] # [self.subsample_idx]
        self.game_size = game_size
        self.transform = transform

    def __len__(self) -> int:
        return len(self.imgs)

    def __getitem__(self, idx: int):
        target_img = self.imgs[idx]
        target_img = target_img.reshape(target_img.shape + (1,))
        # target_latents = self.labels[idx] # NOTE: could be useful in the future

        sample_indices = np.random.choice(len(self.imgs), self.game_size, replace=False)
        if not idx in sample_indices:
            sample_indices[np.random.randint(self.game_size, size=1)[0]] = idx

        candidates = list(self.imgs[sample_indices])
        for i in range(len(candidates)):
            candidates[i] = candidates[i].reshape(candidates[i].shape + (1,))

        if self.transform:
            target_img = self.transform(target_img)
            for i in range(len(candidates)):
                candidates[i] = self.transform(candidates[i])

        # build labels for the correct candidate
        assert len(np.where(sample_indices == idx)[0]) == 1
        labels = torch.from_numpy(np.where(sample_indices==idx)[0])

        return target_img, labels, candidates

def get_dsprites_dataloader(batch_size=64,
                            validation_split=.2,
                            random_seed=42,
                            shuffle=True,
                            path_to_data=os.path.join('data', 'dsprites.npz'),
                            referential=False,
                            game_size=15,
                           ):
    """DSprites dataloader."""
    if referential:
        dsprites_data = DSpritesDatasetRefer(path_to_data, game_size, transform=transforms.ToTensor())
    else:
        dsprites_data = DSpritesDataset(path_to_data, transform=transforms.ToTensor())

    dataset_size = len(dsprites_data)
    indices = list(range(dataset_size))
    split = int(np.floor(validation_split * dataset_size))

    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)

    train_loader = torch.utils.data.DataLoader(dsprites_data, batch_size=batch_size,
                                               sampler=train_sampler)
    validation_loader = torch.utils.data.DataLoader(dsprites_data, batch_size=batch_size,
                                                    sampler=valid_sampler)
    return train_loader, validation_loader


class SymbolicDataset(Dataset):
    """Manually built symbolic dataset."""
    def __init__(self, 
                 n_attributes:int, 
                 n_values:int, 
                 referential=False, 
                 game_size=2, 
                 transform=None,
                 contrastive=False,
    ) -> None:
        self.transform = transform
        self.samples, self.labels = self._build_samples(n_attributes, n_values)
        self.referential = referential
        self.game_size = game_size
        self.contrastive = contrastive
        
    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx) -> tuple:
        sample = self.samples[idx]
        generative_label = self.labels[idx]

        if self.transform:
            sample = self.transform(sample)
            generative_label = self.transform(generative_label)

        if not self.referential:
            return sample, generative_label

        if self.contrastive:
            return sample, generative_label, sample

        target_sample = sample

        sample_indices = np.random.choice(len(self.samples), self.game_size, replace=False)
        if not idx in sample_indices:
            sample_indices[np.random.randint(self.game_size, size=1)[0]] = idx

        candidates = list(self.samples[sample_indices])

        if self.transform:
            for i in range(len(candidates)):
                candidates[i] = self.transform(candidates[i])

        # build labels for the correct candidate
        assert len(np.where(sample_indices == idx)[0]) == 1
        candidate_label = torch.from_numpy(np.where(sample_indices==idx)[0])

        return target_sample, [generative_label, candidate_label], candidates

    @staticmethod
    def _build_samples(n_a, n_v) -> tuple:
        values = list(np.ndindex(tuple([n_v]*n_a)))
        v_dict = np.eye(n_v)

        sample_matrix = []
        label_matrix = []
        for value in values:
            sample = [v_dict[v] for v in value]
            sample = np.concatenate(sample, axis=0)
            sample_matrix.append(sample)
            label_matrix.append(list(value))

        return np.stack(sample_matrix), np.stack(label_matrix)


def get_symbolic_dataloader(
    n_attributes:int=3,
    n_values:int=6,
    batch_size=64,
    validation_split=.2,
    random_seed=1234,
    shuffle=True,
    referential=False,
    game_size=2,
    contrastive=False,
):
    """
    Key Parameters
        ----------
        n_attributes, n_values:
            The number of attributes and possible values in the manually built dataset, thus there will be 
            $n_{values}^n_{attributes}$ samples in total.
        game_size:
            The number of candidates. If game_size is $n$, then there are $n-1$ distractors for every item.
    """
    dataset = SymbolicDataset(n_attributes, n_values, referential, game_size, torch.FloatTensor, contrastive)

    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(validation_split * dataset_size))

    """
    # original implementation of data loaders
    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)

    train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
    validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler) \
                    if len(val_indices)>0 else None
    """
                    
    np.random.seed(random_seed)
    np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    train_set = SymbolicDataset(n_attributes, n_values, referential, game_size, torch.FloatTensor, contrastive)
    train_set.samples = dataset.samples[train_indices]
    train_set.labels = dataset.labels[train_indices]
    val_set = SymbolicDataset(n_attributes, n_values, referential, game_size, torch.FloatTensor, contrastive)
    val_set.samples = dataset.samples[val_indices]
    val_set.labels = dataset.labels[val_indices]
    
    train_loader = torch.utils.data.DataLoader(train_set, shuffle=shuffle, batch_size=batch_size)
    validation_loader = torch.utils.data.DataLoader(val_set, shuffle=shuffle, batch_size=batch_size) \
                    if len(val_indices)>0 else None
    
    return train_loader, validation_loader


if __name__ == "__main__":
    from matplotlib import pyplot as plt
    def show_density(imgs):
      _, ax = plt.subplots()
      ax.imshow(imgs.mean(axis=0), interpolation='nearest', cmap='Greys_r')
      ax.grid('off')
      ax.set_xticks([])
      ax.set_yticks([])
    train, valid =  get_dsprites_dataloader(1)
    cnt = 0
    imgs_sampled = []
    for img, lab in valid:
        imgs_sampled.append(img.squeeze().numpy())
        cnt+=1
    for img, lab in train:
        imgs_sampled.append(img.squeeze().numpy())
        cnt+=1        
    
    show_density(np.array(imgs_sampled))
    print(cnt)
