import copy
import os
from typing import Tuple

import torch
from gmm_util.gmm import GMM
from multi_daft_vi.util_multi_daft import create_initial_gmm_parameters
from torch_geometric.data import Data, Batch

from ltsgns_mp.algorithms.abstract_algorithm import AbstractAlgorithm
from ltsgns_mp.algorithms.posterior_learners.get_posterior_learner import get_posterior_learner
from ltsgns_mp.algorithms.lnpdfs.get_lnpdf import get_lnpdf
from ltsgns_mp.algorithms.posterior_learners.multi_daft_learner import MultiDaftLearner
from ltsgns_mp.algorithms.util import _update_external_state
from ltsgns_mp.architectures.loss_functions.mse import mse
from ltsgns_mp.architectures.simulators.ltsgns_step_simulator import LTSGNS_Step_Simulator
from ltsgns_mp.envs.env import Env
from ltsgns_mp.envs.train_iterator.ltsgns_step_train_iterator import LTSGNSStepTrainBatch
from ltsgns_mp.util.graph_input_output_util import add_distances_from_positions, node_type_mask, \
    add_and_update_node_features, recompute_external_edges, add_label
from ltsgns_mp.util.own_types import ValueDict, ConfigDict
from ltsgns_mp.util import keys


class LTSGNS_Step(AbstractAlgorithm):

    def __init__(self, config: ConfigDict, simulator: LTSGNS_Step_Simulator,
                 env: Env, optimizer: torch.optim.Optimizer, loading_config: ConfigDict, device: str):
        super().__init__(config, simulator, env, optimizer, loading_config, device)
        self._posterior_learner = get_posterior_learner(config=config.posterior_learner,
                                                        env=env,
                                                        device=device)
        self._prior = self.get_prior(loading_config)

    def _single_train_step(self, batch: LTSGNSStepTrainBatch) -> torch.Tensor:
        # unpack batch
        context_batch = batch.context_batch
        trajectory_batch = batch.trajectory_batch

        z, _ = self._condition_model_and_get_posterior_samples(batch=context_batch,
                                                               num_posterior_fit_steps=self.config.training.num_posterior_fit_steps,
                                                               num_posterior_samples=self.config.training.num_z_samples_for_elbo_estimate,
                                                               mode="train",
                                                               logging=False)

        prediction = self.simulator(trajectory_batch, z=z, use_cached_gnn_output=False)
        gth = trajectory_batch.y
        num_nodes_per_graph = trajectory_batch[0][keys.POSITIONS][node_type_mask(trajectory_batch[0], keys.MESH)].shape[
            0]
        gth = gth.reshape(-1, num_nodes_per_graph, gth.shape[-1])
        # MC estimate of the ELBO with std=1.0.
        loss = mse(prediction, gth)
        self._apply_loss(loss)
        return loss.detach().item()

    def predict_trajectory(self, data: Data, visualize: bool = False, eval_only: bool = False) -> Tuple[
        torch.Tensor, ValueDict]:
        context_batch = self._build_context_step_batch(graph=data)

        if visualize or eval_only:
            mode = "eval_from_prior"
            num_posterior_fit_steps = self.config.eval.num_posterior_fit_steps_from_prior
        else:
            mode = "eval_from_checkpoint"
            num_posterior_fit_steps = self.config.eval.num_posterior_fit_steps_from_checkpoint
        z, additional_visualization = self._condition_model_and_get_posterior_samples(batch=context_batch,
                                                                                      num_posterior_fit_steps=num_posterior_fit_steps,
                                                                                      num_posterior_samples=1,
                                                                                      mode=mode,
                                                                                      logging=visualize)

        # get positions up to the first predicted step (including anchor step)
        first_step = data[keys.ANCHOR_INDICES]
        node_mask = node_type_mask(data, keys.MESH, as_index=True)
        batch_index = 0
        mesh_positions = data[keys.CONTEXT_NODE_POSITIONS][batch_index]
        output_trajectories = [mesh_positions[i] for i in range(first_step + 1)]

        # predict the remaining steps from the anchor graph onwards in the future
        trajectory_length = mesh_positions.shape[0]
        # add the velocity features to the data if necessary and update the node features
        data = add_and_update_node_features(data, self.config.second_order_dynamics)

        with torch.no_grad():
            for current_step in range(first_step + 1, trajectory_length):
                # update collider positions and potentially other external (sensory) information such as point_clouds
                if len(output_trajectories) > 1:
                    prev_pos = output_trajectories[-2]
                else:
                    prev_pos = output_trajectories[-1]
                data = _update_external_state(batch_index, current_step, data, prev_pos)

                # recompute *all* edges, remove edge distances in preparation for next iteration
                data = recompute_external_edges(data, self.env.config, self._device)
                # compute and add relative distances
                # count edge types
                data = add_distances_from_positions(data, self.config.train_iterator.add_euclidian_distance)

                # predict velocities/accelerations
                predicted_dynamics = self.simulator(data, z, use_cached_gnn_output=False)
                # remove sample, task and time dimension
                predicted_dynamics = predicted_dynamics[0, 0, 0]

                # update positions
                if self.config.second_order_dynamics:
                    velocities_indices = [index for index, value in enumerate(data.x_description) if
                                          value == keys.VELOCITIES]
                    all_vel = data.x[:, velocities_indices]
                    mesh_vel = all_vel[node_mask]
                    mesh_vel += predicted_dynamics
                    data[keys.POSITIONS][node_mask] += mesh_vel
                else:
                    data[keys.POSITIONS][node_mask] += predicted_dynamics

                # add updated mesh positions to output trajectories
                new_positions = copy.deepcopy(data[keys.POSITIONS][node_mask].detach())
                output_trajectories.append(new_positions)
        # finalize output trajectories
        output_trajectories = torch.stack(output_trajectories, dim=0)
        return output_trajectories, additional_visualization

    def _build_context_step_batch(self, graph: Data) -> Tuple[Batch, Batch]:
        """
        Build the context batch for the step prediction.
        :param graph:
        :return:
        """
        context_step_list = []

        _, num_time_steps, num_mesh_nodes, action_dim = graph[keys.CONTEXT_NODE_POSITIONS].shape
        context_indices = graph.mesh_indices[0]
        for i in range(num_time_steps):
            if not context_indices[i]:
                continue
            mesh_pos = graph[keys.CONTEXT_NODE_POSITIONS][0, i]
            if keys.COLLIDER in graph.node_type_description:
                collider_pos = graph[keys.CONTEXT_COLLIDER_POSITIONS][0, i]
                # combine pos
                pos = torch.cat([mesh_pos, collider_pos], dim=0)
            else:
                pos = mesh_pos
            # use previous node types as x features
            assert "collider_velocity" not in graph.x_description
            x = graph.x
            # if len(x.shape) > 2:
            #     raise NotImplementedError(
            #         "x with more than 2 dims means that there is more than just collider and mesh node types. Not implemented yet.")
            # more features for visualization if present
            mesh_faces = graph.get("mesh_faces")
            collider_faces = graph.get("collider_faces")
            # relevant for evaluation to know from where to start the trajectory
            anchor_index = graph.anchor_indices
            # ignore task properties
            # use same edge type, edge properties and so on as in the graph. We will update that with correct relative
            # positions after noising
            step = Data(pos=pos, x=x, mesh_faces=mesh_faces, collider_faces=collider_faces, anchor_index=anchor_index,
                        edge_index=graph.edge_index, edge_attr=graph.edge_attr, edge_type=graph.edge_type,
                        node_type=graph.node_type,
                        x_description=graph.x_description, node_type_description=graph.node_type_description,
                        edge_type_description=graph.edge_type_description,
                        next_mesh_pos=graph.next_mesh_pos,
                        )
            context_step_list.append(step)
        context_batch = Batch.from_data_list(context_step_list)
        context_batch.mesh_indices = graph.mesh_indices
        context_batch.task_indices = [graph.task_indices]
        context_batch = add_label(context_batch, second_order_dynamics=self.config.second_order_dynamics)
        # add relative positions
        context_batch = add_distances_from_positions(context_batch, self.config.train_iterator.add_euclidian_distance)
        return context_batch

    def save_checkpoint(self, directory: str, iteration: int, is_initial_save: bool, is_final_save: bool = False):
        super().save_checkpoint(directory, iteration, is_initial_save, is_final_save)
        # nothing to save in the posterior learner, but we have to save the prior params
        if is_initial_save:
            prior_dict = self._prior.get_params_dict()
            file_name = f"{keys.PRIOR}.pt"
            torch.save(prior_dict, os.path.join(directory, file_name))

    def _condition_model_and_get_posterior_samples(self, *,
                                                   batch: Batch,
                                                   num_posterior_fit_steps: int,
                                                   num_posterior_samples: int,
                                                   mode: str | None = None,
                                                   logging: bool = False) -> Tuple[torch.Tensor, ValueDict]:
        """
        Condition the model on the given batch and task belonging and return posterior samples. Logs the fitting of
        the posterior learner.
        Args:
            batch: Data to condition the model on
            task_belonging: Task belonging of the batch
            num_posterior_fit_steps:
            num_posterior_samples: How many z should be drawn from the posterior
            mode: Mode for the Posterior Learner, currently only in use for MultiDaftPosteriorLearner
            logging: Whether to log the fitting of the posterior learner

        Returns: z samples from the posterior and logging of the posterior learner (empty dict if no logging enabled)

        """

        # load the batch into the simulator
        self.simulator.gnn_forward(batch)
        task_indices = batch[keys.TASK_INDICES]
        # adapt the posterior learner
        if isinstance(self.posterior_learner, MultiDaftLearner):
            # set eval_from_checkpoint mode. This is the same as train, but it uses the eval GMM parameters
            if mode is None:
                raise ValueError("Mode may not be None for MultiDaftPosteriorLearner")
            self._posterior_learner: MultiDaftLearner
            self._posterior_learner.mode = mode
            # reset the current weights in the posterior logger if mode is eval_from_prior
            if mode == "eval_from_prior":
                self._posterior_learner.reset_eval_from_prior()

        lnpdf = get_lnpdf(lnpdf_config=self.config.posterior_learner.lnpdf,
                          batch=batch, simulator=self.simulator, prior=self._prior)
        posterior_learner_logging_results = self.posterior_learner.fit(n_steps=num_posterior_fit_steps,
                                                                       task_indices=task_indices,
                                                                       lnpdf=lnpdf,
                                                                       logging=logging, )
        # Get the samples from the posterior
        z = self.posterior_learner.sample(n_samples=num_posterior_samples,
                                          task_indices=task_indices,
                                          lnpdf=lnpdf
                                          )
        if logging:
            lnpdf_additional_information = lnpdf.get_additional_information(z)
            posterior_learner_logging_results.update(lnpdf_additional_information)
        return z, posterior_learner_logging_results

    @property
    def posterior_learner(self):
        return self._posterior_learner

    @property
    def simulator(self) -> LTSGNS_Step_Simulator:
        if self._simulator is None:
            raise ValueError("Simulator not set")
        return self._simulator

    def get_prior(self, loading_config: ConfigDict) -> GMM:
        prior_config = self.config.posterior_learner.lnpdf.prior
        if loading_config.enable_loading:

            # load the prior parameters
            prior_dict = torch.load(os.path.join(loading_config.checkpoint_path, keys.PRIOR + ".pt"))
            prior_w = prior_dict["log_w"]
            prior_mean = prior_dict["mean"]
            prior_cov = prior_dict["prec"]

        else:
            # create new params
            prior_w, prior_mean, prior_cov = create_initial_gmm_parameters(
                n_tasks=1,
                d_z=self.simulator.d_z,
                n_components=prior_config.n_components,
                prior_scale=prior_config.prior_scale,
                initial_var=prior_config.initial_var,
            )
        prior = GMM(
            log_w=prior_w,
            mean=prior_mean,
            prec=torch.linalg.inv(prior_cov),
            device=self._device
        )
        return prior
