from pydgn.training.event.handler import EventHandler


class iCGMMOptimizer(EventHandler):
    def __init__(self, **kwargs):
        super().__init__()

    def on_training_epoch_start(self, state):
        """
        Use the "compute_intermediate_outputs" field of the state to decide whether to compute statistics or not during
        this training epoch
        :param state: the shared State object
        """
        icgmm = state.model
        icgmm.compute_intermediate_outputs = state.compute_intermediate_outputs

    def on_training_epoch_end(self, state):
        """
        Calls the M_step to update the parameters
        :param state: the shared State object
        :return:
        """
        state.model.update()

    def on_eval_epoch_start(self, state):
        """
        Use the "compute_intermediate_outputs" field of the state to decide whether to compute statistics or not during
        this evaluation epoch
        :param state: the shared State object
        """
        icgmm = state.model
        icgmm.compute_intermediate_outputs = state.compute_intermediate_outputs

    # Not necessary, but it may help to debug
    def on_eval_epoch_end(self, state):
        """
        Reset the "compute_intermediate_outputs" field to False
        :param state:
        :return:
        """
        icgmm = state.model
        icgmm.compute_intermediate_outputs = False
