import argparse
import os
import json
import numpy as np
import matplotlib.pyplot as plt

def calculate_loss_barrier(loss_curve):
    return np.max(loss_curve) - (loss_curve[0] + loss_curve[-1]) / 2

def main():
    parser = argparse.ArgumentParser(description="Plot train and test loss for weight matching")
    parser.add_argument("--file-1", type=str, required=True)
    parser.add_argument("--file-2", type=str, required=True)
    parser.add_argument("--file-3", type=str, required=True)
    parser.add_argument("--output-dir", type=str, default=".")
    args = parser.parse_args()

    file_paths = [args.file_1, args.file_2, args.file_3]
    xtick_labels = [
        ["Model 1", r"$\lambda$", "Model 2"],
        ["Model 1", r"$\lambda$", "Model 3"],
        ["Model 2", r"$\lambda$", "Model 3"],
    ]

    data = [json.load(open(path, 'r')) for path in file_paths]

    num_points = len(data[0]["test_loss_interp_naive"])
    lambda_values = np.linspace(0, 1, num_points)

    plt.rcParams.update({
        "font.family": "serif",
        'legend.frameon': False,
        'lines.linewidth': 2,
    })

    colors = ["steelblue", "lightsalmon"]
    plt.style.use('tableau-colorblind10')
    fig, axs = plt.subplots(2, 3, figsize=(13, 6))

    FONT_SMALL = 11
    FONT_MEDIUM = 13
    FONT_LARGE = 16

    # Find minimum loss barrier across all clever interpolations
    min_barrier = float("inf")
    min_pair = (None, None)  # (file_index, clever_index)

    for file_idx, d in enumerate(data):
        for clever_idx, loss_curve in enumerate(d["test_loss_interp_clever_list"]):
            barrier = calculate_loss_barrier(loss_curve)
            if barrier < min_barrier:
                min_barrier = barrier
                min_pair = (file_idx, clever_idx)

    for i in range(3):
        d = data[i]
        clever_losses = d["train_loss_interp_clever_list"]
        clever_test_losses = d["test_loss_interp_clever_list"]

        # Use best clever index if it's the min-barrier one; else default to first
        clever_idx = min_pair[1] if i == min_pair[0] else 0

        # Row 0: Train Loss
        ax = axs[0, i]
        ax.plot(lambda_values, d["train_loss_interp_naive"], label="Naive", color=colors[0])
        ax.plot(lambda_values, clever_losses[clever_idx], label="WM", color=colors[1])
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(xtick_labels[i], fontsize=FONT_MEDIUM)
        ax.tick_params(axis='y', labelsize=FONT_MEDIUM)
        if i == 0:
            ax.set_ylabel("Validation Loss", fontsize=FONT_LARGE, labelpad=15)
            ax.legend(loc='upper left', fontsize=FONT_SMALL)

        # Row 1: Test Loss
        ax = axs[1, i]
        ax.plot(lambda_values, d["test_loss_interp_naive"], label="Naive", color=colors[0])
        ax.plot(lambda_values, clever_test_losses[clever_idx], label="WM", color=colors[1])
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(xtick_labels[i], fontsize=FONT_MEDIUM)
        ax.tick_params(axis='y', labelsize=FONT_MEDIUM)
        if i == 0:
            ax.set_ylabel("Test Loss", fontsize=FONT_LARGE, labelpad=15)

    plt.tight_layout(w_pad=4.0, h_pad=2.5)
    output_path = os.path.join(args.output_dir, "plot.pdf") if args.output_dir == '.' else args.output_dir + ".pdf"
    plt.savefig(output_path)
    print(f"Saved plot to {output_path}")
    print(f"Minimum loss barrier: {min_barrier:.4f} at file index {min_pair[0]}, clever index {min_pair[1]}")

if __name__ == "__main__":
    main()
