import copy
from typing import Tuple

import torch
from torch_geometric.data import Batch, Data

from ltsgns_mp.algorithms.abstract_algorithm import AbstractAlgorithm
from ltsgns_mp.algorithms.np import NP
from ltsgns_mp.algorithms.util import _update_external_state
from ltsgns_mp.architectures.loss_functions.logsumexp_mse import logsumexp_mse
from ltsgns_mp.architectures.loss_functions.mse import mse
from ltsgns_mp.architectures.simulators.cnp_simulator import CNPSimulator
from ltsgns_mp.architectures.simulators.np_simulator import NPSimulator
from ltsgns_mp.envs import Env
from ltsgns_mp.envs.train_iterator.cnp_train_iterator import CNPTrainBatch
from ltsgns_mp.envs.train_iterator.step_train_iterator import StepTrainBatch
from ltsgns_mp.util import keys
from ltsgns_mp.util.graph_input_output_util import recompute_external_edges, add_and_update_node_features, remove_edge_distances
from ltsgns_mp.util.graph_input_output_util import add_distances_from_positions, node_type_mask
from ltsgns_mp.util.own_types import ValueDict, ConfigDict


class NP_MP(NP):
    def __init__(self, config: ConfigDict, simulator: NPSimulator,
                 env: Env, optimizer: torch.optim.Optimizer, loading_config: ConfigDict, device: str):
        super().__init__(config, simulator, env, optimizer, loading_config, device)


    def _single_train_step(self, batch: CNPTrainBatch) -> torch.Tensor:
        # unpack batch
        context_batch = batch.context_batch
        target_batch = batch.target_batch
        z_mean, z_var = self.simulator.compute_task_posterior(context_batch)
        z_sigma = torch.sqrt(z_var)
        z_samples = torch.distributions.Normal(z_mean, z_sigma).rsample((self.config.num_z_samples,))
        anchor_index = int(context_batch.anchor_index)
        # only one data in this batch
        anchor_target = target_batch[0]
        anchor_target_batch = []
        for z in z_samples:
            current_z_batch = anchor_target.clone()
            self.simulator.add_z_to_batch(current_z_batch, z)
            anchor_target_batch.append(current_z_batch)
        anchor_target_batch = Batch.from_data_list(anchor_target_batch)
        trajectory_prediction = self.simulator(anchor_target_batch,
                                               initial_time=anchor_index,
                                               )
        gth_positions = torch.permute(anchor_target_batch.context_node_positions, (0, 2, 1, 3))
        # split up the z dim from node dim
        trajectory_prediction = torch.reshape(trajectory_prediction, (gth_positions.shape[0], gth_positions.shape[1], gth_positions.shape[2], gth_positions.shape[3]))
        # compute loss
        if self.config.loss_function == "logsumexp_mse":
            loss = logsumexp_mse(trajectory_prediction, gth_positions, logsumexp_dim=0) - torch.log(torch.tensor(self.config.num_z_samples, device=self._device))
        elif self.config.loss_function == "mse":
            loss = mse(trajectory_prediction, gth_positions)
        else:
            raise ValueError(f"Loss function {self.config.loss_function} not supported")
        self._apply_loss(loss)
        return loss.detach().item()

    def predict_trajectory(self, data: Data, visualize: bool = False, eval_only: bool = False) -> Tuple[
        torch.Tensor, ValueDict]:
        """

        Args:
            data:
            visualize:
                Because in LTSGNS_MP we sometimes have visualizations of the ELBO or the latent space which we
                only want to log in the visualization epochs (every 50-ish epochs)
            eval_only: Not relevant here, but in LTSGNS. However, if something should be different if we are in eval only

        Returns:

        """

        with torch.no_grad():
            context_batch, target_data = self.extract_np_context(data)
            z_mean, z_var = self.simulator.compute_task_posterior(context_batch)
            # for testing: use MAP estimate
            z = z_mean
            target_data = self.simulator.add_z_to_batch(target_data, z)
            target_data = remove_edge_distances(target_data)
            target_data = add_distances_from_positions(target_data, self.config.train_iterator.add_euclidian_distance)
            target_data = Batch.from_data_list([target_data])
            anchor_index = int(target_data[keys.ANCHOR_INDICES][0])
            trajectory_prediction = self.simulator(target_data,
                                                   initial_time=anchor_index,
                                                   )
            # permute to have dim order [num_timesteps, num_nodes, world_dim]
            trajectory_prediction = torch.permute(trajectory_prediction, (1, 0, 2))
            mesh_positions = target_data[keys.CONTEXT_NODE_POSITIONS][0]
            # get positions up to the first predicted step (including anchor step)
            trajectory_prediction[:anchor_index + 1] = mesh_positions[:anchor_index + 1]
        return trajectory_prediction, {}  # Never visualize additional information



    @property
    def simulator(self) -> NPSimulator:
        if self._simulator is None:
            raise ValueError("Simulator not set")
        return self._simulator

