import os

import torch
from omegaconf import open_dict
from torch_geometric.data import Batch

from ltsgns_mp.architectures.prodmp import build_prodmp
from ltsgns_mp.architectures.prodmp.prodmp import ProDMPPredictor
from ltsgns_mp.architectures.simulators.abstract_simulator import AbstractSimulator
from ltsgns_mp.envs.train_iterator.trajectory_train_iterator import TrajectoryTrainBatch
from ltsgns_mp.util import keys
from ltsgns_mp.util.graph_input_output_util import unpack_node_features, node_type_mask
from ltsgns_mp.util.loading import get_checkpoint_iteration
from ltsgns_mp.util.own_types import ValueDict, ConfigDict


class LTSGNS_MP_Simulator(AbstractSimulator):
    def _input_dimensions(self) -> ValueDict:
        return {keys.PROCESSOR_DIMENSION: self.config.gnn.latent_dimension,
                keys.Z_DIMENSION: self.d_z}

    def __init__(self, config: ConfigDict, example_input_batch: TrajectoryTrainBatch, d_z: int, loading_config: ConfigDict,
                 device: str):
        example_input_batch = example_input_batch.batch
        self.d_z = d_z
        self._trajectory_length = example_input_batch[keys.CONTEXT_NODE_POSITIONS].shape[1]
        mp_predictor = build_prodmp(example_input_batch=example_input_batch,
                                    simulator_config=config,
                                    trajectory_length=self._trajectory_length,
                                    device=device)
        super().__init__(config,
                         example_input_batch=example_input_batch,
                         decoder_output_dim=mp_predictor.output_size,
                         loading_config=loading_config,
                         device=device)
        self._mp_predictor = mp_predictor
        self._mode = "gnn_step"  # "training mode". May be either "posterior_step" or "gnn_step"

        self._cached_gnn_output_mesh = None
        self._cached_gnn_output_collider = None

        self.load_weights(loading_config, device)

    def forward(self, batch, z: torch.Tensor, use_cached_gnn_output: bool = False,
                predict_context_timesteps: bool = False) -> ValueDict:
        """
        Args:
        :param batch:
        :param z: Shape (num_samples, num_tasks, d_z)
        :param use_cached_gnn_output: Flag if precomputed features should be used. Set to True for posterior fitting
        :param predict_context_timesteps: Flag if the context timesteps should be predicted or the full trajectory
        :return: Dictionary with the following keys:
            - context_predictions: Shape (num_samples, num_tasks, context_timesteps, mesh_nodes_per_task, dim)
            - collider_context_predictions: Shape (num_samples, num_tasks, collider_context_timesteps, collider_nodes_per_task, dim)
            - full_trajectory_predictions: Shape (num_samples, num_tasks, trajectory_length, num_nodes_per_task, dim)
        Depending on the flags, either the first 2 or the last value will be None
        """

        results_dict = {
            "context_predictions": None,
            "collider_context_predictions": None,
            "full_trajectory_predictions": None,
        }
        env_has_collider = keys.COLLIDER in batch[0].node_type_description  # remove batch dim hence batch[0]

        if use_cached_gnn_output:
            gnn_output_mesh = self._cached_gnn_output_mesh
            gnn_output_collider = self._cached_gnn_output_collider
            assert gnn_output_mesh is not None
            if env_has_collider:
                assert gnn_output_collider is not None
        else:
            processed_graph = self.gnn(batch)
            gnn_output_mesh = unpack_node_features(processed_graph, node_type=keys.MESH)
            if env_has_collider:
                gnn_output_collider = unpack_node_features(processed_graph, node_type=keys.COLLIDER)
            else:
                gnn_output_collider = None

        anchor_indices = batch[keys.ANCHOR_INDICES]

        # shape (batch_dim, traj_length, NUM_NODES, dim), want NUM_NODES
        mesh_nodes_per_task = batch[keys.CONTEXT_NODE_POSITIONS].shape[2]
        mesh_init_pos = batch[keys.POSITIONS][node_type_mask(batch, keys.MESH)]
        mesh_init_vel = self._get_init_vel(batch, batch[keys.CONTEXT_NODE_POSITIONS], mesh_init_pos)
        if env_has_collider:
            collider_nodes_per_task = batch[keys.CONTEXT_COLLIDER_POSITIONS].shape[2]
            collider_init_pos = batch[keys.POSITIONS][node_type_mask(batch, keys.COLLIDER)]
            collider_init_vel = self._get_init_vel(batch, batch[keys.CONTEXT_COLLIDER_POSITIONS], collider_init_pos)
        context_type = batch[keys.CONTEXT_TYPE][0]
        if context_type == keys.MESH:
            context_timesteps = batch["mesh_indices"]
        elif context_type == keys.POINT_CLOUD:
            context_timesteps = batch["point_cloud_indices"]
        else:
            raise ValueError("Unknown context type: {}".format(context_type))
        collider_context_timesteps = None  # predict over all timesteps as the collider is given everytime

        if predict_context_timesteps:
            results_dict["context_predictions"] = self._predict_node_trajectories(
                anchor_node_group=gnn_output_mesh,
                nodes_per_task=mesh_nodes_per_task,
                z=z,
                node_group_init_pos=mesh_init_pos,
                node_group_init_vel=mesh_init_vel,
                anchor_indices=anchor_indices,
                context_timesteps=context_timesteps,
                batch_size=batch.num_graphs,
            )
        else:
            results_dict["full_trajectory_predictions"] = self._predict_node_trajectories(
                anchor_node_group=gnn_output_mesh,
                nodes_per_task=mesh_nodes_per_task,
                z=z,
                node_group_init_pos=mesh_init_pos,
                node_group_init_vel=mesh_init_vel,
                anchor_indices=anchor_indices,
                context_timesteps=None,
                batch_size=batch.num_graphs,
            )
        if env_has_collider:
            results_dict["collider_context_predictions"] = self._predict_node_trajectories(
                anchor_node_group=gnn_output_collider,
                nodes_per_task=collider_nodes_per_task,
                z=z,
                node_group_init_pos=collider_init_pos,
                node_group_init_vel=collider_init_vel,
                anchor_indices=anchor_indices,
                context_timesteps=collider_context_timesteps,
                batch_size=batch.num_graphs,
            )
        return results_dict

    def _predict_node_trajectories(self, anchor_node_group: torch.Tensor, nodes_per_task: int, z: torch.Tensor,
                                   node_group_init_pos, node_group_init_vel, anchor_indices, context_timesteps,
                                   batch_size) -> torch.Tensor:
        initial_time = torch.repeat_interleave(anchor_indices, nodes_per_task, dim=0)
        initial_time = initial_time[None, :].repeat(z.shape[0], 1)

        basis_weights = self.decoder(processor_output=anchor_node_group,
                                     z=z,
                                     vertices_per_task=nodes_per_task)
        node_group_init_pos = node_group_init_pos.repeat(z.shape[0], 1, 1)
        node_group_init_vel = node_group_init_vel.repeat(z.shape[0], 1, 1)
        if context_timesteps is not None:
            context_timesteps = context_timesteps.nonzero(as_tuple=False)[:, 1]
            # reshape to (batch_dim, num_context_points)
            context_timesteps = context_timesteps.reshape(batch_size, -1)
            context_timesteps = torch.repeat_interleave(context_timesteps, nodes_per_task, dim=0)
            context_timesteps = context_timesteps[None, :, :].repeat(z.shape[0], 1, 1)
        prediction = self._mp_predictor(node_group_init_pos, node_group_init_vel, basis_weights=basis_weights,
                                        prediction_times=context_timesteps,
                                        output_vel=False,
                                        initial_time=initial_time, )
        # shape (num_samples, tasks*num_nodes_per_task, timesteps, dim)
        # disentangle tasks and nodes
        prediction = prediction.reshape(prediction.shape[0], batch_size,
                                        nodes_per_task, *prediction.shape[2:])
        # swap two dimensions to be in the standard order
        prediction = prediction.permute(0, 1, 3, 2, 4)
        # shape (num_samples, tasks, timesteps, num_nodes_per_task, dim)
        return prediction

    def _get_init_vel(self, batch, context_positions, init_pos):
        anchor_indices = batch.anchor_indices
        # clamp to 0: velocity = 0 at the first timestep
        prev_index = torch.clamp(anchor_indices - 1, 0)

        # one step earlier in time step, context_positions has shape
        # (batch_dim, num_context_timesteps, num_nodes, dim)
        prev_pos = context_positions[:, prev_index, :, :]
        prev_pos = torch.diagonal(prev_pos, dim1=0, dim2=1)
        # put new dimension to the front
        prev_pos = prev_pos.permute(2, 0, 1)
        # merge batch dim with node dim
        prev_pos = prev_pos.reshape((-1, *prev_pos.shape[2:]))
        assert prev_pos.shape == init_pos.shape
        # linear approximation
        init_vel = init_pos - prev_pos
        return init_vel

    def gnn_forward(self, batch):
        with torch.no_grad():
            processed_graph = self.gnn(batch)
            self._cached_gnn_output_mesh = unpack_node_features(processed_graph, node_type=keys.MESH)
            if keys.COLLIDER in batch[0].node_type_description:
                self._cached_gnn_output_collider = unpack_node_features(processed_graph, node_type=keys.COLLIDER)
