import torch

class LogCdists:
    def __init__(self, model, src_symbols, tgt_symbols, logger, etc=100):
        self.etc         = etc
        self.model       = model
        self.src_symbols = src_symbols
        self.tgt_symbols = tgt_symbols

        self.logger      = logger
        self.logid       = 0
        self.global_step = 0

    def step(self, local):
        self.global_step += 1
        if local["step"] % self.etc != 0: return
        cdist = torch.cdist(self.model.embeddings.weight, self.model.embeddings.weight)
        for i in range(self.src_symbols):
            for j in range(i + 1, self.src_symbols):
                self.logger.info(str({
                    "logid"       : self.logid,
                    "global_step" : self.global_step,
                    "epoch"       : local["epoch"],
                    "step"        : local["step"],
                    "i"           : i,
                    "j"           : j,
                    "dist"        : cdist[i,j].item(),
                    "semeqv"      : "T" if (j % self.tgt_symbols) == (i % self.tgt_symbols) else "F"
                }).replace("'","\""))
                self.logid += 1

    def epoch(self, local): pass
