import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def read_training(path: str):
    try:
        df = pd.read_csv(path)
        df.columns = [str(c) for c in df.columns]
        s = df["0"]
    except Exception:
        df = pd.read_csv(path, header=None)
        s = df.iloc[:, 0]
    y = pd.to_numeric(s, errors="coerce").astype(float).to_numpy()
    x = np.arange(len(y))
    return x, y

def read_interp(path: str, expect_cols: int = 50):
    import pandas as pd
    # 1) Try with header
    df = pd.read_csv(path)
    df.columns = [str(c).strip() for c in df.columns]

    want_names = [str(i) for i in range(expect_cols)]
    have = [c for c in want_names if c in df.columns]

    # 2) Fallback: no header
    if len(have) == 0:
        df = pd.read_csv(path, header=None)
        df.columns = [str(i) for i in range(df.shape[1])]
        want_names = [str(i) for i in range(min(expect_cols, df.shape[1]))]
        have = want_names

    # 3) Quick numeric check per column: keep only those that coerce to numeric (not all NaN)
    kept = {}
    for c in have:
        col_num = pd.to_numeric(df[c], errors="coerce")
        if col_num.notna().any():          # filter out “no-number” columns
            kept[int(c)] = col_num.astype(float)

    if not kept:
        raise ValueError("No numeric interpolation columns found after filtering.")

    # 4) Assemble back in numeric order of column names
    ordered_cols = [kept[k] for k in sorted(kept.keys())]
    out = pd.concat(ordered_cols, axis=1)
    return out.to_numpy()  # shape: (rows, kept_cols)

def smooth_ma(y, w=10):
    import pandas as pd
    return pd.Series(y).rolling(window=w, center=True, min_periods=1).mean().to_numpy()

def main(training_csv: str, interp_loss_csv: str, interp_accuracy_csv: str, cols: int, axs, name):
    # x_tr, y_tr = read_training(training_csv)
    interp_loss_vals = read_interp(interp_loss_csv, expect_cols=cols)  # (rows, cols)
    interp_accuracy_vals = read_interp(interp_accuracy_csv, expect_cols=cols)
    m = interp_loss_vals.shape[1]

    # Build x for each row i: i + linspace(0,1,m), inclusive (k=0..m-1 => k/(m-1))
    xs = [i + np.linspace(0.0, 1.0, m) for i in range(interp_loss_vals.shape[0])]
    x_interp = np.concatenate(xs) if xs else np.array([])
    y_loss_interp = interp_loss_vals.reshape(-1)
    y_train_interp = interp_accuracy_vals.reshape(-1)

    # y_tr = smooth_ma(y_tr, w=10)
    y_loss_interp = smooth_ma(y_loss_interp, w=50)
    y_train_interp = smooth_ma(y_train_interp, w=50)

    left_handles = []
    left_labels = []
    right_handles = []
    right_labels = []

    ax_right = axs.twinx()
    # axs.plot(x_tr, y_tr, label="LLPF loss", linewidth=1)
    h, = axs.plot(x_interp, y_loss_interp, linestyle='solid', label="interpolation loss", linewidth=1, alpha=0.8,c='C0')
    left_handles.append(h)
    left_labels.append("interpolation loss")
    h, = ax_right.plot(x_interp, y_train_interp, linestyle='dotted', label="interpolation accuracy", linewidth=1, alpha=0.8,c='C1')
    right_handles.append(h)
    right_labels.append("interpolation accuracy")
    axs.set_xlabel(r'iteration $i$')
    axs.set_ylabel("loss")
    ax_right.set_ylabel("accuracy")

    # axs.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
    handles = left_handles + right_handles
    labels = left_labels + right_labels
    if handles:
        axs.legend(handles, labels, fontsize='8', loc='upper left')
    axs.ticklabel_format(style='sci', axis='x', scilimits=(0,0), useMathText=True)

    if name == "resnet":
        ax_right.set_ylim(top=1.002)
        axs.set_ylim(top=0.05)
    elif name == "dla":
        ax_right.set_ylim(top=1.001)
    else:
        raise NotImplementedError


if __name__ == "__main__":
    fig, axs = plt.subplots(1, 2, figsize=(10, 3))

    training_loss_file = "training_loss.csv"
    interp_loss_file = "consec_linear_interpolation_loss.csv"
    interp_accuracy_file = "consec_linear_interpolation_accuracy.csv"
    
    path = "./dla_cifar10/0-0"
    main(f"{path}/{training_loss_file}", f"{path}/{interp_loss_file}", f"{path}/{interp_accuracy_file}", 50, axs[0], "dla")
    axs[0].set_title("DLA @CIFAR10")

    path = "./resnet18_cifar100/0-0"
    main(f"{path}/{training_loss_file}", f"{path}/{interp_loss_file}", f"{path}/{interp_accuracy_file}", 50, axs[1], "resnet")
    axs[1].set_title("ResNet18 @CIFAR100")

    fig.tight_layout(pad=0.5)

    fig.savefig("linear_interperpolation_appendix.pdf", bbox_inches="tight")

