from pathlib import Path
from typing import Tuple

from megatron.training.global_vars import get_wandb_writer
from megatron.training.utils import print_rank_last
from megatron.training.wandb_utils import _get_artifact_name_and_version, _get_wandb_artifact_tracker_filename


def on_save_checkpoint_success(checkpoint_path: str, tracker_filename: str, save_dir: str, iteration: int) -> None:
    """Function to be called after checkpointing succeeds and checkpoint is persisted for logging it as an artifact in W&B

    Args:
        checkpoint_path (str): path of the saved checkpoint
        tracker_filename (str): path of the tracker filename for the checkpoint iteration
        save_dir (str): path of the root save folder for all checkpoints
        iteration (int): iteration of the checkpoint
    """
                   
    checkpoint_path = str(Path(checkpoint_path).resolve())
                 

    wandb_writer = get_wandb_writer()

    if wandb_writer:
        metadata = {"iteration": iteration}
        artifact_name, artifact_version = _get_artifact_name_and_version(Path(save_dir), Path(checkpoint_path))
        artifact = wandb_writer.Artifact(artifact_name, type="model", metadata=metadata)
        artifact.add_reference(f"file://{checkpoint_path}", checksum=False)
        artifact.add_file(tracker_filename)
                       
        wandb_writer.log_artifact(artifact, aliases=[artifact_version])
                     
        wandb_tracker_filename = _get_wandb_artifact_tracker_filename(save_dir)
        wandb_tracker_filename.write_text(f"{wandb_writer.run.entity}/{wandb_writer.run.project}")
