from typing import Callable
import copy
from omegaconf import DictConfig, OmegaConf

import torchvision.transforms as torchVT
import torchaudio.transforms as torchAT

from pado.core.base.transform import PadoTransform

__all__ = ["register_transform", "build_transform", "build_transforms"]

TRANSFORM_REGISTRY = {}  # {name: transform_cls}


def register_transform(name: str):
    """Decorator to register transform."""

    def register_transform_cls(cls: Callable):
        if name in TRANSFORM_REGISTRY:
            raise ValueError(f"Cannot register duplicated transform {name}.")
        if hasattr(torchVT, name):
            raise ValueError(f"Cannot register same transform {name} as TorchVision.")
        if hasattr(torchAT, name):
            raise ValueError(f"Cannot register same transform {name} as TorchAudio.")

        TRANSFORM_REGISTRY[name] = cls
        return cls

    return register_transform_cls


# ------------------------------------------------------------------------  #
from pado.data.transforms.compose import Compose  # noqa
from pado.data.transforms.audio import *  # noqa
from pado.data.transforms.text import *  # noqa


# ------------------------------------------------------------------------  #

def build_transform(cfg: DictConfig) -> Callable:
    """
    Build a single transform.
    """
    if "name" not in cfg:
        raise ValueError("Transform config dict should have name.")

    name = cfg["name"]
    cfg = copy.deepcopy(cfg)
    del cfg["name"]

    if name in TRANSFORM_REGISTRY:
        transform_cls = TRANSFORM_REGISTRY[name]
        transform_cls: PadoTransform
        transform = transform_cls.from_config(cfg)
    elif hasattr(torchVT, name):
        cfg = OmegaConf.to_container(cfg, resolve=True)
        transform = getattr(torchVT, name)(**cfg)
    elif hasattr(torchAT, name):
        cfg = OmegaConf.to_container(cfg, resolve=True)
        transform = getattr(torchAT, name)(**cfg)
    else:
        raise ValueError(f"Transform {name} not found.")
    return transform


def build_transforms(cfg: DictConfig) -> Compose:
    """
    Build multiple transforms and merge with Composed transform.
    We cannot always use nn.Sequential because some transforms are not nn.Module.
    """
    transforms = []
    for i in range(len(cfg)):
        transforms.append(build_transform(cfg[i]))
    return Compose(transforms)
