


import numpy as np
import pandas as pd
from tqdm import tqdm
import datetime
import os
import sys
import time
from sklearn.preprocessing import StandardScaler
from ucimlrepo import fetch_ucirepo
import matplotlib.pyplot as plt


try:
    from util import * 
    from env import ContextualQueue 
    from agent import (
        OptimalPolicy, RandomPolicy, CQBEps, CQBOpt, 
        CQBEpsopt, CQBts, UCB1, TS1 
    ) 
except ImportError as e:
    print(f"Error importing essential files (util.py, env.py, agent.py): {e}")
    sys.exit(1)


UCI_COUPON_ID = 603
FINAL_COUPON_FILE = 'preprocessed_coupon_recommendation_final1.csv'
TARGET_COLUMN = 'Y' 




def preprocess_coupon_data(data_id=UCI_COUPON_ID, file_name=FINAL_COUPON_FILE):

    
    in_vehicle_coupon_recommendation = fetch_ucirepo(id=data_id)

    X = in_vehicle_coupon_recommendation.data.features
    y = in_vehicle_coupon_recommendation.data.targets[TARGET_COLUMN]

    X = X.replace('nan', np.nan)

    cols_with_missing = ['Bar', 'CoffeeHouse', 'CarryAway',
                         'RestaurantLessThan20', 'Restaurant20To50', 'car']


    for col in cols_with_missing:
        if col in X.columns:
            X[col] = X[col].fillna('Missing')

    categorical_cols = X.select_dtypes(include=['object']).columns
    X_encoded = pd.get_dummies(X, columns=categorical_cols, dummy_na=False)

    numeric_cols = X_encoded.select_dtypes(include=['int64', 'float64']).columns.tolist()

    cols_to_exclude_from_scale = ['has_children', 'toCoupon_GEQ5min', 'toCoupon_GEQ15min', 'toCoupon_GEQ25min', 'direction_same', 'direction_opp']
    
    cols_to_scale = [col for col in numeric_cols if col not in cols_to_exclude_from_scale]

    if cols_to_scale:
        scaler = StandardScaler()
        X_encoded[cols_to_scale] = scaler.fit_transform(X_encoded[cols_to_scale])

    y_df = y.to_frame()
    final_df = pd.concat([X_encoded, y_df], axis=1)

    final_df.to_csv(file_name, index=False, encoding='utf-8')
    
    print(f"Preprocessing complete. Data saved to: {file_name}")
    return file_name


def run_all_policies(data_file):
    
    if not os.path.exists(data_file):
        print(f"Error: Data file '{data_file}' not found.")
        return None

    lambda_ = 0.5
    T = 4000     
    runs = 10   
    seed = 42   
    
    policies = {
        'OptimalPolicy': OptimalPolicy,
        'RandomPolicy': RandomPolicy,
        'CQBEps': CQBEps,
        'CQBOpt': CQBOpt,
        'CQBEpsopt': CQBEpsopt, 
        'CQBts': CQBts,         
        'UCB1': UCB1,
        'TS1': TS1              
    }

    queue_lengths = {name: [] for name in policies.keys()}
    
    print(f"\n--- 🔢 Coupon Recommendation Simulation (T={T}, Runs={runs}, lambda={lambda_}) ---")
    
    try:
       
        temp_env = ContextualQueue(data_file=data_file, T=T, seed=seed, lambda_=lambda_)
        K_val = temp_env.K
        d_val = temp_env.d
        N_val = temp_env.N
        print(f"Step 2/3: Loaded Data. N={N_val} instances, d={d_val} features, K={K_val} servers (0, 1).")
    except Exception as e:
        print(f"Environment initialization error: {e}")
        return None


    for policy_name, PolicyClass in policies.items():
        print(f"\nRunning {policy_name}...")
        start_time = time.time()
        
        for run in tqdm(range(runs)):
            env = ContextualQueue(data_file=data_file, lambda_=lambda_, T=T, seed=seed + run)
            
            agent = PolicyClass(env, lambda_=lambda_, seed=seed + run)

            agent.run(T)
            
            queue_lengths[policy_name].append(env.queue_length_history)
        
        end_time = time.time()
        print(f"  {policy_name} avg time: {(end_time - start_time) / runs:.2f} s")


    avg_queue_lengths = {}
    se_queue_lengths = {}
    
    for policy_name, lengths_list in queue_lengths.items():
        if lengths_list:
            lengths_array = np.array(lengths_list) 
            
            avg_queue_lengths[f"avg_{policy_name.lower()}"] = np.mean(lengths_array, axis=0)
            
            if runs > 1:
                se_queue_lengths[f"se_{policy_name.lower()}"] = np.std(lengths_array, axis=0)
            else:
                se_queue_lengths[f"se_{policy_name.lower()}"] = np.zeros(T)
        else:
            avg_queue_lengths[f"avg_{policy_name.lower()}"] = np.zeros(T)
            se_queue_lengths[f"se_{policy_name.lower()}"] = np.zeros(T)
    
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    filename = f"queue_performance_coupon_l{lambda_}_R{runs}_{timestamp}.npz"
    
    save_data = {
        'T': T, 
        'runs': runs, 
        'lambda_': lambda_, 
        'd': d_val, 
        'K': K_val,
        'N': N_val,
        'policies': list(policies.keys()), 
        **avg_queue_lengths,
        **se_queue_lengths 
    }
    
    np.savez(filename, **save_data)

    print(f"\nSimulation complete. Data saved to: {filename}")
    return filename


