import importlib

from custom_models.CustomCausalModel import CustomCausalModel
import networkx as nx

# Load function

def load_custom_model(model_name: str, causal_graph: nx.DiGraph) -> CustomCausalModel:
    """
    Load a custom causal model based on the given model name and causal graph.

    Args:
        model_name (str): The name of the custom model to load.
        causal_graph (nx.DiGraph): The causal graph to initialize the model with.

    Returns:
        CustomCausalModel: The loaded custom causal model.

    Raises:
        ImportError: If the module for the custom model cannot be imported.
        AttributeError: If the custom model class cannot be found in the module.
    """

    # Returns the requested model, which is contained in this folder and is a subclass of the CustomCausalModel class.
    # The model is initialized with the causal graph.
    module_path = f"custom_models.{model_name}"
    module = importlib.import_module(module_path)

    model_class = getattr(module, model_name.upper())

    # Initialize the model with the causal graph
    model = model_class(causal_graph)

    return model
