import torch

from XXX.uib.modules.summarizer import Summarizer
from XXX.uib.utils.tensor_chain import CpuTensorChain


class LatentLabelChain(Summarizer):
    latent_x_k_Z: CpuTensorChain
    labels_x: CpuTensorChain

    def __init__(self):
        super().__init__()

        self.latent_x_k_Z = CpuTensorChain.create()
        self.labels_x = CpuTensorChain.create()

    def reset(self):
        self.latent_x_k_Z.reset()
        self.labels_x.reset()

    def fit(self, latent_x_k_Z: torch.Tensor, labels_x: torch.Tensor):
        assert latent_x_k_Z.dim() == 3

        with torch.no_grad():
            self.labels_x.append(labels_x)
            self.latent_x_k_Z.append(latent_x_k_Z)
