#!/usr/bin/env python3
import argparse
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import transforms as mtransforms  # for mixed data/x, axes/y coords

TARGETS = {"training_accuracy.csv", "training_loss.csv", "accuracy.csv", "loss.csv"}

def load_csv_keep_cols(path: Path) -> pd.DataFrame:
    """Load CSV and keep only numeric 'tick' and '0' columns."""
    df = pd.read_csv(path)
    if "tick" not in df.columns or "0" not in df.columns:
        raise ValueError(f"File {path} must contain columns: 'tick' and '0'.")
    df = df[["tick", "0"]].copy()
    df["tick"] = pd.to_numeric(df["tick"], errors="coerce")
    df["0"]    = pd.to_numeric(df["0"],    errors="coerce")
    df = df.dropna(subset=["tick", "0"])
    df["tick"] = df["tick"].astype(int)  # ticks are integers
    df = df.sort_values("tick").reset_index(drop=True)
    return df

def concat_by_tick(df1: pd.DataFrame, df2: pd.DataFrame):
    """
    Concatenate df2 after df1 by shifting df2.tick so the combined series continues.
    Returns (concatenated_df, boundary_tick).
    boundary_tick is the x-position where the second file starts.
    """
    if df1.empty:
        return df2.copy(), None
    boundary = int(df1["tick"].max())  # matches 0..A + 0..B -> 0..(A+B)
    df2s = df2.copy()
    df2s["tick"] = df2s["tick"] + boundary
    out = pd.concat([df1, df2s], ignore_index=True)
    out = out.sort_values("tick").reset_index(drop=True)
    return out, boundary

def smooth(y: pd.Series, window: int = 10) -> np.ndarray:
    """
    Rolling mean with given window (min_periods=1).
    Returns numpy array. For log-scale safety, non-positive values are set to NaN.
    """
    ys = pd.Series(y).rolling(window=window, min_periods=1).mean()
    ys = ys.where(ys > 0, np.nan)  # avoid log-scale issues for loss
    return ys.to_numpy()


