from torch import nn
from omegaconf import OmegaConf

from .registry import get_model_class
from .mesh import PointNet, GraphSAGE, Transolver, UPT


__all__ = ["get_model_class", "PointNet", "GraphSAGE", "Transolver", "UPT"]


def get_model(cfg, dataset) -> nn.Module:
    Model = get_model_class(cfg.model.name)

    hparams = OmegaConf.to_container(cfg.model.hparams)

    return Model(n_conds=dataset.n_conds, output_channels=dataset.n_channels, n_materials=getattr(dataset, "n_materials", None), out_deformation=cfg.dataset.out_deformation, space=cfg.dataset.space, **hparams)
    # return UPT2(radius=0.05, output_dim=dataset.n_channels, input_dim=None, dec_depth=4, num_supernodes=4096, n_conds=dataset.n_conds, latent_channels=8, app_depth=8)


# This works for overfitting
    # return UPT2(radius=0.05, output_dim=dataset.n_channels, input_dim=None, dec_depth=0, num_supernodes=4096)