#!/usr/bin/env python3
import re
import argparse
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import matplotlib.ticker as mticker


def parse_args():
    p = argparse.ArgumentParser(description="Plot loss curves from training .log")
    p.add_argument("--log", type=str, default="../../data/energy-stats/ebm_training_epochs.log",
                   help="Path to the .log file")
    p.add_argument("--series", choices=["train", "val", "both"], default="both",
                   help="Which series to plot (train, val, or both)")
    p.add_argument("--out", default="loss_curves.png", help="Output image filename")
    p.add_argument("--dpi", type=int, default=300, help="Figure DPI")
    return p.parse_args()

def main():
    args = parse_args()
    log_path = Path(args.log)
    if not log_path.exists():
        raise FileNotFoundError(f"Log file not found: {log_path.resolve()}")

    # Regex patterns (robust to extra spaces)
    epoch_re = re.compile(r"^\s*Epoch\s+(\d+)\b", re.IGNORECASE)
    overall_loss_re = re.compile(r"Overall\s+Loss\(T/V\):\s*([0-9.]+)\s*/\s*([0-9.]+)", re.IGNORECASE)
    method_loss_re = re.compile(
        r"^\s*-\s+([^\|]+?)\s*\|\s*Loss\(T/V\):\s*([0-9.]+)\s*/\s*([0-9.]+)",
        re.IGNORECASE
    )

    # Storage
    seen_epochs = []  # keep order as they appear
    overall_train = {}
    overall_val = {}
    method_losses = defaultdict(lambda: {"train": {}, "val": {}})

    current_epoch = None
    with open(log_path, "r", encoding="utf-8") as f:
        for line in f:
            # Detect epoch (don’t 'continue'—same line can contain Overall Loss)
            e = epoch_re.search(line)
            if e:
                current_epoch = int(e.group(1))
                if not seen_epochs or seen_epochs[-1] != current_epoch:
                    seen_epochs.append(current_epoch)

            if current_epoch is None:
                continue

            # Overall on same line as Epoch or elsewhere
            m_overall = overall_loss_re.search(line)
            if m_overall:
                overall_train[current_epoch] = float(m_overall.group(1))
                overall_val[current_epoch] = float(m_overall.group(2))

            # Method-specific lines
            m_method = method_loss_re.search(line)
            if m_method:
                name = m_method.group(1).strip()
                tr = float(m_method.group(2))
                va = float(m_method.group(3))
                method_losses[name]["train"][current_epoch] = tr
                method_losses[name]["val"][current_epoch] = va

    # Use the ordered epochs we actually saw
    epochs = seen_epochs
    if not epochs:
        raise ValueError("No epochs found. Check the log format or regexes.")

    def align(series_dict):
        return [series_dict.get(ep, np.nan) for ep in epochs]

    y_overall_train = align(overall_train)
    y_overall_val = align(overall_val)

    plt.figure(figsize=(9, 6))

    # Plot only overall losses
    if args.series in ("train", "both"):
        plt.plot(epochs, y_overall_train, linewidth=2, label="training loss")
    if args.series in ("val", "both"):
        plt.plot(epochs, y_overall_val, linewidth=2, label="validation loss")

    # Axis labels and title
    plt.xlabel("Epoch")
    plt.ylabel("Average loss during training")
    plt.title("Average loss for train/validation datasets")

    # Set x-axis ticks every 10 epochs
    plt.gca().xaxis.set_major_locator(mticker.MultipleLocator(10))

    # Optional grid
    plt.grid(True, linestyle="--", alpha=0.6)

    # Legend & layout
    plt.legend()
    plt.tight_layout()

    # Save as PDF
    out_file = args.out if args.out.endswith(".pdf") else args.out.replace(".png", ".pdf")
    plt.savefig(out_file, dpi=args.dpi, bbox_inches="tight")
    # plt.show()



if __name__ == "__main__":
    main()
