from models.base import BaseModel
from models.masked_mlp import MaskedMLP, MaskedMLP_no_onehot
from models.logZ import LogZModule
from models.rope_vit import get_rope_vit_model, RopeVIT
from models.ema import ExponentialMovingAverage

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    import torch
    from omegaconf import DictConfig
    from targets.base import BaseTarget


def create_model(
    cfg: "DictConfig", target: "BaseTarget", device: "torch.device"
) -> tuple["BaseModel", "LogZModule | None"]:
    """Create a model based on the configuration.

    Args:
        cfg: Hydra configuration.
        target: Target distribution instance.
        device: Device to place tensors on.

    Returns:
        A model instance and an EMA instance, if ema_decay is specified.
    """
    model_cfg = cfg.target.model  # We select model based on targets
    if model_cfg.name == "mlp":
        model = MaskedMLP(
            ndim=target.ndim,
            vocab_size=target.q + 1,  # target's vocab size + Mask
            hidden_dim=model_cfg.hidden_dim,
            n_layers=model_cfg.n_layers,
        )
        print(f"Model: MaskedMLP with {model_cfg.hidden_dim} hidden dim")

    elif model_cfg.name == "rope_vit":
        assert cfg.target.name in ["ising", "potts"]

        model = get_rope_vit_model(
            L=target.L,
            embed_dim=model_cfg.hidden_dim,
            depth=model_cfg.n_blocks,
            n_heads=model_cfg.n_heads,
            vocab_size=target.q + 1,  # target's vocab size + Mask
            dtype=model_cfg.dtype,
            device_type=device.type,
        )
        print(f"Model: RopeViT with {model_cfg.n_blocks} blocks, {model_cfg.n_heads} heads")
    else:
        raise ValueError(f"Unknown model type: {model_cfg.name}")

    model = model.to(device)

    logZ_module = None
    if cfg.algorithm.loss_type == "tb":
        log_Z_init = float(cfg.algorithm.log_Z_init)
        logZ_module = LogZModule(init_value=log_Z_init).to(device)
        print(f"Using TB loss with log_Z_lr={cfg.algorithm.log_Z_lr}, log_Z_init={log_Z_init}")

    return model, logZ_module
