from typing import Tuple

import torch
from multi_daft_vi.lnpdf import LNPDF

from ltsgns_mp.util import keys
from ltsgns_mp.architectures.loss_functions.log_likelihood import log_likelihood_per_time_step


class LTSGNS_MP_LNPDF(LNPDF):
    def __init__(self, config, batch, simulator, prior):
        super().__init__()
        self.config = config
        self.batch = batch
        self.simulator = simulator
        # although the prior config is in the lnpdf config, the real prior is created in the algorithm class
        self.prior = prior
        self._additional_information = None

    def log_density(self, z: torch.Tensor, compute_grad: bool = False) -> Tuple[torch.Tensor, torch.Tensor | None]:
        if compute_grad:
            z.requires_grad_(True)
            if z.grad is not None:
                z.grad.zero_()

        log_likelihood, additional_information = self._log_likelihood(z)
        log_prior_density = self._log_prior_density(z=z)
        log_density = log_likelihood + log_prior_density
        additional_information["log_prior_density"] = log_prior_density.squeeze()
        additional_information["log_likelihood"] = log_likelihood.squeeze()
        additional_information["log_density"] = log_density.squeeze()
        self._additional_information = additional_information

        # log_density has shape (num_samples, num_tasks)
        if compute_grad:
            # sum and backward
            torch.sum(log_density).backward()
            # check that the result does not contain any nans
            assert torch.isnan(z.grad).sum() == 0
            assert torch.isnan(log_density).sum() == 0
            return log_density, z.grad
        else:
            return log_density, None

    def get_num_dimensions(self):
        return self.simulator.d_z

    def _log_likelihood(self, z: torch.Tensor):
        predictions_dict = self.simulator(self.batch, z,
                                          use_cached_gnn_output=True,
                                          predict_context_timesteps=True,
                                          )

        context_predictions = predictions_dict["context_predictions"]
        context_type = self.batch[keys.CONTEXT_TYPE][0]
        if context_type == keys.MESH:
            mesh_indices = self.batch.mesh_indices
            context_gth = self.batch[keys.CONTEXT_NODE_POSITIONS][mesh_indices]
        elif context_type == keys.POINT_CLOUD:
            point_cloud_indices = self.batch.point_cloud_indices
            context_gth = self.batch[keys.CONTEXT_POINT_CLOUD_POSITIONS][point_cloud_indices]
        else:
            raise NotImplementedError(f"Unknown context type: {context_type}")
        mesh_log_llh = self._compute_likelihood(context_predictions, context_gth, context_type)

        if predictions_dict["collider_context_predictions"] is not None:
            collider_predictions = predictions_dict["collider_context_predictions"]
            gth_collider = self.batch[keys.CONTEXT_COLLIDER_POSITIONS]
            gth_collider = gth_collider.view(-1, *gth_collider.shape[2:])
            # shape (batch*selected_context_timesteps, n_vertices, d_world)
            # context type is Always mesh for the collider
            collider_log_llh = self._compute_likelihood(collider_predictions, gth_collider, context_type=keys.MESH)
        else:
            collider_log_llh = torch.zeros_like(mesh_log_llh)

        mesh_llh_weighting = self.config.likelihood.mesh_llh_weighting

        log_likelihood = mesh_llh_weighting * mesh_log_llh + (1 - mesh_llh_weighting) * collider_log_llh
        additional_information = {
            "mesh_log_llh": mesh_log_llh.squeeze(),
            "weighted_mesh_log_llh": (mesh_llh_weighting * mesh_log_llh).squeeze(),
            "collider_log_llh": collider_log_llh.squeeze(),
            "weighted_collider_log_llh": ((1 - mesh_llh_weighting) * collider_log_llh).squeeze(),
        }
        return log_likelihood, additional_information

    def _compute_likelihood(self, predictions, gth, context_type):
        gth = gth.view(self.batch.num_graphs, -1, *gth.shape[1:])
        # shape (batch, selected_context_timesteps, n_vertices, d_world)
        if context_type == keys.MESH:
            likelihood_std = self.config.likelihood.mesh_std
        elif context_type == keys.POINT_CLOUD:
            likelihood_std = self.config.likelihood.pc_std
        else:
            raise NotImplementedError(f"Unknown context type: {context_type}")
        log_llh_per_time_step = log_likelihood_per_time_step(predictions, gth,
                                                             likelihood_std=likelihood_std,
                                                             gth_type=context_type)
        # shape (z samples, batch, selected_context_timesteps)
        log_llh = torch.sum(log_llh_per_time_step, dim=2)
        return log_llh

    def _log_prior_density(self, z: torch.Tensor):
        log_prior_density, _ = self.prior.log_density(z=z, compute_grad=False)
        return log_prior_density

    def get_additional_information(self, z: torch.Tensor):
        """
        This logs different parts of the log density separately. Useful to figure out the influence of each part.
        :param z:
        :return:
        """
        _, _ = self.log_density(z=z, compute_grad=False)
        return self._additional_information

