import copy
from typing import Tuple

import torch
from torch_geometric.data import Batch, Data

from ltsgns_mp.algorithms.abstract_algorithm import AbstractAlgorithm
from ltsgns_mp.algorithms.mgn import MGN
from ltsgns_mp.algorithms.util import _update_external_state
from ltsgns_mp.architectures.loss_functions.mse import mse
from ltsgns_mp.envs.train_iterator.step_train_iterator import StepTrainBatch
from ltsgns_mp.util import keys
from ltsgns_mp.util.graph_input_output_util import recompute_external_edges, add_and_update_node_features
from ltsgns_mp.util.graph_input_output_util import add_distances_from_positions, node_type_mask
from ltsgns_mp.util.own_types import ValueDict


class HistoryMGN(MGN):
    def predict_trajectory(self, data: Data, visualize: bool = False, eval_only: bool = False) -> Tuple[
        torch.Tensor, ValueDict]:
        """

        Args:
            data:
            visualize:
                Because in LTSGNS_MP we sometimes have visualizations of the ELBO or the latent space which we
                only want to log in the visualization epochs (every 50-ish epochs)
            eval_only: Not relevant here, but in LTSGNS. However, if something should be different if we are in eval only

        Returns:

        """
        batch_index = 0 

        node_mask = node_type_mask(data, keys.MESH, as_index=True)  # assume consistent collider indices
        mesh_positions = data[keys.CONTEXT_NODE_POSITIONS][batch_index]

        # get positions up to the first predicted step (including anchor step)
        first_step = data[keys.ANCHOR_INDICES]
        output_trajectories = [mesh_positions[i] for i in range(first_step + 1)]

        # predict the remaining steps from the anchor graph onwards in the future
        trajectory_length = mesh_positions.shape[0]

        # add the velocity features to the data if necessary and update the node features
        data = add_and_update_node_features(data, self.config.second_order_dynamics)
        # get history vels to insert them as input to the simulator. Update them later for new steps
        use_collider = keys.CONTEXT_COLLIDER_POSITIONS in data
        history_positions = []
        for hist_idx in range(first_step - 1, first_step - self.config.train_iterator.history_length - 1, -1):
            if hist_idx < 0:
                hist_idx = 0
            hist_pos = data[keys.CONTEXT_NODE_POSITIONS][batch_index, hist_idx]
            if use_collider:
                hist_collider = data[keys.CONTEXT_COLLIDER_POSITIONS][batch_index, hist_idx]
                hist_pos = torch.cat([hist_pos, hist_collider], dim=0)
            history_positions.insert(0, hist_pos)

        history_vels = []
        current_pos = data.pos
        for hist_idx in range(len(history_positions) - 1, -1, -1):
            prev_pos = history_positions[hist_idx]
            vel = current_pos - prev_pos
            history_vels.insert(0, vel)
            # order: from oldest to newest
            current_pos = prev_pos
        no_hist_data_x_shape = data.x.shape[1]
        if use_collider:
            prev_collider = data[keys.CONTEXT_COLLIDER_POSITIONS][batch_index, first_step]
        with torch.no_grad():
            for current_step in range(first_step + 1, trajectory_length):
                # update collider positions and potentially other external (sensory) information such as point_clouds
                if len(output_trajectories) > 1:
                    prev_pos = output_trajectories[-2]
                else:
                    prev_pos = output_trajectories[-1]
                data = self._update_external_state(batch_index, current_step, data, prev_pos)

                # recompute *all* edges, remove edge distances in preparation for next iteration
                data = recompute_external_edges(data, self.env.config, self._device)
                # compute and add relative distances
                # count edge types
                data = add_distances_from_positions(data, self.config.train_iterator.add_euclidian_distance)
                # add history data
                if self.config.train_iterator.history_length > 0:
                    history_vels_tensor = torch.stack(history_vels, dim=1)
                else:
                    history_vels_tensor = torch.zeros((data.pos.shape[0], 0, data.pos.shape[1]), device=self._device)
                history_vels_tensor = history_vels_tensor.view(history_vels_tensor.shape[0], -1)
                # remove previous history data
                data.x = data.x[:, :no_hist_data_x_shape]
                # add new history data
                data.x = torch.cat([data.x, history_vels_tensor], dim=1)
                # predict velocities/accelerations
                predicted_dynamics = self.simulator(data)

                # update positions
                if self.config.second_order_dynamics:
                    raise NotImplementedError("Second order dynamics not implemented for history MGN")
                    # velocities_indices = [index for index, value in enumerate(data.x_description) if
                    #                       value == keys.VELOCITIES]
                    # all_vel = data.x[:, velocities_indices]
                    # mesh_vel = all_vel[node_mask]
                    # mesh_vel += predicted_dynamics
                    # data[keys.POSITIONS][node_mask] += mesh_vel
                else:
                    data[keys.POSITIONS][node_mask] += predicted_dynamics
                # update history
                if len(history_vels) > 0:
                    history_vels.pop(0)
                    if use_collider:
                        collider_vel = data[keys.CONTEXT_COLLIDER_POSITIONS][batch_index, current_step] - prev_collider
                        prev_collider = data[keys.CONTEXT_COLLIDER_POSITIONS][batch_index, current_step]
                        history_vels.append(torch.concatenate((predicted_dynamics, collider_vel), dim=0))
                    else:
                        history_vels.append(predicted_dynamics)

                # add updated mesh positions to output trajectories
                new_positions = copy.deepcopy(data[keys.POSITIONS][node_mask].detach())
                output_trajectories.append(new_positions)
        # finalize output trajectories
        output_trajectories = torch.stack(output_trajectories, dim=0)
        return output_trajectories, {}  # Never visualize additional information


