import torch
from multi_daft_vi.lnpdf import LNPDF

from ltsgns_mp.algorithms.posterior_learners.abstract_posterior_learner import AbstractPosteriorLearner
from ltsgns_mp.util import keys

from ltsgns_mp.util.own_types import ValueDict, ConfigDict


class TaskPropLearner(AbstractPosteriorLearner):

    def __init__(self, posterior_learner_config: ConfigDict, n_all_tasks: int, device: str):
        """
        :param n_all_tasks: Number of different training tasks. Needed for getting the correct shape of the latent samples
        """
        super().__init__(posterior_learner_config, device=device)
        self._d_z = posterior_learner_config.d_z
        self._n_all_tasks = n_all_tasks
        self._current_task_prop = None

    def sample(self, n_samples: int, task_indices: torch.Tensor) -> torch.Tensor:
        # return a tensor of shape (n_samples, n_tasks, d_z) with the task properties of the current batch
        return torch.stack([self._current_task_prop] * n_samples, dim=0).to(self._device)

    def fit(self, n_steps: int, task_indices: torch.Tensor, lnpdf: LNPDF, logging: bool = False) -> ValueDict:
        # get the task prop from the lnpdf batch
        self._current_task_prop = lnpdf.batch[keys.TASK_PROPERTIES]
        assert self._current_task_prop.shape[-1] == self.d_z, f"Change the d_z to {self._current_task_prop.shape[-1]}."
        return {}
    @property
    def n_all_tasks(self):
        return self._n_all_tasks

    @property
    def d_z(self):
        return self._d_z

    def save_checkpoint(self, directory: str, iteration: int, is_initial_save: bool, is_final_save: bool = False):
        pass
