import os
import torch
import glob

def get_latest_checkpoint(folder, prefix):
    files = glob.glob(os.path.join(folder, f"{prefix}*.pt"))
    if not files:
        return None
    return max(files, key=os.path.getctime)  # Most recent file

def save_checkpoint(state, filename):
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    torch.save(state, filename)

def load_checkpoint(model, checkpoint_path, part='base'):
    checkpoint = torch.load(checkpoint_path)
    if part == 'base':
        model.backbone.load_state_dict(checkpoint['backbone'])
        model.projector.load_state_dict(checkpoint['projector'])
        model.classifier.load_state_dict(checkpoint['classifier'])
    elif part == 'head':
        model.classifier.load_state_dict(checkpoint['classifier'])
    return checkpoint

def cleanup_old_checkpoints(ckpt_dir, prefix, latest_epoch, final_epoch):
    ckpts = sorted(glob.glob(os.path.join(ckpt_dir, f"{prefix}*.pt")))
    for ckpt in ckpts:
        if f"epoch{latest_epoch}.pt" in ckpt or f"epoch{final_epoch}.pt" in ckpt:
            continue
        os.remove(ckpt)
