import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import datetime
import os
from env import ContextualQueue
from agent import CQBEps, CQBOpt, RandomPolicy, OptimalPolicy, CQBEpsopt, CQBts, UCB1, TS1
from util import sigmoid, project_l2_norm 

FIGSIZE = (20, 14)
DPI = 300
RC_FONTS = {
    'font.size': 50,
    'axes.labelsize': 50,
    'xtick.labelsize': 40,
    'ytick.labelsize': 40,
    'legend.fontsize': 40
}
LINEWIDTH = 10
MARKERSIZE = 20
CAPSIZE = 10
ELINEWIDTH = 3
PLOT_EVERY = 500
Y_LIM = (0, 500)
LEGEND_KW = dict(loc='best', markerscale=1.2, framealpha=1.0)

STYLE_MAP = {
    r'CQB-$\epsilon$': ('#FF3300', '-', 'o', MARKERSIZE, 3),
    'CQB-Opt': ('#4169E1', '-', 's', MARKERSIZE, 4),
    'Optimal Policy': ('#2ca02c', '--', '^', MARKERSIZE, 5),
    'Random Policy': ('gray', ':', 'D', MARKERSIZE, 2),
    r'CQB-$\epsilon$-Opt': ('#FFC300', ':', 'p', MARKERSIZE, 3),
    'CQB-ThS': ('#800080', ':', 'h', MARKERSIZE, 3),
    'Q-UCB': ('#17becf', '-.', 'v', MARKERSIZE, 3),
    'Q-ThS': ('#E377C2', '-.', '*', MARKERSIZE + 5, 3),
}

ALGORITHM_KEYS = {
    r'CQB-$\epsilon$': ('avg_cqb', 'se_cqb'),
    'CQB-Opt': ('avg_ucb_opt', 'se_ucb_opt'),
    'Optimal Policy': ('avg_optimal', 'se_optimal'),
    'Random Policy': ('avg_random', 'se_random'),
    r'CQB-$\epsilon$-Opt': ('avg_cqb_epsopt', 'se_cqb_epsopt'),
    'CQB-ThS': ('avg_cqb_ts', 'se_cqb_ts'),
    'Q-UCB': ('avg_ucb1', 'se_ucb1'),
    'Q-ThS': ('avg_ts1', 'se_ts1'),
}


def start_fig():
    plt.rcParams.update(RC_FONTS)
    plt.figure(figsize=FIGSIZE)
    plt.subplots_adjust(left=0.12, right=0.98, top=0.95, bottom=0.14)


def plot_performance(filename):
    start_fig()
    
    if not filename:
        return

    try:
        data = np.load(filename, allow_pickle=True)
    except FileNotFoundError:
        return

    try:
        T = int(np.array(data['T']).item()) if 'T' in data else 5000
        runs = int(np.array(data['runs']).item()) if 'runs' in data else 1
        
        first_avg_key = next(iter(ALGORITHM_KEYS.values()))[0]
        if first_avg_key in data:
            T_plot = data[first_avg_key].shape[0]
        else:
            T_plot = T

        x = np.arange(T_plot)
        mark = np.arange(0, T_plot, PLOT_EVERY)
        if T_plot > 0 and mark[-1] != T_plot - 1:
            mark = np.append(mark, T_plot - 1)
            mark = np.unique(mark)
            
    except Exception:
        return

    se_data_present = runs > 1 and 'se_optimal' in data

    for label, (color, linestyle, marker, markersize, zorder) in STYLE_MAP.items():
        
        avg_key, se_key = ALGORITHM_KEYS.get(label, (None, None))
        
        if avg_key in data:
            try:
                avg_data = data[avg_key][:T_plot]

                plt.plot(x, avg_data, label=label,
                         color=color, linestyle=linestyle, linewidth=LINEWIDTH, zorder=zorder)

                plt.plot(mark, avg_data[mark],
                         color=color, linestyle='none',
                         marker=marker, markersize=markersize, zorder=zorder + 1)
                
                if se_data_present and se_key in data:
                    se_data = data[se_key][:T_plot]
                    
                    if se_data.size > 0 and np.any(se_data[mark] != 0):
                        plt.errorbar(mark, avg_data[mark], yerr=se_data[mark],
                                     fmt='none', ecolor=color, capsize=CAPSIZE, elinewidth=ELINEWIDTH, zorder=zorder)
            except KeyError:
                pass
            except Exception:
                pass
        else:
            pass 


    plt.xlabel("Time Step t"); plt.ylabel("Q(t)")
    plt.xlim(0, T_plot)
    plt.ylim(*Y_LIM)
    plt.legend(**LEGEND_KW)
    plt.grid(True, linestyle='--', alpha=0.6)

    out = f'./{os.path.basename(filename).replace(".npz", "_plot.png")}'
    plt.gcf().set_size_inches(*FIGSIZE)
    plt.savefig(out, dpi=DPI, bbox_inches='tight')
    plt.show() 


