"""Pytorch Dataset object that loads MNIST and SVHN. It returns x,y,s where s=0 when x,y is taken from MNIST."""

import numpy as np
import torch.utils.data as data_utils
import torch
from torchvision import transforms
from PIL import Image
import os
import glob


class Causal3D(data_utils.Dataset):
    def __init__(self,
                 root,
                 list_train_domains=[0, 1, 2],
                 train=True,
                 download=True,
                 max_samples=-1,
                 random_seed=0,
                 return_hue=True,
                 transform=None):
        """
        :param list_train_domains: all domains we observe in the training
        :param root: data directory`
        :param train: whether to load MNIST training data
        :param mnist_subset: 'max' - for each domain, use 60000 MNIST samples, 'med' - use 10000 MNIST samples, 'min' - use 1000 MNIST samples
        :param transform: ...
        :param download: ...
        :param list_test_domains: whether to load unseen domains (this might be removed later, but I don't have time to optimize the code at this point)
        :param num_supervised: whether to further subsample
        """

        self.root = os.path.expanduser(root)
        self.train = train
        self.download = download
        self.list_train_domains = list_train_domains
        self.n_domains = len(list_train_domains)
        self.max_samples = max_samples
        self.random_seed = random_seed
        self.return_hue = return_hue
        self.transform = transform

        self.data, self.label, self.domain, self.hue = self._get_data()

    def _get_data(self):

        if self.train:
            root_dir = os.path.join(self.root, 'trainset')
        else:
            root_dir = os.path.join(self.root, 'testset')

        # load all image_dir for each class
        images = []
        label = []
        domain = []
        hue = []
        for cdx in range(7):
            # ---- get all X ---- #
            img_cur_class = glob.glob(f'{root_dir}/images_{cdx}/*.png', recursive=True)

            # ---- get all Y ---- #
            label_cur_class = np.ones(len(img_cur_class)) * cdx

            # ---- get all D ---- #
            # load all latent codes
            latent_code = np.load(f'{root_dir}/latents_{cdx}.npy')
            bg_hue_code = latent_code[:, 9]

            # hardcode the D range
            lower_bounds = [-np.pi / 2, 0.8]
            upper_bounds = [-0.8, np.pi / 2]

            segments = []
            all_index = np.arange(len(bg_hue_code))
            # Split the array based on bounds
            for lower, upper in zip(lower_bounds, upper_bounds):
                segment = (bg_hue_code >= lower) & (bg_hue_code <= upper)
                segments.append(all_index[segment])

            # create a pseduo domain label for filtering
            domain_cur_class = np.ones(len(img_cur_class)) * 100
            for d in self.list_train_domains:
                domain_cur_class[segments[d]] = d

            # ---- get samples based on chosen domain range ---- #
            for d in self.list_train_domains:
                # chosen_img = img_cur_class[domain_cur_class==d]
                chosen_img = [img_cur_class[idx] for idx in range(len(img_cur_class)) if domain_cur_class[idx] == d]
                chosen_label = label_cur_class[domain_cur_class == d]
                chosen_domain = domain_cur_class[domain_cur_class == d]
                chosen_hue = bg_hue_code[domain_cur_class == d]

                images += chosen_img
                label.append(chosen_label)
                domain.append(chosen_domain)
                hue.append(chosen_hue)

        label = np.concatenate(label)
        domain = np.concatenate(domain)
        hue = np.concatenate(hue)

        # # subsample down to the max number of samples
        # if self.max_samples > 0:
        #     indices = np.random.RandomState(self.random_seed).choice(len(images), self.max_samples, replace=False)
        #     images = [images[i] for i in indices]
        #     label = label[indices]
        #     domain = domain[indices]
        #     hue = hue[indices]

        return images, torch.Tensor(label).long(), torch.Tensor(domain).long(), torch.Tensor(hue)

    def __len__(self):

        return len(self.domain)

    def __getitem__(self, index):
        dir_x = self.data[index]
        x = Image.open(dir_x).convert('RGB')
        y = self.label[index]
        d = self.domain[index]
        h = self.hue[index]

        initial_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((64, 64)),
        ])
        x = initial_transform(x)
        if self.transform is not None:
            x = self.transform(x)

        if self.return_hue:
            return x, y, d, h
        else:
            return x, y, d

