import torch
import itertools
import numpy as np
from pathlib import Path
from torch.utils import data
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

def add_dataset_args(parser):
    parser.add_argument('--data_dir', type=str)

class Dataset(data.Dataset):
    def __init__(self, args,
                 is_train=True, with_hierarchy=False, num_cluster=3) -> None:
        super().__init__()

        self.args = args
        self.is_train = is_train
        self.transform = ToTensor()
        self.data_dir = Path(args.data_dir)

        data = MNIST(
            str(self.data_dir),
            train=self.is_train,
            download=True
        )
        self.data = data.data.numpy()
        self.features = data.targets    

        if with_hierarchy:
            self.data, self.features = self.filter_data(data, num_cluster)
            self.colors = self.get_colors(self.features, num_cluster)

    def filter_class(self, c=0):
        filter_index = np.where(self.features == c)
        self.data = self.data[filter_index]
        return 

    def filter_data(self, data, num_cluster):
        targets = self.features.numpy()
        n_samples = 300

        if num_cluster == 3: 
            filter_index_0 = np.where((targets == 0) |
                                      (targets == 2) |
                                      (targets == 6))[0]
            filter_index_1 = np.where((targets == 1) |
                                      (targets == 4) |
                                      (targets == 7) |
                                      (targets == 9))[0]
            filter_index_3 = np.where((targets == 3) |
                                      (targets == 5) |
                                      (targets == 8))[0]

            filter_index_0 = np.random.choice(filter_index_0, n_samples)
            filter_index_1 = np.random.choice(filter_index_1, n_samples)
            filter_index_3 = np.random.choice(filter_index_3, n_samples)

            filter_index   = np.concatenate(
                             (filter_index_0,
                              filter_index_1,
                              filter_index_3), axis = 0) 
        elif num_cluster == 5:
            filter_index_0 = np.where(targets == 0)[0]
            filter_index_1 = np.where(targets == 1)[0]
            filter_index_2 = np.where((targets == 2) | 
                                      (targets == 6))[0]
            filter_index_3 = np.where((targets == 3) |
                                      (targets == 5) |
                                      (targets == 8))[0]
            filter_index_4 = np.where((targets == 4) |
                                      (targets == 7) |
                                      (targets == 9))[0]

            filter_index_0 = np.random.choice(filter_index_0, n_samples)
            filter_index_1 = np.random.choice(filter_index_1, n_samples)
            filter_index_2 = np.random.choice(filter_index_2, n_samples)
            filter_index_3 = np.random.choice(filter_index_3, n_samples)
            filter_index_4 = np.random.choice(filter_index_4, n_samples)

            filter_index   = np.concatenate(
                             (filter_index_0,
                              filter_index_1,
                              filter_index_2, 
                              filter_index_3, 
                              filter_index_4), axis=0)
        return self.data[filter_index], self.features[filter_index]

    def get_colors(self, targets, num_cluster):
        def switch(key, num_cluster):
            if num_cluster == 3:
                color = {0:'0,2,6', 2:'0,2,6', 6:'0,2,6',
                         1:'1,4,7,9', 4:'1,4,7,9', 7:'1,4,7,9', 9:'1,4,7,9',
                         3:'3,5,8', 5:'3,5,8', 8:'3,5,8'
                        }
            elif num_cluster == 5:
                color = {0:'0',
                         1:'1',
                         2:'2,6', 6:'2,6',
                         3:'3,5,8', 5:'3,5,8', 8:'3,5,8',
                         4:'4,7,9', 7:'4,7,9', 9:'4,7,9',
                         }
            return color.get(key.item())
        return list(map(switch, targets, itertools.repeat(num_cluster, len(targets))))

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        x = self.data[idx]
        x = self.transform(x)
        x[x >= 0.5] = 1.
        x[x < 0.5] = 0.
        return x