def main():
    ap = argparse.ArgumentParser(description="Concat matching CSVs from two folders by 'tick' and plot column '0'.")
    ap.add_argument("--folder_resnet_m2m", default="./avs_resnet18_m2m/0-0" ,type=Path, help="Path of m2o folder")
    ap.add_argument("--folder_resnet_m2o", default="./avs_resnet18_m2o/0-to_vs" ,type=Path, help="Path of m2m folder")
    ap.add_argument("--folder_cct_m2m", default="./avs_cct7_m2m/0-0" ,type=Path, help="Path of m2o folder")
    ap.add_argument("--folder_cct_m2o", default="./avs_cct7_m2o/0-to_vs" ,type=Path, help="Path of m2m folder")
    ap.add_argument("--outdir",  default=Path("./output"), type=Path, help="Where to write concatenated CSVs and figure")
    ap.add_argument("--figure",  default="avs_resnet_cct.pdf", help="Filename for the output figure (saved in outdir)")
    ap.add_argument("--smooth",  default=10, type=int, help="Rolling window size for smoothing")
    args = ap.parse_args()

    folder_resnet_m2m = {p.name: p for p in args.folder_resnet_m2m.glob("*.csv") if p.name in TARGETS}
    folder_resnet_m2o = {p.name: p for p in args.folder_resnet_m2o.glob("*.csv") if p.name in TARGETS}
    folder_cct_m2m = {p.name: p for p in args.folder_cct_m2m.glob("*.csv") if p.name in TARGETS}
    folder_cct_m2o = {p.name: p for p in args.folder_cct_m2o.glob("*.csv") if p.name in TARGETS}
    common = sorted(set(folder_resnet_m2m) & set(folder_resnet_m2o) & set(folder_cct_m2m) & set(folder_cct_m2o) & TARGETS)
    if not common:
        raise SystemExit("None of the target files found in BOTH folders: training_loss.csv, accuracy.csv, loss.csv")

    args.outdir.mkdir(parents=True, exist_ok=True)

    # Prepare data, store per-series boundary for vertical line(s)
    series_resnet = {}          # name -> df
    boundaries_resnet = {}      # name -> boundary tick (int)
    for name in common:
        df1 = load_csv_keep_cols(folder_resnet_m2o[name])
        df2 = load_csv_keep_cols(folder_resnet_m2m[name])
        dfc, boundary = concat_by_tick(df1, df2)
        dfc = dfc[["tick", "0"]]
        out_csv = args.outdir / name
        dfc.to_csv(out_csv, index=False)
        series_resnet[name] = dfc
        boundaries_resnet[name] = boundary
    
    series_cct = {}
    boundaries_cct = {}
    for name in common:
        df1 = load_csv_keep_cols(folder_cct_m2o[name])
        df2 = load_csv_keep_cols(folder_cct_m2m[name])
        dfc, boundary = concat_by_tick(df1, df2)
        dfc = dfc[["tick", "0"]]
        out_csv = args.outdir / name
        dfc.to_csv(out_csv, index=False)
        series_cct[name] = dfc
        boundaries_cct[name] = boundary

    # Plot: losses on left axis, accuracy on right axis
    fig, axs = plt.subplots(1,2, figsize=(10, 3))

    def plot_one_axis(axs, series, boundaries, name):
        ax_right = axs.twinx()

        left_handles = []
        left_labels = []
        right_handles = []
        right_labels = []

        # Accuracy curve (left y-axis) - smoothed
        if "training_accuracy.csv" in series:
            h, = ax_right.plot(
                series["training_accuracy.csv"]["tick"].values,
                pd.Series(series["training_accuracy.csv"]["0"].values).rolling(args.smooth, min_periods=1).mean().to_numpy(),
                label="training accuracy", c="C1", linewidth=1, linestyle=(0, (3, 5, 1, 5))
            )
            right_handles.append(h)
            right_labels.append("training accuracy")

        # Loss curves (left y-axis) — smoothed
        loss_name = "training_loss.csv"
        h, = axs.plot(
            series[loss_name]["tick"].values,
            smooth(series[loss_name]["0"].values, window=args.smooth),
            label="training loss", c="C0", linewidth=1, linestyle="solid"
        )
        left_handles.append(h)
        left_labels.append("training loss")

        loss_name = "loss.csv"
        h, = axs.plot(
            series[loss_name]["tick"].values,
            smooth(series[loss_name]["0"].values, window=args.smooth),
            label="test loss", c="C2", linewidth=1, linestyle="dotted"
        )
        left_handles.append(h)
        left_labels.append("test loss")

        # Accuracy curve (right y-axis) — smoothed
        if "accuracy.csv" in series:
            h, = ax_right.plot(
                series["accuracy.csv"]["tick"].values,
                pd.Series(series["accuracy.csv"]["0"].values).rolling(args.smooth, min_periods=1).mean().to_numpy(),
                label="test accuracy", c="C3", linewidth=1, linestyle=(0, (5, 10))
            )
            right_handles.append(h)
            right_labels.append("test accuracy")

        axs.set_xlabel(r'iteration $i$')
        axs.set_ylabel("training & test loss")
        axs.set_yscale("log")
        ax_right.set_ylabel("training & test accuracy")
        axs.set_xlim([series["accuracy.csv"]["tick"].values.min(), series["accuracy.csv"]["tick"].values.max()])

        # Draw unique boundary lines (concat points)
        unique_bounds = sorted({b for b in boundaries.values() if b is not None})
        for b in unique_bounds:
            axs.axvline(x=b, linestyle="dotted", alpha=0.5, linewidth=1, c="black")

        # Determine global x-range for label offsets
        xmins = [df["tick"].min() for df in series.values()]
        xmaxs = [df["tick"].max() for df in series.values()]
        x_min, x_max = (min(xmins), max(xmaxs)) if xmins and xmaxs else (0, 1)
        dx = max(1, 0.015 * (x_max - x_min))  # data-space x offset for boundary labels

        # Add labels to the left/right of (the first) boundary
        if unique_bounds:
            b = unique_bounds[0]
            trans = mtransforms.blended_transform_factory(axs.transData, axs.transAxes)
            left_text = "m2o"
            right_text = "m2m"
            # Place near the top (y=0.92 in axes coords), slightly offset in x in data coords
            axs.text(b - dx, 0.97, left_text,  ha="right", va="top", transform=trans)
            axs.text(b + dx, 0.97, right_text, ha="left",  va="top", transform=trans)

        # Combine legends
        handles = left_handles + right_handles
        labels = left_labels + right_labels
        if handles:
            axs.legend(handles, labels, fontsize='8', loc='center')
        
        axs.ticklabel_format(style='sci', axis='x', scilimits=(0,0), useMathText=True)
        ax_right.ticklabel_format(style='sci', axis='x', scilimits=(0,0), useMathText=True)

        if name == "resnet":
            ax_right.set_ylim(top=1.02)
        elif name == "cct":
            ax_right.set_ylim(top=1.02)
        else:
            raise NotImplementedError

    plot_one_axis(axs[0], series_resnet, boundaries_resnet, name="resnet")
    axs[0].set_title("ResNet18 @CIFAR10")
    plot_one_axis(axs[1], series_cct, boundaries_cct, name="cct")
    axs[1].set_title("CCT7 @CIFAR10")

    fig.subplots_adjust(left=0.01, right=0.99, bottom=0.01, top=0.99, wspace=0.37, hspace=0.30)

    fig_path = args.outdir / args.figure
    # plt.tight_layout()
    fig.savefig(fig_path, bbox_inches="tight")
    print(f"Wrote figure: {fig_path}")
    print(f"Concatenated CSVs in: {args.outdir.resolve()}")

if __name__ == "__main__":
    main()
