import utils
import tqdm

class ProgressBar:

    def __init__(self, length):
        self.pbar = tqdm.tqdm(range(length))
        self.tpfn_acc, self.tot_acc = 0,0

    def step(self, local):
        logits = local["result"]["logits"]
        src, tgt = local["data"]["src"], local["data"]["tgt"]
        batch_size = src.size(0)
        step       = local["step"]
        optimizer  = local["self"].optimizer
        loss       = local["loss"]

        self.tpfn_acc, self.tot_acc = self.tpfn_acc + (logits.argmax(-1, keepdim=True) == tgt).sum().item(), batch_size + self.tot_acc

        acc = self.tpfn_acc/(self.tot_acc+1e-7)
        self.pbar.set_description(f"\rstep:{step: 4d}, lr:{utils.get_learning_rate(optimizer):1.5f}, loss:{loss.item():1.4f}, acc:{acc:1.4f}")

    def epoch(self, local):
        self.pbar.update(1)
        self.tpfn_acc, self.tot_acc = 0,0
