import loggers.Logger as Logger
import wandb
import omegaconf


class Wandb_logger(Logger.Logger):
    def __init__(self, project_cfg, group="", entity=None, printer=False) -> None:
        super().__init__(project_cfg=project_cfg)
        run_name = f"seed_{project_cfg.seed}"
        job_type = f"{project_cfg.job_type}"

        wandb_config = omegaconf.OmegaConf.to_container(
            project_cfg, resolve=True, throw_on_missing=True
        )
        wandb.init(
            entity=entity,
            project=self.project_cfg.project_name,
            group=group,
            name=run_name,
            job_type=job_type,
            config=wandb_config,
        )

        # define  custom x axis metric for pretrain and train
        wandb.define_metric("pretrain/step")
        wandb.define_metric("train/step")
        self.printer = printer

        if self.printer:
            print("PROJECT: ", self.project_cfg.project_name)
            print("GROUP: ", group)
            print("JOB TYPE: ", job_type)
            print("RUN NAME: ", run_name)

    def log(self, parameter, value, step, step_name=None):
        if step_name is None:
            if self.phase_name == "pretrain":
                step_name = "pretrain/step"
            else:
                step_name = "train/step"

        if type(value) == dict:
            for parameter_ind, value_ind in value.items():
                parameter_name = "/".join([self.phase_name, parameter, parameter_ind])
                wandb.log({parameter_name: value_ind, step_name: step})
        else:
            parameter_name = "/".join([self.phase_name, parameter])
            wandb.log({parameter_name: value, step_name: step})

        if self.printer:
            print(
                f"\nphase:{self.phase_name}, {step_name}: {step}| parameter: {parameter} = {value}"
            )

    def finish(self):
        wandb.finish()
