import numpy as np
import matplotlib.pyplot as plt

def plot_quantile_band(data, color, label, δ=1e-4):
    # Get final values (at last step) for all runs
    final_values = data[:, -1]

    # Sort indices based on final values
    sorted_indices = np.argsort(final_values)

    # Get indices of quantile trajectories
    num_runs = data.shape[0]
    lower_index = sorted_indices[int(δ * num_runs)]
    upper_index = sorted_indices[int((1 - δ) * num_runs)]
    median_index = sorted_indices[num_runs // 2]

    # Get full trajectories for those runs
    lower_run = data[lower_index]
    upper_run = data[upper_index]
    median_run = data[median_index]

    # Time axis
    steps = np.arange(1, data.shape[1] + 1)

    # Plot lines
    plt.plot(steps, median_run, label=f'{label}', color=color, alpha=1.0)
    plt.plot(steps, lower_run, color=color, alpha=0.2)
    plt.plot(steps, upper_run, color=color, alpha=0.2)

def main():
    # Set global font size
    plt.rcParams.update({
        'axes.labelsize': 16,
        'xtick.labelsize': 14,
        'ytick.labelsize': 14,
        'legend.fontsize': 15,
        'axes.titlesize': 18
    })

    # data1 = np.load("Lion_synthetic.npz")['avg_norms']
    # data2 = np.load("Lion++_synthetic.npz")['avg_norms']
    
    data1 = np.load("Muon_synthetic.npz")['avg_norms']
    data2 = np.load("Muon_synthetic2.npz")['avg_norms']

    plt.figure()
    
    plot_quantile_band(data1, color='blue', label='Lion')
    plot_quantile_band(data2, color='red', label='Lion++')

    # plot_quantile_band(data1, color='green', label='Muon')
    # plot_quantile_band(data2, color='orange', label='Muon++')

    plt.xlabel(r"$T$")
    plt.ylabel(r"$\frac{1}{T} \sum_{t=1}^T \| \nabla F(x_t) \|$")
    plt.title("Normal noise, n=1")
    plt.grid(False)
    plt.legend()
    plt.tight_layout()
    plt.savefig("Lion_synthetic.pdf")
    plt.show()
    
    # plt.xlabel(r"$T$")
    # plt.ylabel(r"$\frac{1}{T} \sum_{t=1}^T \| \nabla F(X_t) \|$")
    # plt.title("Normal noise, n=1")
    # plt.grid(False)
    # plt.legend()
    # plt.tight_layout()
    # plt.savefig("Muon_synthetic.pdf")
    # plt.show()

if __name__ == "__main__":
    main()