import os
from typing import Tuple

import torch
from gmm_util.gmm import GMM
from multi_daft_vi.util_multi_daft import create_initial_gmm_parameters
from torch_geometric.data import Data, Batch

from ltsgns_mp.algorithms.abstract_algorithm import AbstractAlgorithm
from ltsgns_mp.algorithms.lnpdfs.get_lnpdf import get_lnpdf
from ltsgns_mp.algorithms.posterior_learners.get_posterior_learner import get_posterior_learner
from ltsgns_mp.algorithms.posterior_learners.multi_daft_learner import MultiDaftLearner
from ltsgns_mp.architectures.loss_functions.mse import mse
from ltsgns_mp.architectures.simulators.ltsgns_mp_simulator import LTSGNS_MP_Simulator
from ltsgns_mp.envs.env import Env
from ltsgns_mp.envs.train_iterator.trajectory_train_iterator import TrajectoryTrainBatch
from ltsgns_mp.util import keys
from ltsgns_mp.util.graph_input_output_util import add_distances_from_positions, remove_edge_distances
from ltsgns_mp.util.own_types import ValueDict, ConfigDict


class LTSGNS_MP(AbstractAlgorithm):

    def __init__(self, config: ConfigDict, simulator: LTSGNS_MP_Simulator,
                 env: Env, optimizer: torch.optim.Optimizer, loading_config: ConfigDict, device: str):
        super().__init__(config, simulator, env, optimizer, loading_config, device)
        self._posterior_learner = get_posterior_learner(config=config.posterior_learner,
                                                        env=env,
                                                        device=device)
        self._prior = self.get_prior(loading_config)

    def _single_train_step(self, batch: TrajectoryTrainBatch) -> torch.Tensor:
        # unpack batch
        batch = batch.batch

        z, _ = self._condition_model_and_get_posterior_samples(batch=batch,
                                                               num_posterior_fit_steps=self.config.training.num_posterior_fit_steps,
                                                               num_posterior_samples=self.config.training.num_z_samples_for_elbo_estimate,
                                                               mode="train",
                                                               logging=False)
        if self.config.tight_elbo_training:
            prediction = self.simulator(batch, z=z, use_cached_gnn_output=False, predict_context_timesteps=True)
            mesh_prediction = prediction["context_predictions"]
            context_type = batch[keys.CONTEXT_TYPE][0]
            context_indices = batch[f"{context_type}_indices"]
            gth = batch[keys.CONTEXT_NODE_POSITIONS][context_indices]
            # reshape to batch dim
            gth = gth.reshape(mesh_prediction.shape[1:])
            # MC estimate of the ELBO with std=1.0.
            mesh_loss = mse(mesh_prediction, gth)
        else:
            prediction = self.simulator(batch, z=z, use_cached_gnn_output=False, predict_context_timesteps=False)
            mesh_prediction = prediction["full_trajectory_predictions"]
            gth = batch[keys.CONTEXT_NODE_POSITIONS]
            # MC estimate of the ELBO with std=1.0.
            mesh_loss = mse(mesh_prediction, gth)
        if prediction["collider_context_predictions"] is not None:
            # add the loss for the collider context
            collider_prediction = prediction["collider_context_predictions"]
            gth = batch[keys.CONTEXT_COLLIDER_POSITIONS]
            collider_loss = mse(collider_prediction, gth)
            loss = mesh_loss + collider_loss
        else:
            loss = mesh_loss

        self._apply_loss(loss)
        return loss.detach().item()

    def predict_trajectory(self, data: Data, visualize: bool = False, eval_only: bool = False) -> Tuple[
        torch.Tensor, ValueDict]:
        # add relative positions
        data = remove_edge_distances(data)
        data = add_distances_from_positions(data, self.config.train_iterator.add_euclidian_distance)
        # put it in a batch to have same interface as in training
        data = Batch.from_data_list([data])
        if visualize or eval_only:
            mode = "eval_from_prior"
            num_posterior_fit_steps = self.config.eval.num_posterior_fit_steps_from_prior
        else:
            mode = "eval_from_checkpoint"
            num_posterior_fit_steps = self.config.eval.num_posterior_fit_steps_from_checkpoint
        z, additional_visualization = self._condition_model_and_get_posterior_samples(batch=data,
                                                                                      num_posterior_fit_steps=num_posterior_fit_steps,
                                                                                      num_posterior_samples=1,
                                                                                      mode=mode,
                                                                                      logging=visualize)
        with torch.no_grad():
            prediction = self.simulator(data, z=z, use_cached_gnn_output=True, predict_context_timesteps=False)
            prediction = prediction["full_trajectory_predictions"][0, 0]  # remove batch dim
            # remove z sample dim as we only have 1 sample

        mesh_positions = data[keys.CONTEXT_NODE_POSITIONS][0]  # remove batch dim

        # get positions up to the first predicted step (including anchor step)
        first_step = data[keys.ANCHOR_INDICES][0]
        prediction[:first_step + 1] = mesh_positions[:first_step + 1]
        return prediction, additional_visualization

    def save_checkpoint(self, directory: str, iteration: int, is_initial_save: bool, is_final_save: bool = False):
        super().save_checkpoint(directory, iteration, is_initial_save, is_final_save)
        # save the posterior learner
        self.posterior_learner.save_checkpoint(directory, iteration, is_initial_save, is_final_save)
        # save the prior
        if is_initial_save:
            prior_dict = self._prior.get_params_dict()
            file_name = f"{keys.PRIOR}.pt"
            torch.save(prior_dict, os.path.join(directory, file_name))

    def _condition_model_and_get_posterior_samples(self, *,
                                                   batch: Batch,
                                                   num_posterior_fit_steps: int,
                                                   num_posterior_samples: int,
                                                   mode: str | None = None,
                                                   logging: bool = False) -> Tuple[torch.Tensor, ValueDict]:
        """
        Condition the model on the given batch and task belonging and return posterior samples. Logs the fitting of
        the posterior learner.
        Args:
            batch: Data to condition the model on
            task_belonging: Task belonging of the batch
            num_posterior_fit_steps:
            num_posterior_samples: How many z should be drawn from the posterior
            mode: Mode for the Posterior Learner, currently only in use for MultiDaftPosteriorLearner
            logging: Whether to log the fitting of the posterior learner

        Returns: z samples from the posterior and logging of the posterior learner (empty dict if no logging enabled)

        """

        # load the batch into the simulator
        self.simulator.gnn_forward(batch)
        task_indices = batch[keys.TASK_INDICES]
        # adapt the posterior learner
        if isinstance(self.posterior_learner, MultiDaftLearner):
            # set eval_from_checkpoint mode. This is the same as train, but it uses the eval GMM parameters
            if mode is None:
                raise ValueError("Mode may not be None for MultiDaftPosteriorLearner")
            self._posterior_learner: MultiDaftLearner
            self._posterior_learner.mode = mode
            # reset the current weights in the posterior logger if mode is eval_from_prior
            if mode == "eval_from_prior":
                self._posterior_learner.reset_eval_from_prior()

        lnpdf = get_lnpdf(lnpdf_config=self.config.posterior_learner.lnpdf,
                          batch=batch, simulator=self.simulator, prior=self._prior)
        posterior_learner_logging_results = self.posterior_learner.fit(n_steps=num_posterior_fit_steps,
                                                                       task_indices=task_indices,
                                                                       lnpdf=lnpdf,
                                                                       logging=logging, )
        # Get the samples from the posterior
        z = self.posterior_learner.sample(n_samples=num_posterior_samples,
                                          task_indices=task_indices,
                                          lnpdf=lnpdf,
                                          )
        if logging:
            lnpdf_additional_information = lnpdf.get_additional_information(z)
            posterior_learner_logging_results.update(lnpdf_additional_information)
        return z, posterior_learner_logging_results

    @property
    def posterior_learner(self):
        return self._posterior_learner

    @property
    def simulator(self) -> LTSGNS_MP_Simulator:
        if self._simulator is None:
            raise ValueError("Simulator not set")
        return self._simulator

    def get_prior(self, loading_config: ConfigDict) -> GMM:
        prior_config = self.config.posterior_learner.lnpdf.prior
        if loading_config.enable_loading:

            # load the prior parameters
            prior_dict = torch.load(os.path.join(loading_config.checkpoint_path, keys.PRIOR + ".pt"),
                                    map_location=self._device)
            prior_w = prior_dict["log_w"]
            prior_mean = prior_dict["mean"]
            prior_cov = prior_dict["prec"]

        else:
            # create new params
            prior_w, prior_mean, prior_cov = create_initial_gmm_parameters(
                n_tasks=1,
                d_z=self.simulator.d_z,
                n_components=prior_config.n_components,
                prior_scale=prior_config.prior_scale,
                initial_var=prior_config.initial_var,
            )
        prior = GMM(
            log_w=prior_w,
            mean=prior_mean,
            prec=torch.linalg.inv(prior_cov),
            device=self._device
        )
        return prior
