import os
from lightning.pytorch.callbacks import Callback
import shutil
class SaveLastKModels(Callback):
    def __init__(self, max_k):
        self.max_k = max_k

    def on_save_checkpoint(self, trainer, pl_module, checkpoint):
        dirpath = pl_module.save_dir.replace("save", "checkpoints")
        if os.path.exists(dirpath):
            checkpoints = sorted(
                [file for file in os.listdir(dirpath)],
                key=lambda x: os.path.getmtime(os.path.join(dirpath, x))
            )
            print(checkpoints)
            if len(checkpoints) >= self.max_k:
                for file in checkpoints[:-self.max_k]:
                    # check if file is a dir
                    file = os.path.join(dirpath, file)
                    if os.path.isdir(file):
                        shutil.rmtree(file)
                    else:
                        os.remove(file)