FIGSIZE = (20, 14)
DPI = 300
RC_FONTS = {
    'font.size': 50,
    'axes.labelsize': 50,
    'xtick.labelsize': 40,
    'ytick.labelsize': 40,
    'legend.fontsize': 40
}
LINEWIDTH = 10
MARKERSIZE = 20
CAPSIZE = 10
ELINEWIDTH = 3
PLOT_EVERY = 500
Y_LIM = (0, 100) 
LEGEND_KW = dict(loc='best', markerscale=1.2, framealpha=1.0)

def start_fig():
    plt.rcParams.update(RC_FONTS)
    plt.figure(figsize=FIGSIZE)
    plt.subplots_adjust(left=0.12, right=0.98, top=0.95, bottom=0.14)


STYLE_MAP = {
    r'CQB-$\epsilon$': ('#FF3300', '-', 'o', MARKERSIZE, 3),
    'CQB-Opt': ('#4169E1', '-', 's', MARKERSIZE, 4),
    'Optimal Policy': ('#2ca02c', '--', '^', MARKERSIZE, 5),
    'Random Policy': ('gray', ':', 'D', MARKERSIZE, 2),
    r'CQB-$\epsilon$-Opt': ('#FFC300', ':', 'p', MARKERSIZE, 3), 
    'CQB-ThS': ('#800080', ':', 'h', MARKERSIZE, 3),
    'Q-UCB': ('#17becf', '-.', 'v', MARKERSIZE, 3),
    'Q-ThS': ('#E377C2', '-.', '*', MARKERSIZE + 5, 3),
}

ALGORITHM_KEYS = {
    r'CQB-$\epsilon$': ('avg_cqbeps', 'se_cqbeps'),
    'CQB-Opt': ('avg_cqbopt', 'se_cqbopt'),
    'Optimal Policy': ('avg_optimalpolicy', 'se_optimalpolicy'),
    'Random Policy': ('avg_randompolicy', 'se_randompolicy'),
    r'CQB-$\epsilon$-Opt': ('avg_cqbepsopt', 'se_cqbepsopt'), 
    'CQB-ThS': ('avg_cqbts', 'se_cqbts'),
    'Q-UCB': ('avg_ucb1', 'se_ucb1'),
    'Q-ThS': ('avg_ts1', 'se_ts1'),
}


def plot_performance(filename):
    print(f"Step 3/3: Generating Plot from {filename} ---")
    start_fig()
    
    if not filename:
        return

    try:
        data = np.load(filename, allow_pickle=True)
    except FileNotFoundError:
        print(f"Error: The file '{filename}' was not found.")
        return

    try:
        T = int(np.array(data['T']).item()) if 'T' in data else 4000
        runs = int(np.array(data['runs']).item()) if 'runs' in data else 1
        
        first_avg_key = next(iter(ALGORITHM_KEYS.values()))[0]
        if first_avg_key in data:
            T_plot = data[first_avg_key].shape[0]
        else:
            T_plot = T

        x = np.arange(T_plot)
        mark = np.arange(0, T_plot, PLOT_EVERY)
        if T_plot > 0 and mark[-1] != T_plot - 1:
            mark = np.append(mark, T_plot - 1)
            mark = np.unique(mark)
            
    except Exception as e:
        print(f"Error reading metadata from NPZ: {e}")
        return

    se_data_present = 'se_optimalpolicy' in data 

    for label, (color, linestyle, marker, markersize, zorder) in STYLE_MAP.items():
        
        avg_key, se_key = ALGORITHM_KEYS.get(label, (None, None))
        
        if avg_key in data:
            try:
                avg_data = data[avg_key][:T_plot]

                plt.plot(x, avg_data, label=label,
                         color=color, linestyle=linestyle, linewidth=LINEWIDTH, zorder=zorder)

                plt.plot(mark, avg_data[mark],
                         color=color, linestyle='none',
                         marker=marker, markersize=markersize, zorder=zorder + 1)
                
                if runs > 1 and se_data_present and se_key in data:
                    se_data = data[se_key][:T_plot]
                    
                    if se_data.size > 0 and np.any(se_data[mark] != 0):
                        plt.errorbar(mark, avg_data[mark], yerr=se_data[mark],
                                     fmt='none', ecolor=color, capsize=CAPSIZE, elinewidth=ELINEWIDTH, zorder=zorder)
            except KeyError:
                print(f"Warning: Key '{avg_key}' missing from NPZ file.")
            except Exception as e:
                print(f"Error plotting {label}: {e}")
        else:
            pass 


    plt.xlabel("Time Step t"); plt.ylabel("Q(t)")
    plt.xlim(0, T_plot)
    plt.ylim(*Y_LIM)
    plt.legend(**LEGEND_KW)
    plt.grid(True, linestyle='--', alpha=0.6)

    out = f'./{os.path.basename(filename).replace(".npz", "_plot.png")}'
    plt.gcf().set_size_inches(*FIGSIZE)
    plt.savefig(out, dpi=DPI, bbox_inches='tight')
    plt.show() 
    print(f"\n✅ Graph saved to: {out}")



if __name__ == "__main__":

    try:
        final_data_file = preprocess_coupon_data(file_name=FINAL_COUPON_FILE)
        
    except Exception as e:
        print(f"Data preparation error: {e}")
        sys.exit(1)


    if final_data_file:
        output_npz_file = run_all_policies(data_file=final_data_file)

        if output_npz_file:
            plot_performance(output_npz_file)

