#!/usr/bin/env python3
import argparse
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import transforms as mtransforms  # for mixed data/x, axes/y coords

TARGETS = {"training_accuracy.csv", "training_loss.csv", "accuracy.csv", "loss.csv"}

def load_csv_keep_cols(path: Path) -> pd.DataFrame:
    """Load CSV and keep only numeric 'tick' and '0' columns."""
    df = pd.read_csv(path)
    if "tick" not in df.columns or "0" not in df.columns:
        raise ValueError(f"File {path} must contain columns: 'tick' and '0'.")
    df = df[["tick", "0"]].copy()
    df["tick"] = pd.to_numeric(df["tick"], errors="coerce")
    df["0"]    = pd.to_numeric(df["0"],    errors="coerce")
    df = df.dropna(subset=["tick", "0"])
    df["tick"] = df["tick"].astype(int)  # ticks are integers
    df = df.sort_values("tick").reset_index(drop=True)
    return df

def concat_by_tick(df1: pd.DataFrame, df2: pd.DataFrame):
    """
    Concatenate df2 after df1 by shifting df2.tick so the combined series continues.
    Returns (concatenated_df, boundary_tick).
    boundary_tick is the x-position where the second file starts.
    """
    if df1.empty:
        return df2.copy(), None
    boundary = int(df1["tick"].max())  # matches 0..A + 0..B -> 0..(A+B)
    df2s = df2.copy()
    df2s["tick"] = df2s["tick"] + boundary
    out = pd.concat([df1, df2s], ignore_index=True)
    out = out.sort_values("tick").reset_index(drop=True)
    return out, boundary

def smooth(y: pd.Series, window: int = 10) -> np.ndarray:
    """
    Rolling mean with given window (min_periods=1).
    Returns numpy array. For log-scale safety, non-positive values are set to NaN.
    """
    ys = pd.Series(y).rolling(window=window, min_periods=1).mean()
    ys = ys.where(ys > 0, np.nan)  # avoid log-scale issues for loss
    return ys.to_numpy()

