from typing import Tuple

import torch
from multi_daft_vi.lnpdf import LNPDF

from ltsgns_mp.algorithms.lnpdfs.ltsgns_mp_lnpdf import LTSGNS_MP_LNPDF
from ltsgns_mp.util import keys
from ltsgns_mp.architectures.loss_functions.log_likelihood import log_likelihood_per_time_step


class LTSGNS_Step_LNPDF(LTSGNS_MP_LNPDF):
    def __init__(self, config, batch, simulator, prior):
        super().__init__(config, batch, simulator, prior)

    def _log_likelihood(self, z: torch.Tensor):
        dynamics_prediction = self.simulator(self.batch, z,
                                             use_cached_gnn_output=True,
                                             )

        context_gth = self.batch.y
        # there is only mesh context
        mesh_log_llh = self._compute_likelihood(dynamics_prediction, context_gth, context_type=keys.MESH)
        additional_information = {}
        return mesh_log_llh, additional_information
