import torch
import os

import numpy as np

def save_checkpoint(epoch, model, metrics, model_dir_path, best_model_path):
    """Save model checkpoint and best model if validation improves"""
    val_loss = metrics.get("validation_loss", float("inf"))

    # Save checkpoint
    checkpoint_path = os.path.join(model_dir_path, f"checkpoint_epoch_{epoch}.pt")
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "validation_loss": val_loss,
            "train_loss": metrics.get("train_loss", None),
        },
        checkpoint_path,
    )

    # Check if validation improved
    if val_loss <= metrics["best_validation_loss"]:
        print(f"Epoch {epoch}: Saving best model with validation loss: {val_loss:.4f}")
        # Save best model
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "validation_loss": val_loss,
                "train_loss": metrics.get("train_loss", None),
            },
            best_model_path,
        )

def save_checkpoint_if_validation_improved(epoch, model, metrics, model_dir_path, best_model_path):
    """Save model checkpoint and best model if validation improves"""
    val_loss = metrics.get("validation_loss", float("inf"))

    # Check if validation improved
    if val_loss <= metrics["best_validation_loss"]:
        print(f"Epoch {epoch}: Saving best model with validation loss: {val_loss:.4f}")
        # Save best model
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "validation_loss": val_loss,
                "train_loss": metrics.get("train_loss", None),
            },
            best_model_path,
        )
