import copy
from typing import Tuple

import torch
from torch_geometric.data import Batch, Data

from ltsgns_mp.algorithms.abstract_algorithm import AbstractAlgorithm
from ltsgns_mp.algorithms.cnp import CNP
from ltsgns_mp.algorithms.util import _update_external_state
from ltsgns_mp.architectures.loss_functions.mse import mse
from ltsgns_mp.architectures.simulators.cnp_simulator import CNPSimulator
from ltsgns_mp.architectures.simulators.gnn import get_egno_input
from ltsgns_mp.envs import Env
from ltsgns_mp.envs.train_iterator.cnp_train_iterator import CNPTrainBatch
from ltsgns_mp.envs.train_iterator.step_train_iterator import StepTrainBatch
from ltsgns_mp.util import keys
from ltsgns_mp.util.graph_input_output_util import recompute_external_edges, add_and_update_node_features, remove_edge_distances
from ltsgns_mp.util.graph_input_output_util import add_distances_from_positions, node_type_mask
from ltsgns_mp.util.own_types import ValueDict, ConfigDict
from ltsgns_mp.util.util import to_numpy


class EGNO(AbstractAlgorithm):
    def __init__(self, config: ConfigDict, simulator: CNPSimulator,
                 env: Env, optimizer: torch.optim.Optimizer, loading_config: ConfigDict, device: str):
        super().__init__(config, simulator, env, optimizer, loading_config, device)

    def _single_train_step(self, batch: CNPTrainBatch) -> torch.Tensor:
        # unpack batch
        context_batch = batch.context_batch
        target_batch = batch.target_batch
        anchor_index = int(context_batch.anchor_index)
        traj_pred = self.simulator(target_batch)
        # if torch.sum(torch.isnan(traj_pred)) > 0:
        #     traj_pred2 = self.simulator(target_batch)
        #     print("stop")
        traj_pred = self.cut_prediction(traj_pred, anchor_index, target_batch.context_node_positions)
        gth_positions = target_batch.context_node_positions[0]
        loss = mse(traj_pred, gth_positions)
        # debug plot
        # print(loss)
        # if loss < 0.0025:
        #     import matplotlib.pyplot as plt
        #     for ts in range(gth_positions.shape[0]):
        #         plt.scatter(to_numpy(gth_positions[ts, :, 0]), to_numpy(gth_positions[ts, :, 1]), c='r')
        #         plt.scatter(to_numpy(traj_pred[ts, :, 0]), to_numpy(traj_pred[ts, :, 1]), c='b')
        #         plt.show()
        #         plt.close()
        self._apply_loss(loss)
        return loss.detach().item()


    def cut_prediction(self, traj_pred, anchor_index, gth_mesh_positions):
        """
        EGNO predicts always P steps (P= len(traj)), but we are only interested in the rest of the trajectory after the anchor index
        Also, EGNO predits the collider, we only want the mesh
        :param traj_pred:
        :param anchor_index:
        :param gth_mesh_positions:
        :return:
        """
        # remove collider
        num_nodes =gth_mesh_positions.shape[2]
        traj_pred = traj_pred[:, :num_nodes, ]
        # cut part outside gth
        traj_pred = traj_pred[:-anchor_index]
        # insert ground truth in context set
        traj_pred = torch.cat((gth_mesh_positions[0, :anchor_index], traj_pred), dim=0)
        return traj_pred

    def predict_trajectory(self, data: Data, visualize: bool = False, eval_only: bool = False) -> Tuple[
        torch.Tensor, ValueDict]:
        """

        Args:
            data:
            visualize:
                Because in LTSGNS_MP we sometimes have visualizations of the ELBO or the latent space which we
                only want to log in the visualization epochs (every 50-ish epochs)
            eval_only: Not relevant here, but in LTSGNS. However, if something should be different if we are in eval only

        Returns:

        """
        with torch.no_grad():
            target_data = remove_edge_distances(data)
            target_data = add_distances_from_positions(target_data, self.config.train_iterator.add_euclidian_distance)
            target_data = Batch.from_data_list([target_data])
            traj_pred = self.simulator(target_data)
            anchor_index = int(target_data.anchor_indices)
            traj_pred = self.cut_prediction(traj_pred, anchor_index, target_data.context_node_positions)

        return traj_pred, {}  # Never visualize additional information





    @property
    def simulator(self) -> CNPSimulator:
        if self._simulator is None:
            raise ValueError("Simulator not set")
        return self._simulator

