
class EvalLogger:
    def __init__(self, src_symbols, tgt_symbols, logger):
        self.logger = logger
        self.tpfn_acc = 0
        self.samples  = 0
        self.tot_loss = 0
        self.logid    = 0
        self.tpfn_acc = 0

        self.src_symbols = src_symbols
        self.tgt_symbols = tgt_symbols

    def step(self, local):
        symbols ,  symbols_emb = local["data"]["symbol"] , local["result"]["symbol_emb"]
        contexts, contexts_emb = local["data"]["context"], local["result"]["context_emb"]

        best_match = (symbols_emb @ contexts_emb.transpose(0,1)).argmax(-1)

        self.tot_loss += local["self"].loss_fn(symbols_emb, contexts_emb, reduction="sum").item()
        self.samples  += symbols_emb.size(0)
        self.tpfn_acc += (symbols[best_match] % self.tgt_symbols == symbols % self.tgt_symbols).sum().item()

    def epoch(self, local):

        self.logger.info(str({
            "logid"       : self.logid,
            "global_step" : local["self"].global_step,
            "samples"     : self.samples,
            "loss"        : self.tot_loss / (self.samples+1e-7),
            "accuracy"    : self.tpfn_acc / (self.samples+1e-7),
        }).replace("'","\""))

        self.logid += 1
        self.samples = 0
        self.tpfn_acc = 0
        self.tot_loss = 0
