
\"\"\"visualization.py

Plotting utilities for learning curves, uncertainty heatmaps and PDE fields.
Saves PNG files to an output directory.
\"\"\"
import os
import matplotlib.pyplot as plt
import numpy as np

def plot_learning_curve(history: dict, out_path: str, title: str = None):
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    epochs = range(1, len(history.get("train_loss", [])) + 1)
    plt.figure(figsize=(6,4))
    plt.plot(epochs, history.get("train_loss", []), label="train_loss")
    plt.plot(epochs, history.get("val_loss", []), label="val_loss")
    # additional metrics
    for k,v in history.items():
        if k.startswith("val_") and k not in ("val_loss",):
            plt.plot(epochs, v, label=k)
    plt.xlabel("Epoch")
    plt.ylabel("Loss / Metric")
    if title:
        plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()

def plot_pde_field(field: np.ndarray, out_path: str, cmap: str = "viridis"):
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    plt.figure(figsize=(5,4))
    plt.imshow(field, origin="lower")
    plt.colorbar()
    plt.title("PDE Field")
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()

def plot_uncertainty_heatmap(uncertainty: np.ndarray, out_path: str):
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    plt.figure(figsize=(5,4))
    plt.imshow(uncertainty, origin="lower")
    plt.colorbar()
    plt.title("Uncertainty heatmap")
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()
