import time, torch, utils

class TrainLogger:
    def __init__(self, logger):
        self.logger = logger
        self.start  = time.time()
        self.logid  = 0

    def step(self, local):
        symbol_emb  = local["result"]["symbol_emb"]
        context_emb = local["result"]["context_emb"]
        device     = symbol_emb.device
        batch_size = symbol_emb.size(0)
        epoch      = local["epoch"]
        step       = local["step"]
        optimizer  = local["self"].optimizer
        loss       = local["loss"]

        self.logger.info(str({
            "logid"         : self.logid,
            "global_step"   : self.logid,
            "time"          : time.time() - self.start,
            "epoch"         : epoch,
            "step"          : step,
            "batch_size"    : batch_size,
            "loss"          : loss.item(),
            "accuracy"      : ((symbol_emb @ context_emb.transpose(0,1)).argmax(-1) == torch.arange(batch_size,device=device)).float().mean().item(),
            "learning_rate" : utils.get_learning_rate(optimizer)
        }).replace("'","\""))

        self.logid += 1

    def epoch(self, local):
        pass
