__all__ = ['build_model', 'load_model']

from .ff_net import FeedForwardNet
from .adv_debias import AdvDebiasing
from utils import get_checkpoint_path

CLASSES = [FeedForwardNet, AdvDebiasing]
CLASS_DICT = {cls.name: cls for cls in CLASSES}


def build_model(name=None, **kwargs):
    try:
        cls = CLASS_DICT[name]
    except KeyError:
        raise ValueError(f"Model class {name} not found. Available classes: {list(CLASS_DICT.keys())}")

    return cls(**kwargs)


def load_model(run_id, name=None, **kwargs):
    checkpoint_path = get_checkpoint_path(run_id)
    cls = CLASS_DICT[name]
    model = cls.load_from_checkpoint(checkpoint_path, **kwargs)
    return model
