import time, torch, utils

class TrainLogger:
    def __init__(self, logger):
        self.logger = logger
        self.start  = time.time()
        self.logid  = 0

    def step(self, local):
        src,tgt  = local["data"]["src"], local["data"]["tgt"]
        logits   = local["result"]["logits"]
        device     = src.device
        batch_size = src.size(0)
        epoch      = local["epoch"]
        step       = local["step"]
        optimizer  = local["self"].optimizer
        loss       = local["loss"]

        self.logger.info(str({
            "logid"         : self.logid,
            "global_step"   : self.logid,
            "time"          : time.time() - self.start,
            "epoch"         : epoch,
            "step"          : step,
            "batch_size"    : batch_size,
            "loss"          : loss.item(),
            "accuracy"      : (logits.argmax(-1, keepdim=True) == tgt).float().mean().item(),
            "learning_rate" : utils.get_learning_rate(optimizer)
        }).replace("'","\""))

        self.logid += 1

    def epoch(self, local):
        pass
