import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import datetime
import os
from env import ContextualQueue
from agent import CQBEps
from util import *

FIGSIZE = (20, 14)
RC_FONTS = {
    'font.size': 50,
    'axes.labelsize': 50,
    'xtick.labelsize': 40,
    'ytick.labelsize': 40,
    'legend.fontsize': 40
}
LINEWIDTH = 10
MARKERSIZE = 20
PLOT_EVERY = 500
Y_LIM = (0, 350) 

def run_d_sensitivity_simulation():
    lambda_ = 0.7
    eps = 0.1
    K, T, runs = 5, 5000, 10
    params = {'kappa': 10, 'L': 3, 'S': 1, 'reg': 1, 'seed': 0}
    
    d_values = [3, 5, 10, 20]
    
    cqb_results = {}
    std_results = {}

    
    for d in d_values:
        print(f"\nRunning CQB-eps with d={d} for {runs} runs...")
        all_runs_lengths = []
        
        for run_idx in tqdm(range(runs)):
            env = ContextualQueue(lambda_, eps, d, K, T, T, params['kappa'], params['L'], params['S'], params['seed'] + run_idx)
            agent = CQBEps(env, lambda_, eps, d, K, params['kappa'], params['L'], params['S'], params['reg'], params['seed'] + run_idx)
            
            agent.run(T)
            all_runs_lengths.append(env.queue_length_history)
        
        cqb_results[d] = np.mean(all_runs_lengths, axis=0)
        std_results[d] = np.std(all_runs_lengths, axis=0)
    

    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    filename = f"integrated_3_1_results_{timestamp}.npz"
    np.savez(filename, cqb_results=cqb_results, std_results=std_results, d_values=d_values, T=T, runs=runs)
    return filename

def plot_d_integrated_results(filename):
    data = np.load(filename, allow_pickle=True)
    results = data['cqb_results'].item()
    std_results = data['std_results'].item()
    d_values = data['d_values']
    T = int(data['T'])
    runs = int(data['runs'])

    plt.rcParams.update(RC_FONTS)
    plt.figure(figsize=FIGSIZE)
    plt.subplots_adjust(left=0.12, right=0.98, top=0.95, bottom=0.14)


    colors = ['#FFC699', '#FF9966', '#FF6633', '#FF3300']
    markers = ['o', 's', '^', 'D', 'v', 'p']
    mark = np.arange(0, T, PLOT_EVERY)

    for i, d_val in enumerate(d_values):
        avg = results[d_val][:T]
        se = std_results[d_val][:T]
        
        label = rf"CQB-$\epsilon$ (d={int(d_val)})"
        color = colors[i % len(colors)]
        

        plt.plot(avg, label=label, color=color, linewidth=LINEWIDTH, 
                 marker=markers[i % len(markers)], markersize=20, markevery=mark)
        

        plt.errorbar(mark, avg[mark], yerr=se[mark],
                     fmt='none', ecolor=color, capsize=10, elinewidth=3)

    plt.xlabel("Time Step t"); plt.ylabel("Q(t)")
    plt.ylim(*Y_LIM)
    plt.legend(loc='upper right', markerscale=1.5)
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    

    plt.savefig('./31.png', dpi=300, bbox_inches='tight')
    plt.show()

if __name__ == "__main__":
    result_file = run_d_sensitivity_simulation()
    plot_d_integrated_results(result_file)