def run_and_save_simulation():
    lambda_ = 0.7
    eps = 0.1
    d = 5
    K = 5 
    kappa = 10
    L = 3
    S = 1
    reg = 1
    seed = 0
    T = 5000 
    N = T   
    runs = 10 


    cqb_queue_lengths = []      
    ucb_opt_queue_lengths = [] 
    optimal_queue_lengths = [] 
    random_queue_lengths = [] 
    
    cqb_epsopt_queue_lengths = [] 
    cqb_ts_queue_lengths = []    
    ucb1_queue_lengths = []   
    ts1_queue_lengths = []  
    
    ALGORITHMS = [
        ('CQBEps', CQBEps, cqb_queue_lengths),
        ('CQBOpt', CQBOpt, ucb_opt_queue_lengths),
        ('CQBEpsopt', CQBEpsopt, cqb_epsopt_queue_lengths),
        ('CQBts', CQBts, cqb_ts_queue_lengths),
        ('OptimalPolicy', OptimalPolicy, optimal_queue_lengths),
        ('RandomPolicy', RandomPolicy, random_queue_lengths),
        ('UCB1', UCB1, ucb1_queue_lengths),
        ('TS1', TS1, ts1_queue_lengths)
    ]
    
    for algo_name, PolicyClass, results_list in ALGORITHMS:
        print(f"Running {algo_name}...")
        
        for run in tqdm(range(runs)):
            env = ContextualQueue(lambda_, eps, d, K, N, T, kappa, L, S, seed + run)
            
            if algo_name in ['OptimalPolicy', 'RandomPolicy']:
                agent = PolicyClass(env)
            else:
                agent = PolicyClass(env, lambda_=lambda_, eps=eps, d=d, K=K, kappa=kappa, L=L, S=S, reg=reg, seed=seed + run)
            
            agent.run(T)
            results_list.append(env.queue_length_history)

    results_map = {
        'avg_cqb': cqb_queue_lengths, 'se_cqb': cqb_queue_lengths,
        'avg_ucb_opt': ucb_opt_queue_lengths, 'se_ucb_opt': ucb_opt_queue_lengths,
        'avg_optimal': optimal_queue_lengths, 'se_optimal': optimal_queue_lengths,
        'avg_random': random_queue_lengths, 'se_random': random_queue_lengths,
        
        'avg_cqb_epsopt': cqb_epsopt_queue_lengths, 'se_cqb_epsopt': cqb_epsopt_queue_lengths,
        'avg_cqb_ts': cqb_ts_queue_lengths, 'se_cqb_ts': cqb_ts_queue_lengths,
        'avg_ucb1': ucb1_queue_lengths, 'se_ucb1': ucb1_queue_lengths,
        'avg_ts1': ts1_queue_lengths, 'se_ts1': ts1_queue_lengths
    }

    final_data = {'T': T, 'runs': runs, 'lambda_': lambda_, 'd': d, 'K': K}
    
    for key, data_list in results_map.items():
        if not data_list: continue
        
        data_array = np.array(data_list)
        
        if key.startswith('avg_'):
            final_data[key] = np.mean(data_array, axis=0)
        elif key.startswith('se_'):
            if runs > 1:
                final_data[key] = np.std(data_array, axis=0) 
            else:
                final_data[key] = np.zeros(data_array.shape[1])
        
    
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    filename = f"queue_performance_8_algos_results_{timestamp}.npz" 
    
    np.savez(filename, **final_data)

    print(f"\nSimulation complete. Data saved to: {filename}")
    return filename

if __name__ == "__main__":
    output_file = run_and_save_simulation()
    
    if output_file:
        print("Generating performance plot...")
        plot_performance(output_file)