import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def read_training(path: str):
    try:
        df = pd.read_csv(path)
        df.columns = [str(c) for c in df.columns]
        s = df["0"]
    except Exception:
        df = pd.read_csv(path, header=None)
        s = df.iloc[:, 0]
    y = pd.to_numeric(s, errors="coerce").astype(float).to_numpy()
    x = np.arange(len(y))
    return x, y

def read_interp(path: str, expect_cols: int = 50):
    import pandas as pd
    # 1) Try with header
    df = pd.read_csv(path)
    df.columns = [str(c).strip() for c in df.columns]

    want_names = [str(i) for i in range(expect_cols)]
    have = [c for c in want_names if c in df.columns]

    # 2) Fallback: no header
    if len(have) == 0:
        df = pd.read_csv(path, header=None)
        df.columns = [str(i) for i in range(df.shape[1])]
        want_names = [str(i) for i in range(min(expect_cols, df.shape[1]))]
        have = want_names

    # 3) Quick numeric check per column: keep only those that coerce to numeric (not all NaN)
    kept = {}
    for c in have:
        col_num = pd.to_numeric(df[c], errors="coerce")
        if col_num.notna().any():          # filter out “no-number” columns
            kept[int(c)] = col_num.astype(float)

    if not kept:
        raise ValueError("No numeric interpolation columns found after filtering.")

    # 4) Assemble back in numeric order of column names
    ordered_cols = [kept[k] for k in sorted(kept.keys())]
    out = pd.concat(ordered_cols, axis=1)
    return out.to_numpy()  # shape: (rows, kept_cols)

def smooth_ma(y, w=10):
    import pandas as pd
    return pd.Series(y).rolling(window=w, center=True, min_periods=1).mean().to_numpy()

def main(training_csv: str, interp_accuracy_csv: str, interp_loss_csv: str, out_path: str | None, cols: int):
    x_tr, y_tr = read_training(training_csv)
    interp_vals = read_interp(interp_loss_csv, expect_cols=cols)  # (rows, cols)
    inters_accuracy_vals = read_interp(interp_accuracy_csv, expect_cols=cols)
    m = interp_vals.shape[1]

    # Build x for each row i: i + linspace(0,1,m), inclusive (k=0..m-1 => k/(m-1))
    xs = [i + np.linspace(0.0, 1.0, m) for i in range(interp_vals.shape[0])]
    x_interp = np.concatenate(xs) if xs else np.array([])
    y_interp = interp_vals.reshape(-1)
    y_interp_accuracy = inters_accuracy_vals.reshape(-1)

    y_tr = smooth_ma(y_tr, w=50)
    y_interp = smooth_ma(y_interp, w=50)
    y_interp_accuracy = smooth_ma(y_interp_accuracy, w=50)

    fig, axs = plt.subplots(1, 1, figsize=(6, 2))

    left_handles = []
    left_labels = []
    right_handles = []
    right_labels = []

    ax_right = axs.twinx()
    # axs.plot(x_tr, y_tr, label="LLPF loss", linewidth=1)
    h, = axs.plot(x_interp, y_interp, linestyle='solid', label="linear interpolation loss", linewidth=1, c="C0", alpha=0.8)
    left_handles.append(h)
    left_labels.append("interpolation loss")

    h, = ax_right.plot(x_interp, y_interp_accuracy, linestyle='dotted', label="linear interpolation accuracy", linewidth=1, c="C1", alpha=0.8)
    right_handles.append(h)
    right_labels.append("interpolation accuracy")

    axs.set_xlabel(r'iteration $i$')
    axs.set_ylabel("loss")
    ax_right.set_ylabel("accuracy")
    axs.set_ylim(top=0.04)
    ax_right.set_ylim(top=1.005, bottom=0.988)
    axs.set_title("CCT7 @CIFAR10 (FDF strategy)")
    handles = left_handles + right_handles
    labels = left_labels + right_labels
    if handles:
        axs.legend(handles, labels, fontsize='8')
    axs.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
    axs.ticklabel_format(style='sci', axis='x', scilimits=(0,0), useMathText=True)
    fig.tight_layout()
    if out_path:
        fig.savefig(f"{out_path}.pdf", bbox_inches="tight")
        fig.savefig(f"{out_path}.png", dpi=200, bbox_inches="tight")

if __name__ == "__main__":
    ap = argparse.ArgumentParser(description="Plot training loss with consecutive linear interpolation points.")
    ap.add_argument("--training", default="training_loss.csv", help="path to training_loss.csv")
    ap.add_argument("--interp", default="consec_linear_interpolation_loss.csv", help="path to consec_linear_interpolation_loss.csv")
    ap.add_argument("--interp_accuracy", default="consec_linear_interpolation_accuracy.csv", help="path to consec_linear_interpolation_accuracy.csv")
    ap.add_argument("--columns", type=int, default=50, help="number of interpolation columns (default: 50)")
    ap.add_argument("-o", "--output", default="linear_interpolation_cct", help="optional PNG output path")
    args = ap.parse_args()
    main(args.training, args.interp_accuracy, args.interp, args.output, args.columns)