import numpy

class EmbeddingsLogger:
    def __init__(self, src_symbols, model, path, etc=10):
        self.etc         = etc
        self.model       = model
        self.file = open(path, "w")
        self.src_symbols = src_symbols
        self.stepd = 0

    def step(self, local):
        self.stepd += 1
        if self.stepd % self.etc != 0: return
        numpy.savetxt(
            self.file, 
            self.model.symbol_embeddings.weight[:self.src_symbols].detach().cpu().view(1,-1).numpy(), 
            fmt='%f', 
            delimiter=','
        )

    def epoch(self, local): pass
        
    def __del__(self): self.file.close()

