from hmpn import get_hmpn_from_graph

from ltsgns_mp.architectures.simulators.gnn.egno.egno import EGNO, get_egno_input
from ltsgns_mp.util import keys


def get_gnn(gnn_config, example_input_batch, device):
    if gnn_config.name == "hmpn_gnn":
        return get_hmpn_from_graph(example_graph=example_input_batch,
                            latent_dimension=gnn_config.latent_dimension,
                            node_name=keys.MESH,
                            unpack_output=False,  # return full graph
                            base_config=gnn_config.base,
                            device=device)
    elif gnn_config.name == "egno":
        x, v, h, edge_index, edge_fea = get_egno_input(example_input_batch)
        num_timesteps = example_input_batch.context_node_positions.shape[1]  # traj length
        return EGNO(n_layers=gnn_config.num_steps, in_node_nf=h.shape[-1], in_edge_nf=edge_fea.shape[-1],
                    hidden_nf=gnn_config.latent_dimension, device=device, use_time_conv=gnn_config.use_time_conv,
                    num_timesteps=num_timesteps, time_emb_dim=gnn_config.time_emb_dim)

