import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import datetime
import os
from env import ContextualQueue
from agent import CQBOpt
from util import *

FIGSIZE = (20, 14)
RC_FONTS = {
    'font.size': 50,
    'axes.labelsize': 50,
    'xtick.labelsize': 40,
    'ytick.labelsize': 40,
    'legend.fontsize': 40
}
LINEWIDTH = 10
MARKERSIZE = 20
PLOT_EVERY = 500
Y_LIM = (0, 350) 

def run_opt_sensitivity_simulation():
    lambda_ = 0.7
    eps_values = [0.05, 0.1, 0.15] 
    d, K, T, runs = 5, 5, 5000, 10
    params = {'kappa': 10, 'L': 3, 'S': 1, 'reg': 1, 'seed': 0}
    
    CQB_results = {}
    std_results = {}

    
    for eps in eps_values:
        print(f"\nRunning CQB-Opt with eps={eps} for {runs} runs...")
        all_runs_lengths = []
        
        for run_idx in tqdm(range(runs)):

            env = ContextualQueue(lambda_, eps, d, K, T, T, params['kappa'], params['L'], params['S'], params['seed'] + run_idx)
            agent = CQBOpt(env, lambda_, eps, d, K, params['kappa'], params['L'], params['S'], params['reg'], params['seed'] + run_idx)
            
            agent.run(T)
            all_runs_lengths.append(env.queue_length_history)
        

        CQB_results[eps] = np.mean(all_runs_lengths, axis=0)
        std_results[eps] = np.std(all_runs_lengths, axis=0)
    

    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    filename = f"integrated_1_3_results_{timestamp}.npz"
    np.savez(filename, CQB_results=CQB_results, std_results=std_results, eps_values=eps_values, T=T, runs=runs)
    return filename

def plot_opt_integrated_results(filename):
    data = np.load(filename, allow_pickle=True)
    results = data['CQB_results'].item()
    std_results = data['std_results'].item()
    eps_values = data['eps_values']
    T = int(data['T'])
    runs = int(data['runs'])

    plt.rcParams.update(RC_FONTS)
    plt.figure(figsize=FIGSIZE)
    plt.subplots_adjust(left=0.12, right=0.98, top=0.95, bottom=0.14)


    colors = ['#4169E1', '#0000CD', '#00008B']
    markers = ['o', 's', '^']
    mark = np.arange(0, T, PLOT_EVERY)

    for i, eps in enumerate(eps_values):
        avg = results[eps][:T]
        se = std_results[eps][:T] 
        
        label = f"CQB-Opt ($\epsilon$={eps})"
        color = colors[i % len(colors)]
        

        plt.plot(avg, label=label, color=color, linewidth=LINEWIDTH, 
                 marker=markers[i % len(markers)], markersize=MARKERSIZE, markevery=mark)
        

        plt.errorbar(mark, avg[mark], yerr=se[mark],
                     fmt='none', ecolor=color, capsize=10, elinewidth=3)

    plt.xlabel("Time Step t"); plt.ylabel("Q(t)")
    plt.ylim(*Y_LIM)
    plt.legend(loc='upper right', markerscale=1.5)
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    

    plt.savefig('./13.png', dpi=300, bbox_inches='tight')
    plt.show()

if __name__ == "__main__":

    result_file = run_opt_sensitivity_simulation()
    plot_opt_integrated_results(result_file)