import os
from typing import Tuple, List

import torch
import torch_scatter
from hmpn import get_hmpn_from_graph
from torch_geometric.data import Data, Batch

from ltsgns_mp.architectures.decoder import get_decoder, Decoder
from ltsgns_mp.architectures.simulators.abstract_simulator import AbstractSimulator
from ltsgns_mp.architectures.simulators.gnn import get_egno_input
from ltsgns_mp.envs.train_iterator.cnp_train_iterator import CNPTrainBatch
from ltsgns_mp.util import keys
from ltsgns_mp.util.graph_input_output_util import unpack_node_features, node_type_mask
from ltsgns_mp.util.loading import get_checkpoint_iteration
from ltsgns_mp.util.own_types import ValueDict


class EGNOSimulator(AbstractSimulator):

    def __init__(self, config, example_input_batch: CNPTrainBatch, loading_config, device, trajectory_length: int | None = None):
        target_batch = example_input_batch.target_batch
        example_input_batch = target_batch[0]
        super().__init__(config, example_input_batch,
                         decoder_output_dim=-1,
                         loading_config=loading_config,
                         device=device)

        self.load_weights(loading_config, device)

    def forward(self, batch: Batch | Data, initial_time: int | None = None) -> torch.Tensor:
        x, v, h, edge_index, edge_fea = get_egno_input(batch)
        num_nodes = x.shape[0]
        world_dim = x.shape[-1]
        loc_mean = x.mean(dim=0, keepdim=True).repeat(x.shape[0], 1)
        pred_x, _, _ = self.gnn(x, h, edge_index, edge_fea, loc_mean=loc_mean, v=v)
        pred_x = pred_x.view(self.gnn.num_timesteps, num_nodes, world_dim)
        return pred_x


    def _input_dimensions(self) -> ValueDict:
        return {keys.PROCESSOR_DIMENSION: self.config.gnn.latent_dimension, }



