from torch_geometric.data import Data, Batch

from ltsgns_mp.architectures.simulators.abstract_simulator import AbstractSimulator
from ltsgns_mp.envs.train_iterator.step_train_iterator import StepTrainBatch
from ltsgns_mp.util import keys
from ltsgns_mp.util.graph_input_output_util import unpack_node_features
from ltsgns_mp.util.own_types import ValueDict


class StepSimulator(AbstractSimulator):

    def __init__(self, config, example_input_batch: StepTrainBatch, loading_config, device):
        example_input_batch = example_input_batch.batch
        super().__init__(config, example_input_batch,
                         decoder_output_dim=example_input_batch.pos.shape[-1],
                         loading_config=loading_config,
                         device=device)

        self.load_weights(loading_config, device)

    def forward(self, batch, **kwargs):
        processed_batch = self.gnn(batch)
        mesh_features = unpack_node_features(processed_batch, node_type=keys.MESH)
        decoded_batch = self.decoder(mesh_features)  # velocities of shape (num_nodes, action_dimension)
        return decoded_batch

    def _input_dimensions(self) -> ValueDict:
        return {keys.PROCESSOR_DIMENSION: self.config.gnn.latent_dimension, }
