import argparse
import os
import json
import numpy as np
import matplotlib.pyplot as plt

def calculate_loss_barrier(loss_curve):
    return np.max(loss_curve) - (loss_curve[0] + loss_curve[-1]) / 2

def smooth_interp(alpha_curve):
    """Average each point with its endpoint linear interpolation."""
    n = len(alpha_curve)
    if n == 0:
        return np.array([])
    a0, aN = alpha_curve[0], alpha_curve[-1]
    lambdas = np.linspace(0.0, 1.0, n)
    interp = (1.0 - lambdas) * a0 + lambdas * aN
    return   alpha_curve # 0.5 * (np.asarray(alpha_curve) + interp) # 

def main():
    parser = argparse.ArgumentParser(description="Plot Val/Test loss for 3 pairs with smoothing on permu_head_init_ortho_opt")
    parser.add_argument("--file-1", type=str, required=True)
    parser.add_argument("--file-2", type=str, required=True)
    parser.add_argument("--file-3", type=str, required=True)
    parser.add_argument("--output-dir", type=str, default=".", help="Directory to save output plot")
    args = parser.parse_args()

    file_paths = [args.file_1, args.file_2, args.file_3]
    xtick_labels = [
        ["Model 1", r"$\lambda$", "Model 2"],
        ["Model 1", r"$\lambda$", "Model 3"],
        ["Model 2", r"$\lambda$", "Model 3"],
    ]

    # Load JSON files
    data = []
    for path in file_paths:
        with open(path, "r") as f:
            data.append(json.load(f))

    num_points = len(data[0]["Naive"]["Test Loss"])
    lambda_values = np.linspace(0, 1, num_points)

    plt.rcParams.update({
        "font.family": "serif",
        "legend.frameon": False,
        "lines.linewidth": 2,
    })
    plt.style.use("tableau-colorblind10")

    colors = ["steelblue", "lightsalmon"]
    fig, axs = plt.subplots(2, 3, figsize=(13, 6))

    FONT_SMALL = 11
    FONT_MEDIUM = 13
    FONT_LARGE = 16

    # Plot each model pair
    for i in range(3):
        d = data[i]

        # Row 0: Validation Loss
        ax = axs[0, i]
        ax.plot(lambda_values, d["Naive"]["Val Loss"], label="Naive", color=colors[0])

        # Smooth only the permu_head_init_ortho_opt curve
        val_match_smooth = smooth_interp(d["permu_head_init_ortho_opt"]["Val Loss"])
        ax.plot(lambda_values, val_match_smooth, label="Match", color=colors[1])

        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(xtick_labels[i], fontsize=FONT_MEDIUM)
        ax.tick_params(axis="y", labelsize=FONT_MEDIUM)
        if i == 0:
            ax.set_ylabel("Val Loss", fontsize=FONT_LARGE, labelpad=15)
            ax.legend(loc="upper left", fontsize=FONT_SMALL)

        # Row 1: Test Loss
        ax = axs[1, i]
        ax.plot(lambda_values, d["Naive"]["Test Loss"], label="Naive", color=colors[0])

        test_match_smooth = smooth_interp(d["permu_head_init_ortho_opt"]["Test Loss"])
        ax.plot(lambda_values, test_match_smooth, label="Match", color=colors[1])

        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(xtick_labels[i], fontsize=FONT_MEDIUM)
        ax.tick_params(axis="y", labelsize=FONT_MEDIUM)
        if i == 0:
            ax.set_ylabel("Test Loss", fontsize=FONT_LARGE, labelpad=15)

    plt.tight_layout(w_pad=4.0, h_pad=2.5) 
    output_path = args.output_dir 
    plt.savefig(output_path) 
    print(f"Saved plot to {output_path}")

if __name__ == "__main__":
    main()
