from typing import List, Union

from pathlib import Path
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from src.datasets import (
    CondColoredMNIST
)


DATA_PATH = Path(__file__).resolve().parents[2] / "data"

MEANS = {
    "MNIST":            [0.130660,], 
    "COND-COLOR-MNIST": [0.518488, 0.42173, 0.516175],
}

STDS = {
    "MNIST":            [0.308108,],
    "COND-COLOR-MNIST": [0.499658, 0.493836, 0.499738],
}

def get_dataset(dataset, **kwargs):
    if dataset.upper()=='MNIST':
        return mnist(**kwargs)
    elif dataset.upper()=='COND-COLOR-MNIST':
        return mnist_rgb(**kwargs)
    else:
        raise ValueError


#==========MNIST====================
def mnist(
        root: str=None,
        train: bool=True,
        desired_classes: Union[int, List[int]]=None,
    ):
    
    dataset_name = "MNIST"

    root = root or DATA_PATH
    mean, std = MEANS[dataset_name], STDS[dataset_name]
    
    if desired_classes is None:
        desired_classes = [i for i in range(10)]
    elif isinstance(desired_classes, int):
        desired_classes = [desired_classes]

    dataset = datasets.MNIST(
        root=root,
        train=train,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
            transforms.Pad(2),
        ]),
        download=True,
    )

    indices = torch.tensor([label in desired_classes for label in dataset.targets])
    dataset.data = dataset.data[indices]
    dataset.targets = dataset.targets[indices]
        
    return dataset, 1, 32, len(desired_classes), mean, std, True


#==========RGB MNIST====================
def mnist_rgb(
        root: str=None,
        train: bool=True,
        desired_classes: Union[int, List[int]]=None,
        p: float=0.2
    ):

    root = root or DATA_PATH
    mean, std = MEANS["COND-COLOR-MNIST"], STDS["COND-COLOR-MNIST"]
    
    if desired_classes is None:
        desired_classes = [i for i in range(10)]
    elif isinstance(desired_classes, int):
        desired_classes = [desired_classes]

    dataset = CondColoredMNIST(
        root=root,
        train=train,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
            transforms.Pad(2),
        ]),
        download=True,
        p=p,
    )

    indices = torch.tensor([label in desired_classes for label in dataset.targets])
    dataset.data = dataset.data[indices]
    dataset.targets = dataset.targets[indices]
        
    return dataset, 3, 32, dataset.n_classes, mean, std, True