import argparse
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

DESIRED_ORDER = ["alg3", "alg4", "alg5", "alg1", "alg2", "alg6", "alg7"]
ALG_LABELS = {
    "alg3": "NCBF-UCB",
    "alg4": "Logistic-UCB-1",
    "alg5": "ada-OFU-ECOLog",
    "alg1": "NeuralLog-UCB-1",
    "alg2": "NeuralLog-UCB-2",
    "alg6": "NeuralLog-TS-1",  
    "alg7": "NeuralLog-TS-2",  
}
ALG_COLORS = {
    "alg3": "#377eb8",
    "alg4": "#999999",
    "alg5": "#984ea3",
    "alg1": "#ff7f00",
    "alg2": "#e41a1c",
    "alg6": "#4daf4a",   
    "alg7": "#a65628",  
}
ALG_MARKERS = {
    "alg3": "o",
    "alg4": "s",
    "alg5": "^",
    "alg1": "v",
    "alg2": "D",
    "alg6": "X",
    "alg7": "*",
}


def plot_subfigure(npz_file: str, output_file: str) -> None:
    data = np.load(npz_file)
    sample = next(iter(data.values()))
    _, T = sample.shape
    rounds = np.arange(1, T + 1)
    fig, ax = plt.subplots(figsize=(4.2, 3))
    for key in DESIRED_ORDER:
        if key not in data:
            continue
        runs = data[key]
        avg = runs.mean(axis=0)
        std = runs.std(axis=0)
        ax.plot(
            rounds,
            avg,
            label=ALG_LABELS[key],
            color=ALG_COLORS[key],
            linewidth=1.5,
            marker=ALG_MARKERS[key],
            markersize=4,
            markevery=max(T // 10, 1)
        )
        ax.fill_between(
            rounds,
            avg - 2 * std,
            avg + 2 * std,
            color=ALG_COLORS[key],
            alpha=0.15
        )
    ax.set_xlabel("Timesteps", fontsize=8)
    ax.set_ylabel("Regret", fontsize=8)
    ax.tick_params(labelsize=8)
    legend = ax.legend(fontsize=8, loc="upper left")
    legend.get_frame().set_alpha(0.5)
    ax.xaxis.set_major_locator(MaxNLocator(nbins=7, integer=True))
    ax.yaxis.set_major_locator(MaxNLocator(nbins=7))
    ax.grid(linestyle="--", linewidth=0.8, alpha=0.7)
    plt.tight_layout()
    plt.savefig(output_file, dpi=300)
    plt.show()

def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--npz_file", required=True)
    parser.add_argument("--output", required=True)
    args = parser.parse_args()
    plot_subfigure(args.npz_file, args.output)

if __name__ == "__main__":
    main()
