import copy
from typing import Tuple

import torch
from torch_geometric.data import Batch, Data

from ltsgns_mp.algorithms.abstract_algorithm import AbstractAlgorithm
from ltsgns_mp.algorithms.util import _update_external_state
from ltsgns_mp.architectures.loss_functions.mse import mse
from ltsgns_mp.architectures.simulators.cnp_simulator import CNPSimulator
from ltsgns_mp.envs import Env
from ltsgns_mp.envs.train_iterator.cnp_train_iterator import CNPTrainBatch
from ltsgns_mp.envs.train_iterator.step_train_iterator import StepTrainBatch
from ltsgns_mp.util import keys
from ltsgns_mp.util.graph_input_output_util import recompute_external_edges, add_and_update_node_features, remove_edge_distances
from ltsgns_mp.util.graph_input_output_util import add_distances_from_positions, node_type_mask
from ltsgns_mp.util.own_types import ValueDict, ConfigDict


class CNP(AbstractAlgorithm):
    def __init__(self, config: ConfigDict, simulator: CNPSimulator,
                 env: Env, optimizer: torch.optim.Optimizer, loading_config: ConfigDict, device: str):
        super().__init__(config, simulator, env, optimizer, loading_config, device)
        if "node_aggregation" in config or "context_aggregation" in config:
            raise ValueError("You are using a deprecated config. Node aggregation and context aggregation should be in Simulator config!")

    def _single_train_step(self, batch: CNPTrainBatch) -> torch.Tensor:
        # unpack batch
        context_batch = batch.context_batch
        target_batch = batch.target_batch
        z = self.simulator.compute_z(context_batch)
        target_batch = self.simulator.add_z_to_batch(target_batch, z)
        # is usually velocities, or accelerations if second order dynamics are used
        predicted_dynamics = self.simulator(target_batch)
        gth_dynamics = target_batch.y
        loss = mse(predicted_dynamics, gth_dynamics)
        self._apply_loss(loss)
        return loss.detach().item()

    def predict_trajectory(self, data: Data, visualize: bool = False, eval_only: bool = False) -> Tuple[
        torch.Tensor, ValueDict]:
        """

        Args:
            data:
            visualize:
                Because in LTSGNS_MP we sometimes have visualizations of the ELBO or the latent space which we
                only want to log in the visualization epochs (every 50-ish epochs)
            eval_only: Not relevant here, but in LTSGNS. However, if something should be different if we are in eval only

        Returns:

        """
        with torch.no_grad():
            context_batch, target_data = self.extract_cnp_context(data)
            z = self.simulator.compute_z(context_batch)
            target_data = self.simulator.add_z_to_batch(target_data, z)
            batch_index = 0 
            node_mask = node_type_mask(target_data, keys.MESH, as_index=True)  # assume consistent collider indices
            mesh_positions = target_data[keys.CONTEXT_NODE_POSITIONS][batch_index]

            # get positions up to the first predicted step (including anchor step)
            first_step = target_data[keys.ANCHOR_INDICES]
            output_trajectories = [mesh_positions[i] for i in range(first_step + 1)]

            # predict the remaining steps from the anchor graph onwards in the future
            trajectory_length = mesh_positions.shape[0]

            # add the velocity features to the data if necessary and update the node features
            target_data = add_and_update_node_features(target_data, self.config.second_order_dynamics)
            for current_step in range(first_step + 1, trajectory_length):
                # update collider positions and potentially other external (sensory) information such as point_clouds
                if len(output_trajectories) > 1:
                    prev_pos = output_trajectories[-2]
                else:
                    prev_pos = output_trajectories[-1]
                target_data = self._update_external_state(batch_index, current_step, target_data, prev_pos)

                # recompute *all* edges, remove edge distances in preparation for next iteration
                target_data = recompute_external_edges(target_data, self.env.config, self._device)
                # compute and add relative distances
                # count edge types
                target_data = add_distances_from_positions(target_data, self.config.train_iterator.add_euclidian_distance)

                # predict velocities/accelerations
                predicted_dynamics = self.simulator(target_data)

                # update positions
                if self.config.second_order_dynamics:
                    velocities_indices = [index for index, value in enumerate(target_data.x_description) if
                                          value == keys.VELOCITIES]
                    all_vel = target_data.x[:, velocities_indices]
                    mesh_vel = all_vel[node_mask]
                    mesh_vel += predicted_dynamics
                    target_data[keys.POSITIONS][node_mask] += mesh_vel
                else:
                    target_data[keys.POSITIONS][node_mask] += predicted_dynamics

                # add updated mesh positions to output trajectories
                new_positions = copy.deepcopy(target_data[keys.POSITIONS][node_mask].detach())
                output_trajectories.append(new_positions)
        # finalize output trajectories
        output_trajectories = torch.stack(output_trajectories, dim=0)
        return output_trajectories, {}  # Never visualize additional information

    def _update_external_state(self, batch_index: int, current_step: int, data: Data,
                               last_mesh_pos: torch.Tensor | None = None) -> Data:
        """
        Update the external state of the graph by updating the information of the collider. Also updates new x features
        if necessary.
        Args:
            batch_index: Index of the batch to evaluate. Currently only supports a single batch with index 0
            current_step: Current step in the trajectory
            data: Data object containing the anchor graph and auxiliary information about the external trajectory state

        Returns:

        """
        data = _update_external_state(batch_index, current_step, data, last_mesh_pos)

        return data


    def extract_cnp_context(self, data: Data) -> Tuple[Batch, Data]:
        """
        Extracts the context and target data from the given data object.
        Args:
            data: Data object containing the context and target data in trajectory format
        """
        context_data_list = []
        context_indices = torch.where(data.mesh_indices[0])[0]
        for context_idx in context_indices[:-1]:
            cloned_data = data.clone()
            context_mesh_position = cloned_data.context_node_positions[0, context_idx]
            if "context_collider_positions" in cloned_data.keys():
                context_collider_position = cloned_data.context_collider_positions[0, context_idx]
                pos = torch.cat([context_mesh_position, context_collider_position], dim=0)
            else:
                pos = context_mesh_position
            context_data = Data(
                x=cloned_data.x,
                pos=pos,
                edge_index=cloned_data.edge_index,
                edge_attr=cloned_data.edge_attr,
                edge_type=cloned_data.edge_type,
                edge_type_description=cloned_data.edge_type_description,
                node_type_description=cloned_data.node_type_description,
                node_type=cloned_data.node_type,
                x_description=cloned_data.x_description,
            )
            context_data = recompute_external_edges(context_data, self.env.config, self._device)
            # compute and add relative distances
            # count edge types
            # now add the velocity to the x feature
            if self.config.train_iterator.context_history_vel:
                prev_mesh_pos = cloned_data.context_node_positions[0, context_idx - 1]
                if "context_collider_positions" in cloned_data.keys():
                    prev_collider_pos = cloned_data.context_collider_positions[0, context_idx - 1]
                    prev_pos = torch.cat([prev_mesh_pos, prev_collider_pos], dim=0)
                else:
                    prev_pos = prev_mesh_pos
                context_vel = pos - prev_pos
                context_data.x = torch.cat([context_data.x, context_vel], dim=1)
                context_data.x_description += ["context_vel"] * context_vel.shape[1]
            future_mesh_pos = cloned_data.context_node_positions[0, context_idx + 1]
            if "context_collider_positions" in cloned_data.keys():
                future_collider_pos = cloned_data.context_collider_positions[0, context_idx + 1]
                future_pos = torch.cat([future_mesh_pos, future_collider_pos], dim=0)
            else:
                future_pos = future_mesh_pos
            vel = future_pos - pos
            context_data.x = torch.cat([context_data.x, vel], dim=1)
            context_data.x_description += [keys.VELOCITIES] * vel.shape[1]
            context_data_list.append(context_data)
        context_batch = Batch.from_data_list(context_data_list)
        context_batch = remove_edge_distances(context_batch)
        context_batch = add_distances_from_positions(context_batch, self.config.train_iterator.add_euclidian_distance)
        context_batch = add_and_update_node_features(context_batch, second_order_dynamics=False)
        return context_batch, data

    @property
    def simulator(self) -> CNPSimulator:
        if self._simulator is None:
            raise ValueError("Simulator not set")
        return self._simulator

