import torch
from pydgn.training.callback.score import Score


class iCGMMCompleteLikelihoodScore(Score):
    __name__ = 'iCGMM Complete Log Likelihood'

    def __init__(self):
        super().__init__()

    def on_training_batch_end(self, state):
        self.batch_scores.append(state.batch_score[self.__name__].item())
        # This works for unsupervised CGMM
        self.num_samples += state.batch_num_nodes

    def on_eval_epoch_end(self, state):
        state.update(epoch_score={self.__name__: torch.tensor(self.batch_scores).sum() / self.num_samples})
        self.batch_scores = None
        self.num_samples = None

    def on_eval_batch_end(self, state):
        self.batch_scores.append(state.batch_score[self.__name__].item())
        # This works for unsupervised CGMM
        self.num_samples += state.batch_num_nodes

    def _score_fun(self, targets, *outputs, batch_loss_extra):
        return outputs[2]


class iCGMMCompleteLikelihoodScore2(iCGMMCompleteLikelihoodScore):
    __name__ = 'iCGMM Complete Log Likelihood 2'

    def _score_fun(self, targets, *outputs, batch_loss_extra):
        return outputs[3]


class iCGMMCompleteLikelihoodScore3(iCGMMCompleteLikelihoodScore):
    __name__ = 'iCGMM Complete Log Likelihood 3'

    def _score_fun(self, targets, *outputs, batch_loss_extra):
        return outputs[4]


class iCGMMCurrentStates(Score):
    __name__ = 'iCGMM Current States'

    def __init__(self):
        super().__init__()

    def on_training_batch_end(self, state):
        pass

    def on_training_epoch_end(self, state):
        state.update(epoch_score={self.__name__: state.model.Ccurr.clone().detach()})

    def on_eval_epoch_end(self, state):
        state.update(epoch_score={self.__name__: state.model.Ccurr.clone().detach()})

    def on_eval_batch_end(self, state):
        pass

    def _score_fun(self, targets, *outputs, batch_loss_extra):
        return torch.tensor(0.)


class iCGMMCurrentAlpha(Score):
    __name__ = 'iCGMM Current Alpha'

    def __init__(self):
        super().__init__()

    def on_training_batch_end(self, state):
        pass

    def on_training_epoch_end(self, state):
        state.update(epoch_score={self.__name__: state.model.alpha.clone().detach()})

    def on_eval_epoch_end(self, state):
        state.update(epoch_score={self.__name__: state.model.alpha.clone().detach()})

    def on_eval_batch_end(self, state):
        pass

    def _score_fun(self, targets, *outputs, batch_loss_extra):
        return torch.tensor(0.)


class iCGMMCurrentGamma(Score):
    __name__ = 'iCGMM Current Gamma'

    def __init__(self):
        super().__init__()

    def on_training_batch_end(self, state):
        pass

    def on_training_epoch_end(self, state):
        state.update(epoch_score={self.__name__: state.model.gamma.clone().detach()})

    def on_eval_epoch_end(self, state):
        state.update(epoch_score={self.__name__: state.model.gamma.clone().detach()})

    def on_eval_batch_end(self, state):
        pass

    def _score_fun(self, targets, *outputs, batch_loss_extra):
        return torch.tensor(0.)

