from pydgn.training.callback.loss import Loss


class iCGMMLoss(Loss):
    __name__ = 'iCGMM Loss'

    def __init__(self):
        super().__init__()
        self.old_likelihood = -float('inf')
        self.new_likelihood = None

    def on_training_batch_end(self, state):
        self.batch_losses.append(state.batch_loss[self.__name__].item())
        # This works for unsupervised CGMM
        self.num_samples += state.batch_num_nodes

    def on_training_epoch_end(self, state):
        super().on_training_epoch_end(state)

        if (state.epoch_loss[self.__name__].item() - self.old_likelihood) < 0:
            pass
            # tate.stop_training = True
        self.old_likelihood = state.epoch_loss[self.__name__].item()

    def on_eval_batch_end(self, state):
        self.batch_losses.append(state.batch_loss[self.__name__].item())
        self.num_samples += state.batch_num_nodes

    # Simply ignore targets
    def forward(self, targets, *outputs):
        likelihood = outputs[2]
        return likelihood

    def on_backward(self, state):
        pass