import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import datetime
import os
from env import ContextualQueue
from agent import CQBEps
from util import *

FIGSIZE = (20, 14)
RC_FONTS = {
    'font.size': 50,
    'axes.labelsize': 50,
    'xtick.labelsize': 40,
    'ytick.labelsize': 40,
    'legend.fontsize': 40
}
LINEWIDTH = 10
MARKERSIZE = 20
PLOT_EVERY = 500
Y_LIM = (0, 100) 

def run_k_sensitivity_simulation():
    lambda_ = 0.7
    eps = 0.1
    d, T, runs = 5, 5000, 10
    params = {'kappa': 10, 'L': 3, 'S': 1, 'reg': 1, 'seed': 0}
    
    k_values = [3, 5, 7, 9]
    
    cqb_results = {}
    std_results = {}

    
    for K in k_values:
        print(f"\nRunning CQB-eps with K={K} for {runs} runs...")
        all_runs_lengths = []
        
        for run_idx in tqdm(range(runs)):
            env = ContextualQueue(lambda_, eps, d, K, T, T, params['kappa'], params['L'], params['S'], params['seed'] + run_idx)
            agent = CQBEps(env, lambda_, eps, d, K, params['kappa'], params['L'], params['S'], params['reg'], params['seed'] + run_idx)
            
            agent.run(T)
            all_runs_lengths.append(env.queue_length_history)
        
        cqb_results[K] = np.mean(all_runs_lengths, axis=0)
        std_results[K] = np.std(all_runs_lengths, axis=0)
    
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    filename = f"integrated_2_1_results_{timestamp}.npz"
    np.savez(filename, cqb_results=cqb_results, std_results=std_results, k_values=k_values, T=T, runs=runs)
    return filename

def plot_k_integrated_results(filename):
    data = np.load(filename, allow_pickle=True)
    results = data['cqb_results'].item()
    std_results = data['std_results'].item()
    k_values = data['k_values']
    T = int(data['T'])
    runs = int(data['runs'])

    plt.rcParams.update(RC_FONTS)
    plt.figure(figsize=FIGSIZE)
    plt.subplots_adjust(left=0.12, right=0.98, top=0.95, bottom=0.14)

    colors  = ['#FFC699', '#FF9966', '#FF6633', '#FF3300']  
    markers = ['o', 's', '^', 'D']
    mark = np.arange(0, T, PLOT_EVERY)

    for i, K in enumerate(k_values):
        avg = results[K][:T]
        se = std_results[K][:T] 
        
        label = rf"CQB-$\epsilon$ (K={int(K)})"
        color = colors[i % len(colors)]

        plt.plot(avg, label=label, color=color, linewidth=LINEWIDTH, 
                 marker=markers[i % len(markers)], markersize=20, markevery=mark)
        
        plt.errorbar(mark, avg[mark], yerr=se[mark],
                     fmt='none', ecolor=color, capsize=10, elinewidth=3)

    plt.xlabel("Time Step t"); plt.ylabel("Q(t)")
    plt.ylim(*Y_LIM)
    plt.legend(loc='upper right', markerscale=1.5)
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    
    plt.savefig('./21.png', dpi=300, bbox_inches='tight')
    plt.show()

if __name__ == "__main__":

    result_file = run_k_sensitivity_simulation()
    plot_k_integrated_results(result_file)