import os
import wandb
from flax.training import orbax_utils
import orbax


CKPT_DIR = "checkpoints"


def init_logger(args, config={}):
    assert (
        args.wandb_project and args.wandb_entity
    ), "Must provide --wandb_project and --wandb_entity arguments to log results."
    temp = config.copy()
    temp.update(vars(args))
    wandb.init(
        config=temp,
        project=args.wandb_project,
        entity=args.wandb_entity,
        group=args.wandb_group,
        job_type="train",
    )


orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()


def log_results(metrics, train_state, gen=None):
    # Log metrics
    wandb.log(metrics)
    print(metrics)

    # Log checkpoints
    if gen:
        save_args = orbax_utils.save_args_from_target(train_state)
        orbax_checkpointer.save(
            os.path.join(wandb.run.dir, CKPT_DIR) + f"/gen_{gen}",
            train_state,
            save_args=save_args,
        )
