import wandb


class WandbLogger:
    def __init__(self, opt):
        wandb.init(project='Multimodal Representation',
                   group=f'{opt.dataset}_{opt.stage}',
                   job_type=f'{opt.wandb_run}', 
                   config=opt)
        
        self.online = opt.wandb_online
        self.log_dict = {}

    def add_log(self, key, value):
        self.log_dict[key] = value
    
    def add_full_log(self, log_dict):
        self.log_dict.update(log_dict)
        
    def write_online_log(self):
        wandb.log(self.log_dict)
    
    def write_log(self, epoch):
        print(f'Epoch {epoch}: {self.log_dict}')
        if self.online:
            self.write_online_log()
        self.log_dict.clear()

