# train/checkpoint.py
import os
import glob
import torch

def save_best_checkpoint(model, optimizer, scheduler, epoch, val, output_dir, best_val):
    if val < best_val:
        ckpt_path = os.path.join(output_dir, f"best_epoch{epoch+1}_{val:.4f}.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val': val
        }, ckpt_path)
        return ckpt_path
    return None

def clean_old_checkpoints(output_dir, keep=1):
    files = glob.glob(os.path.join(output_dir, 'best_epoch*.pth'))
    files = sorted(files, key=os.path.getmtime)
    for ckpt in files[:-keep]:
        try:
            os.remove(ckpt)
        except OSError:
            pass

def load_checkpoint_if_exists(model, optimizer, scheduler, ckpt_path, device):
    if ckpt_path and os.path.isfile(ckpt_path):
        state = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(state['model_state_dict'])
        optimizer.load_state_dict(state['optimizer_state_dict'])
        scheduler.load_state_dict(state['scheduler_state_dict'])
        start_epoch = state.get('epoch', 0) + 1
        best_val = state.get('val', float('inf'))
        print(f"Loaded checkpoint {ckpt_path}, resume from epoch {start_epoch}")
        return start_epoch, best_val
    return 0, float('inf')