def main():
    ap = argparse.ArgumentParser(description="Concat matching CSVs from two folders by 'tick' and plot column '0'.")
    ap.add_argument("--folder1", default="./dla_avs_m2o/0-to_vs" ,type=Path, help="Path of m2o folder")
    ap.add_argument("--folder2", default="./dla_avs_m2m/0-0" ,type=Path, help="Path of m2m folder")
    ap.add_argument("--outdir",  default=Path("./dla_avs_out"), type=Path, help="Where to write concatenated CSVs and figure")
    ap.add_argument("--figure",  default="dla_avs.pdf", help="Filename for the output figure (saved in outdir)")
    ap.add_argument("--smooth",  default=10, type=int, help="Rolling window size for smoothing")
    args = ap.parse_args()

    f1 = {p.name: p for p in args.folder1.glob("*.csv") if p.name in TARGETS}
    f2 = {p.name: p for p in args.folder2.glob("*.csv") if p.name in TARGETS}
    common = sorted(set(f1) & set(f2) & TARGETS)
    if not common:
        raise SystemExit("None of the target files found in BOTH folders: training_loss.csv, accuracy.csv, loss.csv")

    args.outdir.mkdir(parents=True, exist_ok=True)

    # Prepare data, store per-series boundary for vertical line(s)
    series = {}          # name -> df
    boundaries = {}      # name -> boundary tick (int)
    for name in common:
        df1 = load_csv_keep_cols(f1[name])
        df2 = load_csv_keep_cols(f2[name])
        dfc, boundary = concat_by_tick(df1, df2)
        dfc = dfc[["tick", "0"]]
        out_csv = args.outdir / name
        dfc.to_csv(out_csv, index=False)
        series[name] = dfc
        boundaries[name] = boundary

    # Determine global x-range for label offsets
    xmins = [df["tick"].min() for df in series.values()]
    xmaxs = [df["tick"].max() for df in series.values()]
    x_min, x_max = (min(xmins), max(xmaxs)) if xmins and xmaxs else (0, 1)
    dx = max(1, 0.012 * (x_max - x_min))  # data-space x offset for boundary labels

    # Plot: losses on left axis, accuracy on right axis
    fig, ax_left = plt.subplots(figsize=(6, 2))
    ax_right = ax_left.twinx()

    left_handles = []
    left_labels = []
    right_handles = []
    right_labels = []

    # Accuracy curve (left y-axis) - smoothed
    if "training_accuracy.csv" in series:
        h, = ax_right.plot(
            series["training_accuracy.csv"]["tick"].values,
            pd.Series(series["training_accuracy.csv"]["0"].values).rolling(args.smooth, min_periods=1).mean().to_numpy(),
            label="training accuracy", c="C1", linewidth=1, linestyle=(0, (3, 5, 1, 5))
        )
        right_handles.append(h)
        right_labels.append("training accuracy")

    # Loss curves (left y-axis) - smoothed
    loss_name = "training_loss.csv"
    h, = ax_left.plot(
        series[loss_name]["tick"].values,
        smooth(series[loss_name]["0"].values, window=args.smooth),
        label="training loss", c="C0", linewidth=1, linestyle="solid"
    )
    left_handles.append(h)
    left_labels.append("training loss")

    loss_name = "loss.csv"
    h, = ax_left.plot(
        series[loss_name]["tick"].values,
        smooth(series[loss_name]["0"].values, window=args.smooth),
        label="test loss", c="C2", linewidth=1, linestyle="dotted"
    )
    left_handles.append(h)
    left_labels.append("test loss")

    # Accuracy curve (right y-axis) - smoothed
    if "accuracy.csv" in series:
        h, = ax_right.plot(
            series["accuracy.csv"]["tick"].values,
            pd.Series(series["accuracy.csv"]["0"].values).rolling(args.smooth, min_periods=1).mean().to_numpy(),
            label="test accuracy", c="C3", linewidth=1, linestyle=(0, (5, 10))
        )
        right_handles.append(h)
        right_labels.append("test accuracy")

    ax_left.set_xlabel(r'iteration $i$')
    ax_left.set_ylabel("training & test loss")
    ax_left.set_yscale("log")
    ax_right.set_ylabel("training & test accuracy")
    ax_right.set_ylim(top=1.02, bottom=0.88)
    ax_left.set_xlim([series["accuracy.csv"]["tick"].values.min(), series["accuracy.csv"]["tick"].values.max()])

    # Draw unique boundary lines (concat points)
    unique_bounds = sorted({b for b in boundaries.values() if b is not None})
    for b in unique_bounds:
        ax_left.axvline(x=b, linestyle="dotted", alpha=0.5, linewidth=1, c="black")

    # Add labels to the left/right of (the first) boundary
    if unique_bounds:
        b = unique_bounds[0]
        trans = mtransforms.blended_transform_factory(ax_left.transData, ax_left.transAxes)
        left_text = "m2o"
        right_text = "m2m"
        # Place near the top (y=0.92 in axes coords), slightly offset in x in data coords
        ax_left.text(b - dx, 0.97, left_text,  ha="right", va="top", transform=trans)
        ax_left.text(b + dx, 0.97, right_text, ha="left",  va="top", transform=trans)

    # Combine legends
    handles = left_handles + right_handles
    labels = left_labels + right_labels
    if handles:
        ax_left.legend(handles, labels, loc="lower left", fontsize='8')

    ax_left.set_title("DLA @CIFAR10")
    ax_left.ticklabel_format(style='sci', axis='x', scilimits=(0,0), useMathText=True)
    ax_right.ticklabel_format(style='sci', axis='x', scilimits=(0,0), useMathText=True)
    fig_path = args.outdir / args.figure
    # plt.tight_layout()
    plt.savefig(fig_path, bbox_inches="tight")
    print(f"Wrote figure: {fig_path}")
    print(f"Concatenated CSVs in: {args.outdir.resolve()}")

if __name__ == "__main__":
    main()
