
class EvalLogger:
    def __init__(self, src_symbols, tgt_symbols, logger):
        self.logger = logger
        self.tpfn_acc = 0
        self.samples  = 0
        self.tot_loss = 0
        self.logid    = 0
        self.tpfn_acc = 0

        self.src_symbols = src_symbols
        self.tgt_symbols = tgt_symbols

    def step(self, local):
        logits = local["result"]["logits"]
        src, tgt =  local["data"]["src"], local["data"]["tgt"]

        self.tot_loss += local["self"].loss_fn(logits, tgt, reduction="sum").item()
        self.samples  += src.size(0)
        self.tpfn_acc += (logits.argmax(-1, keepdim=True) == tgt).sum().item()

    def epoch(self, local):

        self.logger.info(str({
            "logid"       : self.logid,
            "global_step" : local["self"].global_step,
            "samples"     : self.samples,
            "loss"        : self.tot_loss / (self.samples+1e-7),
            "accuracy"    : self.tpfn_acc / (self.samples+1e-7),
        }).replace("'","\""))

        self.logid += 1
        self.samples = 0
        self.tpfn_acc = 0
        self.tot_loss = 0
