import importlib
from .base_model import BaseModel
from torch import nn
import sys

sys.path.append("..")

MODEL_REGISTRY = {}
SUPPORTED_MODELS = {
    "IQFormer": "ddm4signal.models.IQFormer.IQFormer",
    "RF_Diffusion": "ddm4signal.models.RF_Diffusion.RF_Diffusion",
    "SpectrumFM": "ddm4signal.models.SpectrumFM.SpectrumFM",
    "AMC_Net": "ddm4signal.models.AMC_Net.AMC_Net",
    "CGDNN": "ddm4signal.models.CGDNN.CGDNN",
    "CNN2": "ddm4signal.models.CNN2.CNN2",
    "DAE": "ddm4signal.models.DAE.DAE",
    "GRU2": "ddm4signal.models.GRU2.GRU2",
    "MCNet": "ddm4signal.models.MCNet.MCNet",
    "MSNet": "ddm4signal.models.MSNet.MSNet",
    "ResNet": "ddm4signal.models.ResNet.ResNet",
    "Transformer": "ddm4signal.models.Transformer.Transformer",
    "VGG": "ddm4signal.models.VGG.VGG",
    "LDM_Encoder": "ddm4signal.models.LDM.LDM_Encoder",
    "LDM": "ddm4signal.models.LDM.LDM",
    "DDAE_Network": "ddm4signal.models.DDAE.Network",
    "custom_encoder": "ddm4signal.models.custom.custom_encoder",
    "custom": "ddm4signal.models.custom.custom",
    "signal": "ddm4signal.models.signal.signal"
}

def register_model(name):
    """
    New models types can be added with the :func:`register_model`
    function decorator.

    For example::

        @register_model('gat')
        class GAT(BaseModel):
            (...)

    Args:
        name (str): the name of the models
    """

    def register_model_cls(cls):
        if name in MODEL_REGISTRY:
            raise ValueError("Cannot register duplicate models ({})".format(name))
        if not issubclass(cls, BaseModel):
            raise ValueError(
                "Model ({}: {}) must extend BaseModel".format(name, cls.__name__)
            )
        MODEL_REGISTRY[name] = cls
        cls.model_name = name
        return cls

    return register_model_cls


def try_import_model(model):
    if model not in MODEL_REGISTRY:
        if model in SUPPORTED_MODELS:
            importlib.import_module(SUPPORTED_MODELS[model])
        else:
            print(f"Failed to import {model} models.")
            return False
    return True


def build_model(model):
    if isinstance(model, nn.Module):
        if not hasattr(model, "build_model_from_args"):
            def build_model_from_args(args):
                return model
            model.build_model_from_args = build_model_from_args
        return model
    if not try_import_model(model):
        exit(1)
    return MODEL_REGISTRY[model]