import wandb


class DummyLogger(object):
    def __init__(self,
                 project: str,
                 job_type: str,
                 config):
        self.project = project
        self.job_type = job_type
        self.config = config

    def log_loss(self, kvs, step):
        pass

    def log_metric(self, kvs, step):
        pass


class WandbLogger(DummyLogger):
    def __init__(self,
                 project: str, job_type: str, config
                 ):
        super().__init__(project, job_type, config)
        wandb.init(project=project, job_type=job_type, config=config)

    def __del__(self):
        wandb.finish()

    def log_loss(self, kvs, step):
        wandb.log(kvs, step=step)

    def log_metric(self, kvs, step):
        wandb.log(kvs, step=step)




