#!/usr/bin/env python3
"""
Usage:
python kernelatt_experiment.py \
  --rbf_kernelattn_log logs/transformer_rbf_kernelattn.log \
  --rbf_dva_log logs/transformer_rbf_dva.log \
  --nonsmooth_kernelattn_log logs/transformer_nonsmooth_kernelattn.log \
  --nonsmooth_dva_log logs/transformer_nonsmooth_dva.log \
  --epochs 100 \
  --title "Transformer Comparison: Kernel vs DVA, Smooth vs Non-smooth" \
  --outdir fig/transformer_comparison \
  --csv

Produces a single, publication-quality line chart comparing:
  - Transformer+RBF+KernelAttn
  - Transformer+RBF+DVA
  - Transformer+NonSmooth+KernelAttn
  - Transformer+NonSmooth+DVA
"""

import argparse
import os
import re
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ---------- Parsing ----------

LOG_RE = re.compile(
    r"\[(CNN|Transformer)\]\s*Epoch\s*(\d+)\s*/\s*(\d+).*?Val loss:\s*([\-+]?\d+(?:\.\d+)?)",
    re.IGNORECASE,
)

def parse_log_file(path):
    """Returns dict: {'CNN': {epoch: val_loss}, 'Transformer': {epoch: val_loss}}"""
    results = {"CNN": {}, "Transformer": {}}
    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            m = LOG_RE.search(line)
            if m:
                model = m.group(1)
                epoch = int(m.group(2))
                val_loss = float(m.group(4))
                results[model][epoch] = val_loss
    return results

def union_epochs(*dicts):
    all_epochs = set()
    for d in dicts:
        for model in d:
            all_epochs.update(d[model].keys())
    return sorted(all_epochs)

def to_series(epochs, model_dict):
    """Create a list of values aligned to 'epochs'; NaN if missing."""
    return [model_dict.get(e, np.nan) for e in epochs]

def sanitize_filename(s):
    allowed = "-_.() abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
    return "".join(c if c in allowed else "_" for c in s).strip().replace(" ", "_")

# ---------- Plotting ----------

def configure_matplotlib_for_papers():
    plt.rcParams.update({
        "font.size": 10,
        "axes.titlesize": 10,
        "axes.labelsize": 8,
        "xtick.labelsize": 8,
        "ytick.labelsize": 8,
        "legend.fontsize": 6,
        "figure.dpi": 600,
        "savefig.dpi": 600,
        "axes.linewidth": 0.8,
        "xtick.major.width": 0.8,
        "ytick.major.width": 0.8,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "figure.constrained_layout.use": True,
    })

def make_plot(
    epochs,
    series_dict,
    title,
    outdir,
    figwidth=3.25,
    figheight=2.2,
    legend_loc="best",
):
    configure_matplotlib_for_papers()

    POINTS_PER_EPOCH = 400 * 16 * 100
    x_vals = [e * POINTS_PER_EPOCH for e in epochs]

    fig, ax = plt.subplots(figsize=(figwidth, figheight))

    color_map = {
        "Smooth+KernelAttn": "blue",
        "Smooth+DVA": "green",
        "NonSmooth+KernelAttn": "orange",
        "NonSmooth+DVA": "red",
    }
    style_map = {
        "Smooth+KernelAttn": "-",
        "Smooth+DVA": "--",
        "NonSmooth+KernelAttn": "-.",
        "NonSmooth+DVA": ":",
    }

    plotted_any = False
    for label in [
        "Smooth+KernelAttn",
        "Smooth+DVA",
        "NonSmooth+KernelAttn",
        "NonSmooth+DVA",
    ]:
        y = series_dict.get(label, None)
        if y is None or np.all(np.isnan(y)):
            continue
        plotted_any = True
        ax.plot(
            x_vals,
            y,
            color=color_map[label],
            linestyle=style_map[label],
            linewidth=1.0,
            label=label,
        )

    ax.set_xlabel(r"$N_{train}$")
    ax.set_ylabel("Validation Loss")
    if title:
        ax.set_title(title)

    from matplotlib.ticker import ScalarFormatter
    ax.ticklabel_format(style="sci", axis="x", scilimits=(0,0))
    ax.xaxis.set_major_formatter(ScalarFormatter(useMathText=True))

    if plotted_any:
        ax.legend(loc=legend_loc, frameon=False, handlelength=3, fontsize=6)

    outdir = Path(outdir)
    outdir.mkdir(parents=True, exist_ok=True)
    base = sanitize_filename(title) if title else "val_loss_comparison"
    for ext in ["png", "pdf", "svg"]:
        fig.savefig(outdir / f"{base}.{ext}")

    plt.close(fig)

    return (
        outdir / f"{base}.png",
        outdir / f"{base}.pdf",
        outdir / f"{base}.svg",
    )

