'''PINA Callbacks Implementations'''

from lightning.pytorch.callbacks import Callback
import torch
import copy


class MetricTracker(Callback):
    """
    PINA implementation of a Lightining Callback to track relevant
    metrics during training.
    """
    def __init__(self):
        self._collection = []

    def on_train_epoch_end(self, trainer, __):
        self._collection.append(copy.deepcopy(trainer.logged_metrics)) # track them

    @property
    def metrics(self):
        common_keys = set.intersection(*map(set, self._collection))
        v = {k: torch.stack([dic[k] for dic in self._collection]) for k in common_keys}
        return v

    