# viz/plotting.py
"""
Plotting utilities for experiment results: trade-off curves and summary tables.
Uses matplotlib. For interactive notebooks consider using plotly.
"""
import matplotlib.pyplot as plt
import numpy as np
import os

def plot_tradeoff_lambda(x_vals, metrics_dict, outpath=None, title="Tradeoff Curve"):
    """
    x_vals: list of lambda or ne values
    metrics_dict: dict of metric_name -> list of values (same length as x_vals)
    """
    plt.figure(figsize=(6,4))
    for name, vals in metrics_dict.items():
        plt.plot(x_vals, vals, marker='o', label=name)
    plt.xlabel("lambda / inner-steps")
    plt.ylabel("metric value")
    plt.title(title)
    plt.legend()
    plt.grid(True)
    if outpath:
        os.makedirs(os.path.dirname(outpath) or ".", exist_ok=True)
        plt.savefig(outpath, dpi=150)
    else:
        plt.show()

def save_examples_grid(image_tensors, filenames, outpath="viz/grid.png", ncols=4):
    """
    image_tensors: list of numpy arrays or torch tensors [C,H,W] in [0,1]
    filenames: list of captions
    """
    import math
    from PIL import Image
    n = len(image_tensors)
    ncols = ncols
    nrows = math.ceil(n / ncols)
    fig_w = ncols * 3
    fig_h = nrows * 3
    plt.figure(figsize=(fig_w, fig_h))
    for i, img in enumerate(image_tensors):
        plt.subplot(nrows, ncols, i+1)
        if hasattr(img, "cpu"):
            img = img.cpu().numpy()
        # assume [C,H,W]
        if img.shape[0] == 3:
            img = np.transpose(img, (1,2,0))
        plt.imshow(img.clip(0,1))
        plt.axis("off")
        plt.title(filenames[i][:30])
    plt.tight_layout()
    os.makedirs(os.path.dirname(outpath) or ".", exist_ok=True)
    plt.savefig(outpath, dpi=150)
    plt.close()