def assemble_dataframe(epochs, series_dict):
    df = pd.DataFrame({"epoch": epochs})
    for k, v in series_dict.items():
        df[k] = v
    return df

# ---------- CLI ----------

def main():
    parser = argparse.ArgumentParser(description="Plot Transformer variants with RBF/NonSmooth kernels and Kernel/DVA attn.")
    parser.add_argument("--rbf_kernelattn_log", type=str, required=True, help="Log file for Transformer with RBF kernel + KernelAttn")
    parser.add_argument("--rbf_dva_log", type=str, required=True, help="Log file for Transformer with RBF kernel + DVA")
    parser.add_argument("--nonsmooth_kernelattn_log", type=str, required=True, help="Log file for Transformer with NonSmooth kernel + KernelAttn")
    parser.add_argument("--nonsmooth_dva_log", type=str, required=True, help="Log file for Transformer with NonSmooth kernel + DVA")
    parser.add_argument("--epochs", type=int, default=100, help="Max number of epochs to consider")
    parser.add_argument("--title", type=str, default=None, help="Figure title")
    parser.add_argument("--outdir", type=str, default="figures", help="Output directory")
    parser.add_argument("--csv", action="store_true", help="Also write CSV")
    parser.add_argument("--figwidth", type=float, default=3.25)
    parser.add_argument("--figheight", type=float, default=2.2)
    parser.add_argument("--legend_loc", type=str, default="best")
    args = parser.parse_args()

    # Parse all logs
    rbf_kernelattn = parse_log_file(args.rbf_kernelattn_log)
    rbf_dva = parse_log_file(args.rbf_dva_log)
    nonsmooth_kernelattn = parse_log_file(args.nonsmooth_kernelattn_log)
    nonsmooth_dva = parse_log_file(args.nonsmooth_dva_log)

    epochs = union_epochs(rbf_kernelattn, rbf_dva, nonsmooth_kernelattn, nonsmooth_dva)
    if args.epochs is not None:
        epochs = [e for e in epochs if e <= args.epochs]

    # Build aligned series (always Transformer only)
    series_dict = {
        "Smooth+KernelAttn": to_series(epochs, rbf_kernelattn.get("Transformer", {})),
        "Smooth+DVA": to_series(epochs, rbf_dva.get("Transformer", {})),
        "NonSmooth+KernelAttn": to_series(epochs, nonsmooth_kernelattn.get("Transformer", {})),
        "NonSmooth+DVA": to_series(epochs, nonsmooth_dva.get("Transformer", {})),
    }


    # Plot
    title = args.title if args.title else "Transformer Variants: Val Loss vs Epoch"
    png_path, pdf_path, svg_path = make_plot(
        epochs,
        series_dict,
        title,
        args.outdir,
        figwidth=args.figwidth,
        figheight=args.figheight,
        legend_loc=args.legend_loc,
    )

    # Optional CSV
    if args.csv:
        df = assemble_dataframe(epochs, series_dict)
        csv_path = Path(args.outdir) / (sanitize_filename(title) + ".csv")
        df.to_csv(csv_path, index=False)
        print(f"CSV written:", csv_path)

    print("Saved figure to:")
    print(" -", png_path)
    print(" -", pdf_path)
    print(" -", svg_path)

if __name__ == "__main__":
    main()
