import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import argparse
from scipy.stats import sem

def plot_single_metric(metric, K, seeds, base_dir, algo_configs, colors, n_points_to_plot=10):
    if metric == 'regret':
        data_dir = os.path.join(base_dir, "regretlog")
        ylabel = 'Cumulative Regret'
        file_prefix = "Regret"
    else:
        data_dir = os.path.join(base_dir, "timelog")
        ylabel = 'Time (seconds)'
        file_prefix = "Time"

    plt.figure(figsize=(8, 6))
    
    for config in algo_configs:
        algo_name = config['file_key']
        label = config['label']
        color = colors[config['color_idx']]
        
        data_list = []
        for seed in seeds:
            filename = f"{K}_{seed}_{algo_name}.npy"
            filepath = os.path.join(data_dir, filename)
            
            if os.path.exists(filepath):
                data = np.load(filepath)
                data_list.append(data)
            else:
                print(f"[Warning] File not found: {filepath}")

        if not data_list:
            continue

        data_array = np.array(data_list)
        mean_curve = np.mean(data_array, axis=0)
        stderr_curve = sem(data_array, axis=0)
        
        rounds = np.arange(len(mean_curve))
        
        # --- Plotting ---
        plt.plot(rounds, mean_curve, color=color, alpha=0.8, linewidth=1.5)
        
        interval = len(rounds) // n_points_to_plot
        indices = np.arange(0, len(rounds), interval)
        if indices[-1] != len(rounds) - 1:
            indices = np.append(indices, len(rounds) - 1)

        plt.errorbar(
            rounds[indices], 
            mean_curve[indices], 
            yerr=stderr_curve[indices], 
            fmt='o', 
            color=color, 
            label=label, 
            capsize=5, 
            alpha=0.7, 
            markersize=6
        )

    plt.xlabel(r'Round ($t$)')
    plt.ylabel(ylabel)
    plt.title(f"N=1682, K={K}, d=25", fontweight='bold')
    plt.grid(True, linestyle='-', alpha=0.6)
    
    plt.legend(loc='upper left', frameon=True, edgecolor='gray')
    
    plt.tight_layout()
    
    save_filename = f"{file_prefix}_N1682_K{K}.png"
    save_path = os.path.join(base_dir, save_filename)
    plt.savefig(save_path, dpi=300)
    print(f"Saved {metric} plot to {save_path}")
    
    plt.close()

def main(K, seeds, base_dir="/home/choihyunjun/neurips2025"):
    colors = sns.color_palette("deep")
    
    algo_configs = [
        {'file_key': 'CLogUCB',   'label': r'CLogUCB',    'color_idx': 0},
        {'file_key': 'VACLogUCB', 'label': r'VA-CLogUCB', 'color_idx': 1},
        {'file_key': 'UCBCCA',    'label': r'UCB-CCA',    'color_idx': 2},
        {'file_key': 'UCBCLB',    'label': r'UCB-CLB',    'color_idx': 3}
    ]

    plt.rcParams.update({
        'font.size': 14,
        'axes.titlesize': 18,
        'axes.labelsize': 16,
        'xtick.labelsize': 14,
        'ytick.labelsize': 14,
        'legend.fontsize': 13,
        'legend.title_fontsize': 14
    })

    plot_single_metric('regret', K, seeds, base_dir, algo_configs, colors)
    plot_single_metric('time', K, seeds, base_dir, algo_configs, colors)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Plot Bandit Experiment Results Separately")
    
    parser.add_argument('--K', type=int, required=True, help='Length of Cascade(K)')
    parser.add_argument('--seeds', type=int, nargs='+', default=[0, 1, 2, 3, 4], 
                        help='List of seeds used in experiment')

    args = parser.parse_args()
    
    main(args.K, args.seeds)