import argparse
import os
import json
import numpy as np
import matplotlib.pyplot as plt

def calculate_loss_barrier(loss_curve):
    return np.max(loss_curve) - (loss_curve[0] + loss_curve[-1]) / 2

def main():
    parser = argparse.ArgumentParser(description="Plot test loss and perplexity for best weight matching")
    parser.add_argument("--file-1", type=str, required=True)
    parser.add_argument("--file-2", type=str, required=True)
    parser.add_argument("--file-3", type=str, required=True)
    parser.add_argument("--output-dir", type=str, default=".", help="Directory to save output plot")
    args = parser.parse_args()

    file_paths = [args.file_1, args.file_2, args.file_3]
    xtick_labels = [
        ["Model 1", r"$\lambda$", "Model 2"],
        ["Model 1", r"$\lambda$", "Model 3"],
        ["Model 2", r"$\lambda$", "Model 3"],
    ]

    data = [json.load(open(path, 'r')) for path in file_paths]

    num_points = len(data[0]["test_loss_interp_naive"])
    lambda_values = np.linspace(0, 1, num_points)

    plt.rcParams.update({
        "font.family": "serif",
        'legend.frameon': False,
        'lines.linewidth': 2,
    })

    colors = ["steelblue", "lightsalmon"]
    plt.style.use('tableau-colorblind10')
    fig, axs = plt.subplots(2, 3, figsize=(13, 6))

    FONT_SMALL = 11
    FONT_MEDIUM = 13
    FONT_LARGE = 16

    # Determine best clever index (lowest test loss barrier) for each file
    best_indices = []
    for d in data:
        best_idx = 0
        min_barrier = float("inf")
        for idx, loss_curve in enumerate(d["test_loss_interp_clever_list"]):
            barrier = calculate_loss_barrier(loss_curve)
            if barrier < min_barrier:
                min_barrier = barrier
                best_idx = idx
        best_indices.append(best_idx)

    # Plot each model pair
    for i in range(3):
        d = data[i]
        best_idx = best_indices[i]

        # Row 0: Test Loss
        ax = axs[0, i]
        ax.plot(lambda_values, d["test_loss_interp_naive"], label="Naive", color=colors[0])
        ax.plot(lambda_values, d["test_loss_interp_clever_list"][best_idx], label="WM", color=colors[1])
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(xtick_labels[i], fontsize=FONT_MEDIUM)
        ax.tick_params(axis='y', labelsize=FONT_MEDIUM)
        if i == 0:
            ax.set_ylabel("Test Loss", fontsize=FONT_LARGE, labelpad=15)
            ax.legend(loc='upper left', fontsize=FONT_SMALL)

        # Row 1: Test Perplexity
        ax = axs[1, i]
        ax.plot(lambda_values, d["test_ppl_interp_naive"], label="Naive", color=colors[0])
        ax.plot(lambda_values, d["test_ppl_interp_clever_list"][best_idx], label="WM", color=colors[1])
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(xtick_labels[i], fontsize=FONT_MEDIUM)
        ax.tick_params(axis='y', labelsize=FONT_MEDIUM)
        if i == 0:
            ax.set_ylabel("Test Perplexity", fontsize=FONT_LARGE, labelpad=15)

    plt.tight_layout(w_pad=4.0, h_pad=2.5)
    output_path = os.path.join(args.output_dir, "plot.pdf") if args.output_dir == '.' else args.output_dir + ".pdf"
    plt.savefig(output_path)
    print(f"Saved plot to {output_path}")

if __name__ == "__main__":
    main()
