import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import datetime
import os
from env import ContextualQueue
from agent import CQBEps
from util import *

FIGSIZE = (20, 14)
DPI = 300
RC_FONTS = {
    'font.size': 50,
    'axes.labelsize': 50,
    'xtick.labelsize': 40,
    'ytick.labelsize': 40,
    'legend.fontsize': 40
}
LINEWIDTH = 10
MARKERSIZE = 20
CAPSIZE = 10    
ELINEWIDTH = 3    
PLOT_EVERY = 500
Y_LIM = (0, 350) 

def run_sensitivity_simulation():
    lambda_ = 0.7
    eps_values = [0.05, 0.1, 0.15] 
    d, K, T, runs = 5, 5, 5000, 10
    params = {'kappa': 10, 'L': 3, 'S': 1, 'reg': 1, 'seed': 0}
    
    results = {}
    std_results = {}

    for eps in eps_values:
        print(f"\nRunning with eps={eps} for {runs} runs...")
        all_runs_lengths = []
        
        for run_idx in tqdm(range(runs)):
            env = ContextualQueue(lambda_, eps, d, K, T, T, params['kappa'], params['L'], params['S'], params['seed'] + run_idx)
            agent = CQBEps(env, lambda_, eps, d, K, params['kappa'], params['L'], params['S'], params['reg'], params['seed'] + run_idx)
            
            agent.run(T)
            all_runs_lengths.append(env.queue_length_history)
        
        results[eps] = np.mean(all_runs_lengths, axis=0)
        std_results[eps] = np.std(all_runs_lengths, axis=0)

    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    filename = f"integrated_1_2_results_{timestamp}.npz"
    np.savez(filename, results=results, std_results=std_results, eps_values=eps_values, T=T, runs=runs)
    return filename

def plot_integrated_results(filename):
    data = np.load(filename, allow_pickle=True)
    results = data['results'].item()
    std_results = data['std_results'].item()
    eps_values = data['eps_values']
    T = int(data['T'])
    runs = int(data['runs'])

    plt.rcParams.update(RC_FONTS)
    plt.figure(figsize=FIGSIZE)
    plt.subplots_adjust(left=0.12, right=0.98, top=0.95, bottom=0.14)

    colors = ['#FF6347', '#FF8C00', '#FFDAB9']
    markers = ['o', 's', '^']
    

    mark = np.arange(0, T, PLOT_EVERY)
    if T > 0 and mark[-1] != T - 1:
        mark = np.append(mark, T - 1)
        mark = np.unique(mark)

    for i, eps in enumerate(eps_values):
        avg = results[eps][:T]
        se = std_results[eps][:T] 
        
        label = rf"CQB-$\epsilon$ ($\epsilon={eps}$)"
        color = colors[i % len(colors)]
        marker = markers[i % len(markers)]
        

        plt.plot(range(T), avg, label=label, color=color, 
                 linewidth=LINEWIDTH, linestyle='-', zorder=i*2+2)
        

        plt.plot(mark, avg[mark], color=color, linestyle='none',
                 marker=marker, markersize=MARKERSIZE, zorder=i*2+3)
        

        plt.errorbar(mark, avg[mark], yerr=se[mark],
                     fmt='none', ecolor=color, capsize=CAPSIZE, 
                     elinewidth=ELINEWIDTH, zorder=i*2+1)

    plt.xlabel("Time Step t")
    plt.ylabel("Q(t)")
    plt.xlim(0, T)
    plt.ylim(*Y_LIM)
    plt.legend(loc='upper left', markerscale=1.2, framealpha=1.0)
    plt.grid(True, linestyle='--', alpha=0.6)
    

    out = f'./{os.path.basename(filename).replace(".npz", "_plot.png")}'
    plt.savefig(out, dpi=DPI, bbox_inches='tight')
    print(f"Plot saved to: {out}")
    
    plt.show()

if __name__ == "__main__":

    result_file = run_sensitivity_simulation()

    plot_integrated_results(result_file)