import argparse
import os
import json
import numpy as np
import matplotlib.pyplot as plt

def calculate_loss_barrier(loss_curve):
    return np.max(loss_curve) - (loss_curve[0] + loss_curve[-1]) / 2

def main():
    parser = argparse.ArgumentParser(description="Plot ImageNet test loss and accuracy for best weight matching")
    parser.add_argument("--file-1", type=str, required=True, help="Path to first JSON results file")
    parser.add_argument("--file-2", type=str, required=True, help="Path to second JSON results file")
    parser.add_argument("--file-3", type=str, required=True, help="Path to third JSON results file")
    parser.add_argument("--output-dir", type=str, default=".", help="Directory to save output plot")
    args = parser.parse_args()

    file_paths = [args.file_1, args.file_2, args.file_3]
    xtick_labels = [
        ["Model 1", r"$\lambda$", "Model 2"],
        ["Model 1", r"$\lambda$", "Model 3"],
        ["Model 2", r"$\lambda$", "Model 3"],
    ]

    data = [json.load(open(path, 'r')) for path in file_paths]

    num_points = len(data[0]["Naive"]["Test Loss"])
    lambda_values = np.linspace(0, 1, num_points)

    plt.rcParams.update({
        "font.family": "serif",
        'legend.frameon': False,
        'lines.linewidth': 2,
    })

    colors = ["steelblue", "lightsalmon"]
    plt.style.use('tableau-colorblind10')
    fig, axs = plt.subplots(2, 3, figsize=(13, 6))

    FONT_SMALL = 11
    FONT_MEDIUM = 13
    FONT_LARGE = 16

    # Find best index (lowest barrier) for each file
    # best_indices = []
    # for d in data:
    #     best_idx = 0
    #     min_barrier = float('inf')
    #     for idx, loss_curve in enumerate(d["test_loss_interp_clever_list"]):
    #         barrier = calculate_loss_barrier(loss_curve)
    #         if barrier < min_barrier:
    #             min_barrier = barrier
    #             best_idx = idx
    #     best_indices.append(best_idx)

    for i in range(3):
        d = data[i]
        # best_idx = best_indices[i]

        # Row 0: Test Loss
        ax = axs[0, i]
        ax.plot(lambda_values, d["Naive"]["Test Loss"], label="Naive", color=colors[0])
        ax.plot(lambda_values, d["permu_head_init_ortho_opt"]["Test Loss"], label="Match", color=colors[1])
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(xtick_labels[i], fontsize=FONT_MEDIUM)
        ax.tick_params(axis='y', labelsize=FONT_MEDIUM)
        if i == 0:
            ax.set_ylabel("Validation Loss", fontsize=FONT_LARGE, labelpad=15)
            ax.legend(loc='upper left', fontsize=FONT_SMALL)

        # Row 1: Test Accuracy
        ax = axs[1, i]
        ax.plot(lambda_values, d["Naive"]["Test Acc"], label="Naive", color=colors[0])
        ax.plot(lambda_values, d["permu_head_init_ortho_opt"]["Test Acc"], label="Match", color=colors[1])
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(xtick_labels[i], fontsize=FONT_MEDIUM)
        ax.tick_params(axis='y', labelsize=FONT_MEDIUM)
        if i == 0:
            ax.set_ylabel("Validation Accuracy (%)", fontsize=FONT_LARGE, labelpad=15)

    plt.tight_layout(w_pad=4.0, h_pad=2.5)
    output_path = os.path.join(args.output_dir, "plot.pdf") if args.output_dir == '.' else args.output_dir + ".pdf"
    plt.savefig(output_path)
    print(f"Saved plot to {output_path}")

if __name__ == "__main__":
    main()
