import torch
from .dit import DiT
from .ar_transformer import ARTransformer
from .dit_orig import DIT as DiTOrig
from .ar_orig import AR as AROrig


def get_backbone(config, vocab_size) -> torch.nn.Module:
    # set backbone
    mtype = config.model.type
    if mtype == "ddit":
        backbone = DiT(config, vocab_size=vocab_size, adaptive=config.time_conditioning)
    elif mtype == "ar_transformer":
        # Adaptive means there is conditional information -> not the case with AR modeling
        backbone = ARTransformer(config, vocab_size=vocab_size)
    elif mtype == "ddit-orig":
        backbone = DiTOrig(config, vocab_size)
    elif mtype == "ar_orig":
        backbone = AROrig(config, vocab_size + 1, vocab_size)
    else:
        raise ValueError(f"Unknown backbone: {config.backbone}")

    return backbone
