import wandb

class WandbWriter:
    def __init__(self, writer, args):
        self.writer = writer
        self.args = args

        self.logs_dict = {}
        self.current_step = None

        wandb.login()
        project = f'NPGApprox'
        wandb.init(project=project, config=vars(self.args))

    def add_scalars(self, k, val_dict, total_num_steps):
        self.writer.add_scalars(k, val_dict, total_num_steps)

        if self.current_step is not None:
            if self.current_step != total_num_steps:
                self.log()
        self.current_step = total_num_steps

        self.logs_dict[k] = val_dict[k]

    def export_scalars_to_json(self, path):
        self.writer.export_scalars_to_json(path)

    def close(self):
        self.writer.close()

    def log(self):
        wandb.log(self.logs_dict, step=self.current_step)
        self.logs_dict = {}