import copy
from omegaconf import DictConfig, OmegaConf

import torchvision.datasets as torchVD
import torchaudio.datasets as torchAD

from pado.core.base.dataset import PadoDataset

DATASET_REGISTRY = {}  # {name: dataset_cls}

__all__ = ["register_dataset", "build_dataset", "build_datasets"]


def register_dataset(name: str):
    """Decorator for register dataset."""

    def register_dataset_cls(cls):
        if name in DATASET_REGISTRY:
            raise ValueError(f"Cannot register duplicated dataset {name}.")
        if hasattr(torchVD, name):
            raise ValueError(f"Cannot register same dataset {name} as TorchVision.")
        if hasattr(torchAD, name):
            raise ValueError(f"Cannot register same dataset {name} as TorchAudio.")

        DATASET_REGISTRY[name] = cls
        return cls

    return register_dataset_cls


# ------------------------------------------------------------------------  #
from pado.data.datasets.concat import ConcatDataset  # noqa
from pado.data.datasets.vision import *  # noqa
from pado.data.datasets.speech import *  # noqa


# ------------------------------------------------------------------------  #

def build_dataset(cfg: DictConfig, transform=None, target_transform=None):
    """
    Build a single dataset.
    """
    if "name" not in cfg:
        raise ValueError("Dataset config dict should have name.")

    name = cfg["name"]
    cfg = copy.deepcopy(cfg)
    del cfg["name"]

    if name in DATASET_REGISTRY:
        dataset_cls = DATASET_REGISTRY[name]
        dataset_cls: PadoDataset
        dataset = dataset_cls.from_config(cfg, transform=transform, target_transform=target_transform)
    elif hasattr(torchVD, name):
        cfg = OmegaConf.to_container(cfg, resolve=True)
        dataset = getattr(torchVD, name)(**cfg, transform=transform, target_transform=target_transform)
    elif hasattr(torchAD, name):
        cfg = OmegaConf.to_container(cfg, resolve=True)
        dataset = getattr(torchAD, name)(**cfg)  # no transform nor target_transform.
    else:
        raise ValueError(f"Dataset {name} not found.")
    return dataset


def build_datasets(cfg: DictConfig, transform=None, target_transform=None) -> ConcatDataset:
    """
    Build multiple datasets and merge as Concat dataset.
    """
    datasets = []
    for i in range(len(cfg)):
        datasets.append(build_dataset(cfg[i], transform=transform, target_transform=target_transform))
    return ConcatDataset(datasets)
