from typing import Dict, Any

import torch.nn as nn

from .codebook import Codebook
from .mid_encoder import MidEncoder
from .adapter import get_adapter


def register_mid_encoder(codebook_cfg: Dict[str, Any], model: nn.Module) -> MidEncoder:
    codebook = Codebook(**codebook_cfg["codebook_cfg"])
    adapter = get_adapter(codebook_cfg["adapter_name"])
    encoder = MidEncoder(model, codebook_cfg["encoder_layer"], codebook, adapter)
    return encoder
