#!/usr/bin/env python3
import argparse
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def smooth(y: pd.Series, window: int = 10) -> np.ndarray:
    """
    Rolling mean with given window (min_periods=1).
    Returns numpy array. For log-scale safety, non-positive values are set to NaN.
    """
    ys = pd.Series(y).rolling(window=window, min_periods=1).mean()
    ys = ys.where(ys > 0, np.nan)  # avoid log-scale issues for loss
    return ys.to_numpy()

def load_and_merge_phase(phase_dir: Path) -> pd.DataFrame:
    """Load CSVs from a phase directory and merge into one DataFrame."""
    wd_path   = phase_dir / "variance.csv"
    loss_path = phase_dir / "loss.csv"
    trl_path  = phase_dir / "training_loss.csv"
    acc_path  = phase_dir / "accuracy.csv"
    trac_path = phase_dir / "training_accuracy.csv"

    # Load
    wd   = pd.read_csv(wd_path)
    loss = pd.read_csv(loss_path)
    trl  = pd.read_csv(trl_path)
    acc  = pd.read_csv(acc_path)
    trac = pd.read_csv(trac_path)

    # Ensure numeric ticks
    for df in (wd, loss, trl, acc, trac):
        df["tick"] = pd.to_numeric(df["tick"], errors="coerce")

    # Rename metric columns
    loss["loss"] = pd.to_numeric(loss["0"], errors="coerce")
    trl["training_loss"] = pd.to_numeric(trl["0"], errors="coerce")
    acc["accuracy"] = pd.to_numeric(acc["0"], errors="coerce")
    trac["training_accuracy"] = pd.to_numeric(trac["0"], errors="coerce")

    loss = loss[["tick", "loss"]]
    trl  = trl[["tick", "training_loss"]]
    acc  = acc[["tick", "accuracy"]]
    trac = trac[["tick", "training_accuracy"]]

    # Merge on tick
    merged = (
        wd.merge(loss, on="tick", how="inner")
          .merge(trl,  on="tick", how="inner")
          .merge(acc,  on="tick", how="inner")
          .merge(trac, on="tick", how="inner")
          .sort_values("tick")
          .reset_index(drop=True)
    )
    return merged


def main():
    parser = argparse.ArgumentParser(
        description="Concatenate phase_1 and phase_2 CSVs and plot metrics vs conv1.weight."
    )
    parser.add_argument(
        "--base-dir",
        type=Path,
        default=Path("."),
        help="Directory containing phase_1 and phase_2 (default: current directory).",
    )
    parser.add_argument(
        "--save-csv",
        action="store_true",
        help="Save intermediate merged CSVs and final concatenated CSV.",
    )
    args = parser.parse_args()

    base = args.base_dir
    phase1_dir = base / "phase_1"
    phase2_dir = base / "phase_2"

    if not phase1_dir.is_dir():
        raise SystemExit("phase_1 directory must exist in base-dir")

    # Always load phase_1
    df1 = load_and_merge_phase(phase1_dir)
    if args.save_csv:
        df1.to_csv(base / "phase_1_merged.csv", index=False)

    # Optionally load and append phase_2
    if phase2_dir.is_dir():
        df2 = load_and_merge_phase(phase2_dir)
        if args.save_csv:
            df2.to_csv(base / "phase_2_merged.csv", index=False)

        # Offset phase_2 ticks so that tick=0 in phase_2 is next after last phase_1 tick
        max_tick_phase1 = df1["tick"].max()
        offset = max_tick_phase1 + 1
        df2 = df2.copy()
        df2["tick"] = df2["tick"] + offset

        combined = pd.concat([df1, df2], ignore_index=True).sort_values("tick")
    else:
        # No phase_2 → just use phase_1
        combined = df1.copy()

    combined = combined.reset_index(drop=True)

    if args.save_csv:
        combined.to_csv(base / "phases_concatenated.csv", index=False)

    # Prepare data
    x = combined["conv1.weight"]

    loss = combined["loss"]
    tr_loss = combined["training_loss"]
    acc = combined["accuracy"]
    tr_acc = combined["training_accuracy"]

    # Plot with twin y-axes
    fig, ax1 = plt.subplots()

    # Log scale on x-axis
    ax1.set_xscale("log")
    ax1.set_yscale("log")
    ax1.invert_xaxis()

    lns = []
    # Left axis: losses
    # l1, = ax1.plot(x, loss, label="loss")
    # lns.append(l1)
    l2, = ax1.plot(x, smooth(tr_loss), label="LLPF m2o path", linestyle="--", c="C0", linewidth=1)
    lns.append(l2)
    ax1.set_xlabel("weights variance of layer conv1.weight")
    ax1.set_ylabel("training loss")
    ax1.grid(True)

    # Right axis: accuracies
    # ax2 = ax1.twinx()
    # l3, = ax2.plot(x, acc, label="accuracy")
    # lns.append(l3)
    # l4, = ax2.plot(x, tr_acc, label="training_accuracy", linestyle="--")
    # lns.append(l4)
    # ax2.set_ylabel("accuracy")

    wd_loss_path = base / "wd_loss.csv"
    if not wd_loss_path.is_file():
        raise SystemExit(f"{wd_loss_path} not found")
    df_wd = pd.read_csv(wd_loss_path)
    df_wd["training_loss"] = pd.to_numeric(df_wd["training_loss"], errors="coerce")
    df_wd["variance_of_first_layer"] = pd.to_numeric(
        df_wd["variance_of_first_layer"], errors="coerce"
    )

    x_var = df_wd["variance_of_first_layer"]
    y_tr_loss_wd = df_wd["training_loss"]

    ln = ax1.scatter(x_var, y_tr_loss_wd, s=10, alpha=0.8, c="C1", label="low-loss models from SGD")
    lns.append(ln)

    # Combined legend
    labels = [ln.get_label() for ln in lns]
    ax1.legend(lns, labels, loc="best")

    ax1.set_title("ResNet18 @CIFAR10")

    plt.tight_layout()
    plt.savefig("m2o_long.pdf")


if __name__ == "__main__":
    main()
