import torch
import utils
import tqdm

class ProgressBar:

    def __init__(self, length):
        self.pbar = tqdm.tqdm(range(length))
        self.tpfn_acc, self.tot_acc = 0,0

    def step(self, local):
        symbol_emb  = local["result"]["symbol_emb"]
        context_emb = local["result"]["context_emb"]
        batch_size = symbol_emb.size(0)
        device     = symbol_emb.device
        step       = local["step"]
        optimizer  = local["self"].optimizer
        loss       = local["loss"]

        self.tpfn_acc, self.tot_acc = self.tpfn_acc + ((symbol_emb @ context_emb.transpose(0,1)).argmax(-1) == torch.arange(symbol_emb.size(0),device=device)).sum().item(), batch_size + self.tot_acc

        acc = self.tpfn_acc/(self.tot_acc+1e-7)
        self.pbar.set_description(f"\rstep:{step: 4d}, lr:{utils.get_learning_rate(optimizer):1.5f}, loss:{loss.item():1.4f}, acc:{acc:1.4f}")

    def epoch(self, local):
        self.pbar.update(1)
        self.tpfn_acc, self.tot_acc = 0,0
