import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import datetime
import os
import sys

try:
    from util import *
    from env import ContextualQueue 
    from agent import (
        OptimalPolicy, RandomPolicy, CQBEps, CQBOpt, 
        CQBEpsopt, CQBts, UCB1, TS1 
    ) 
except ImportError as e:
    print(f"-")
    sys.exit(1)

FIGSIZE = (20, 14)
DPI = 300
RC_FONTS = {
    'font.size': 50,
    'axes.labelsize': 50,
    'xtick.labelsize': 40,
    'ytick.labelsize': 40,
    'legend.fontsize': 40
}
LINEWIDTH = 10
MARKERSIZE = 20
CAPSIZE = 10
ELINEWIDTH = 3
PLOT_EVERY = 500 
Y_LIM = (0, 500) 
LEGEND_KW = dict(loc='best', markerscale=1.2, framealpha=1.0)

STYLE_MAP = {
    r'CQB-$\epsilon$': ('#FF3300', '-', 'o', MARKERSIZE, 3),
    'CQB-Opt': ('#4169E1', '-', 's', MARKERSIZE, 4),
    'Optimal Policy': ('#2ca02c', '--', '^', MARKERSIZE, 5),
    'Random Policy': ('gray', ':', 'D', MARKERSIZE, 2),
    r'CQB-$\epsilon$-Opt': ('#FFC300', ':', 'p', MARKERSIZE, 3), 
    'CQB-ThS': ('#800080', ':', 'h', MARKERSIZE, 3),
    'Q-UCB': ('#17becf', '-.', 'v', MARKERSIZE, 3),
    'Q-ThS': ('#E377C2', '-.', '*', MARKERSIZE + 5, 3),
}

ALGORITHM_KEYS = {
    r'CQB-$\epsilon$': ('avg_cqbeps', 'se_cqbeps'),
    'CQB-Opt': ('avg_cqbopt', 'se_cqbopt'),
    'Optimal Policy': ('avg_optimalpolicy', 'se_optimalpolicy'),
    'Random Policy': ('avg_randompolicy', 'se_randompolicy'),
    r'CQB-$\epsilon$-Opt': ('avg_cqbepsopt', 'se_cqbepsopt'), 
    'CQB-ThS': ('avg_cqbts', 'se_cqbts'),
    'Q-UCB': ('avg_ucb1', 'se_ucb1'),
    'Q-ThS': ('avg_ts1', 'se_ts1'),
}

def run_mnist_simulation():
    lambda_ = 0.2  
    T = 4000      
    runs = 10    
    seed = 42      
 
    policies = {
        'OptimalPolicy': OptimalPolicy,
        'RandomPolicy': RandomPolicy,
        'CQBEps': CQBEps,
        'CQBOpt': CQBOpt,
        'CQBEpsopt': CQBEpsopt, 
        'CQBts': CQBts,         
        'UCB1': UCB1,
        'TS1': TS1              
    }

    queue_lengths = {name: [] for name in policies.keys()}



    for policy_name, PolicyClass in policies.items():
        for run in tqdm(range(runs)):
            env = ContextualQueue(data_file=None, lambda_=lambda_, T=T, seed=seed + run, data_load_option='keras')
            agent = PolicyClass(env, lambda_=lambda_, seed=seed + run)
            agent.run(T)
            queue_lengths[policy_name].append(env.queue_length_history)


    avg_queue_lengths = {}
    se_queue_lengths = {}
    
    for policy_name, lengths_list in queue_lengths.items():
        lengths_array = np.array(lengths_list)
        avg_queue_lengths[f"avg_{policy_name.lower()}"] = np.mean(lengths_array, axis=0)
        se_queue_lengths[f"se_{policy_name.lower()}"] = np.std(lengths_array, axis=0) 
    
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    filename = f"queue_performance_mnist_T{T}_R{runs}_{timestamp}.npz"
    save_data = {
        'T': T, 'runs': runs, 'lambda_': lambda_, 
        'policies': list(policies.keys()), 
        **avg_queue_lengths, **se_queue_lengths
    }
    np.savez(filename, **save_data)
    return filename

def plot_performance(filename):
    plt.rcParams.update(RC_FONTS)
    plt.figure(figsize=FIGSIZE)
    plt.subplots_adjust(left=0.12, right=0.98, top=0.95, bottom=0.14)
    
    data = np.load(filename, allow_pickle=True)
    T_plot = data['avg_cqbeps'].shape[0]
    x = np.arange(T_plot)
    mark = np.arange(0, T_plot, PLOT_EVERY)
    if mark[-1] != T_plot - 1: mark = np.append(mark, T_plot - 1)

    for label, (color, linestyle, marker, markersize, zorder) in STYLE_MAP.items():
        avg_key, se_key = ALGORITHM_KEYS.get(label, (None, None))
        if avg_key in data:
            avg_data = data[avg_key]
            plt.plot(x, avg_data, label=label, color=color, linestyle=linestyle, linewidth=LINEWIDTH, zorder=zorder)
            plt.plot(mark, avg_data[mark], color=color, linestyle='none', marker=marker, markersize=markersize, zorder=zorder + 1)
            
            se_data = data[se_key]
            plt.errorbar(mark, avg_data[mark], yerr=se_data[mark], fmt='none', ecolor=color, capsize=CAPSIZE, elinewidth=ELINEWIDTH, zorder=zorder)

    plt.xlabel("Time Step t"); plt.ylabel("Q(t)")
    plt.ylim(*Y_LIM)
    plt.legend(**LEGEND_KW)
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.savefig(filename.replace(".npz", ".png"), dpi=DPI, bbox_inches='tight')
    plt.show()

if __name__ == "__main__":
    output_file = run_mnist_simulation()
    plot_performance(output_file)