import torch
from omegaconf import open_dict
from torch_geometric.data import Batch, Data

from ltsgns_mp.architectures.prodmp.prodmp import ProDMPPredictor
from ltsgns_mp.architectures.simulators.abstract_simulator import AbstractSimulator
from ltsgns_mp.envs.train_iterator.ltsgns_step_train_iterator import LTSGNSStepTrainBatch
from ltsgns_mp.util import keys
from ltsgns_mp.util.graph_input_output_util import unpack_node_features, node_type_mask
from ltsgns_mp.util.own_types import ValueDict, ConfigDict


class LTSGNS_Step_Simulator(AbstractSimulator):
    def _input_dimensions(self) -> ValueDict:
        return {keys.PROCESSOR_DIMENSION: self.config.gnn.latent_dimension,
                keys.Z_DIMENSION: self.d_z}

    def __init__(self, config: ConfigDict, example_input_batch: LTSGNSStepTrainBatch, d_z: int, loading_config: ConfigDict,
                 device: str):
        self.d_z = d_z
        batch = example_input_batch.trajectory_batch[0]
        decoder_output_dim = batch[keys.POSITIONS].shape[1]

        super().__init__(config,
                         example_input_batch=batch,
                         decoder_output_dim=decoder_output_dim,
                         loading_config=loading_config,
                         device=device)

        self._cached_gnn_output_mesh = None

        self.load_weights(loading_config, device)

    def forward(self, batch: Batch | Data, z: torch.Tensor, use_cached_gnn_output: bool = False) -> ValueDict:
        """
        Args:
        :param batch:
        :param z: Shape (num_samples, num_tasks, d_z)
        :param use_cached_gnn_output: Flag if precomputed features should be used. Set to True for posterior fitting
        :return: Prediction of the next steps of the shape (num_samples, num_tasks, batch_timesteps, mesh_nodes_per_task, dim)
        """
        if use_cached_gnn_output:
            gnn_output_mesh = self._cached_gnn_output_mesh
            assert gnn_output_mesh is not None
        else:
            processed_graph = self.gnn(batch)
            gnn_output_mesh = unpack_node_features(processed_graph, node_type=keys.MESH)

        # shape (batch_dim, traj_length, NUM_NODES, dim), want NUM_NODES
        mesh_nodes_per_task = gnn_output_mesh.shape[0]
        if isinstance(batch, Batch):
            graph = batch[0]
        else:
            graph = batch
        mesh_nodes_per_graph = graph[keys.POSITIONS][node_type_mask(graph, keys.MESH)].shape[0]
        predicted_dynamics = self._predict_dynamics(
            processor_output=gnn_output_mesh,
            nodes_per_task=mesh_nodes_per_task,
            nodes_per_graph=mesh_nodes_per_graph,
            z=z,
        )
        return predicted_dynamics

    def _predict_dynamics(self, processor_output: torch.Tensor, nodes_per_task: int, z: torch.Tensor, nodes_per_graph: int) -> torch.Tensor:
        predicted_dynamics = self.decoder(processor_output=processor_output,
                                          z=z,
                                          vertices_per_task=nodes_per_task)
        # add task dimension: always 1 task in our case
        predicted_dynamics = predicted_dynamics.unsqueeze(1)
        sh = predicted_dynamics.shape
        predicted_dynamics = predicted_dynamics.reshape((sh[0], sh[1], -1, nodes_per_graph, sh[-1]))
        # shape (num_samples, tasks, timesteps, num_nodes_per_task, dim)
        return predicted_dynamics

    def gnn_forward(self, batch):
        with torch.no_grad():
            processed_graph = self.gnn(batch)
            self._cached_gnn_output_mesh = unpack_node_features(processed_graph, node_type=keys.MESH)
