import os
import shutil

from transformers import TrainerCallback, TrainerControl, TrainingArguments


class RenameCheckpointCallback(TrainerCallback):
    def on_save(
        self, args: TrainingArguments, state, control: TrainerControl, **kwargs
    ):

        last_ckpt_dir = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
        new_ckpt_dir = os.path.join(args.output_dir, f"epoch-{int(state.epoch)}")

        if os.path.exists(last_ckpt_dir):
            print(f"Renaming {last_ckpt_dir} -> {new_ckpt_dir}")
            shutil.move(last_ckpt_dir, new_ckpt_dir)

        return control
