import importlib
from .base_pipe import BasePipe
from abc import ABC     

PIPE_REGISTRY = {}
SUPPORTED_PIPES = {
    "moco_pretrain": "ddm4signal.pipeline.moco_pretrain",
    "classification": "ddm4signal.pipeline.classification",
    "rf_ddm_pretrain": "ddm4signal.pipeline.rf_ddm_pretrain",
    "spectrumfm_pretrain": "ddm4signal.pipeline.spectrumfm_pretrain",
    "spectrumfm_tune": "ddm4signal.pipeline.spectrumfm_tune",
    "ddae_trainer": "ddm4signal.pipeline.ddae_trainer",
    "iqformer_trainer": "ddm4signal.pipeline.iqformer_trainer",
    "ldm_pretrain": "ddm4signal.pipeline.ldm_pretrain",
    "ldm_tune": "ddm4signal.pipeline.ldm_tune",
    "custom": "ddm4signal.pipeline.custom",
    "few_shot": "ddm4signal.pipeline.few_shot",
    "prob": "ddm4signal.pipeline.prob",
    "transfer": "ddm4signal.pipeline.transfer",
    "anomaly": "ddm4signal.pipeline.anomaly",
    "signal": "ddm4signal.pipeline.signal"
}

def register_pipe(name):
    """
    New pipe can be added to openhgnn with the :func:`register_pipe`
    function decorator.

    For example::

        @register_task('modulation_classification')
        class ModulationClassification(BasePipe):
            (...)

    Args:
        name (str): the name of the pipes
    """

    def register_pipe_cls(cls):
        if name in PIPE_REGISTRY:
            raise ValueError("Cannot register duplicate pipe ({})".format(name))
        if not issubclass(cls, (BasePipe,ABC)):
            raise ValueError("Pipe ({}: {}) must extend BasePipe or ABC".format(name, cls.__name__))
        PIPE_REGISTRY[name] = cls
        return cls

    return register_pipe_cls


def try_import_pipe(pipe):
    if pipe not in PIPE_REGISTRY:
        if pipe in SUPPORTED_PIPES:
            importlib.import_module(SUPPORTED_PIPES[pipe])
        else:
            print(f"Failed to import {pipe} pipes.")
            return False
    return True


def build_pipe(args, pipe_name):
    if not try_import_pipe(pipe_name):
        exit(1)
    return PIPE_REGISTRY[pipe_name](args)