
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import datetime
import os
import sys
import time
from sklearn.preprocessing import StandardScaler
from imblearn.over_sampling import SMOTE
from ucimlrepo import fetch_ucirepo

TARGET_COLUMN = 'HeartDisease_Level'
UCI_ID = 45
RANDOM_STATE = 42
TARGET_SIZE_EXPANDED = 5000
TEMP_FILE_NAME = 'preprocessed_heart_disease_temp.csv'
FINAL_FILE_NAME = 'smote_random_sampled_heart_disease_5000.csv'


try:
    from util import *
    from env import ContextualQueue
    from agent import (
        OptimalPolicy, RandomPolicy, CQBEps, CQBOpt,
        CQBEpsopt, CQBts, UCB1, TS1
    )
except ImportError:
    sys.exit(1)

FIGSIZE = (20, 14)
DPI = 300
RC_FONTS = {
    'font.size': 50,
    'axes.labelsize': 50,
    'xtick.labelsize': 40,
    'ytick.labelsize': 40,
    'legend.fontsize': 40
}
LINEWIDTH = 10
MARKERSIZE = 20
CAPSIZE = 10
ELINEWIDTH = 3
PLOT_EVERY = 500
Y_LIM = (0, 500)
LEGEND_KW = dict(loc='best', markerscale=1.2, framealpha=1.0)

STYLE_MAP = {
    r'CQB-$\epsilon$': ('#FF3300', '-', 'o', MARKERSIZE, 3),
    'CQB-Opt': ('#4169E1', '-', 's', MARKERSIZE, 4),
    'Optimal Policy': ('#2ca02c', '--', '^', MARKERSIZE, 5),
    'Random Policy': ('gray', ':', 'D', MARKERSIZE, 2),
    r'CQB-$\epsilon$-Opt': ('#FFC300', ':', 'p', MARKERSIZE, 3),
    'CQB-ThS': ('#800080', ':', 'h', MARKERSIZE, 3),
    'Q-UCB': ('#17becf', '-.', 'v', MARKERSIZE, 3),
    'Q-ThS': ('#E377C2', '-.', '*', MARKERSIZE + 5, 3),
}

ALGORITHM_KEYS = {
    r'CQB-$\epsilon$': ('avg_cqbeps', 'se_cqbeps'),
    'CQB-Opt': ('avg_cqbopt', 'se_cqbopt'),
    'Optimal Policy': ('avg_optimalpolicy', 'se_optimalpolicy'),
    'Random Policy': ('avg_randompolicy', 'se_randompolicy'),
    r'CQB-$\epsilon$-Opt': ('avg_cqbepsopt', 'se_cqbepsopt'),
    'CQB-ThS': ('avg_cqbts', 'se_cqbts'),
    'Q-UCB': ('avg_ucb1', 'se_ucb1'),
    'Q-ThS': ('avg_ts1', 'se_ts1'),
}


# ----------------------------------------------------------------------
# Function Definitions
# ----------------------------------------------------------------------

def preprocess_heart_disease_data(data_id=UCI_ID):
    
    heart_disease = fetch_ucirepo(id=data_id)
    X = heart_disease.data.features
    y = heart_disease.data.targets

    X = X.replace('?', np.nan)
    combined_data = pd.concat([X, y], axis=1)
    combined_data_cleaned = combined_data.dropna()

    X_cleaned = combined_data_cleaned.drop(y.columns, axis=1)
    y_multi_class = combined_data_cleaned[y.columns[0]]

    categorical_cols = X_cleaned.select_dtypes(include=['object']).columns
    X_encoded = pd.get_dummies(X_cleaned, columns=categorical_cols, dummy_na=False)

    numeric_cols = X_encoded.select_dtypes(include=['int64', 'float64']).columns.tolist()
    scaler = StandardScaler()
    X_scaled = X_encoded.copy()
    X_scaled[numeric_cols] = scaler.fit_transform(X_scaled[numeric_cols])

    y_df = y_multi_class.to_frame(name=TARGET_COLUMN)
    final_df = pd.concat([X_scaled, y_df], axis=1)

    final_df.to_csv(TEMP_FILE_NAME, index=False, encoding='utf-8')

    return True

def run_smote_and_sampling():
    
    try:
        df_preprocessed = pd.read_csv(TEMP_FILE_NAME)
    except FileNotFoundError:
        raise

    X_final = df_preprocessed.drop(TARGET_COLUMN, axis=1)
    y_final = df_preprocessed[TARGET_COLUMN]

    smote = SMOTE(random_state=RANDOM_STATE)
    X_resampled, y_resampled = smote.fit_resample(X_final, y_final)

    y_resampled_df = y_resampled.to_frame(name=TARGET_COLUMN)
    df_resampled_base = pd.concat([X_resampled, y_resampled_df], axis=1)

    df_expanded = df_resampled_base.sample(
        n=TARGET_SIZE_EXPANDED,
        replace=True,
        random_state=RANDOM_STATE
    )

    df_expanded.to_csv(FINAL_FILE_NAME, index=False, encoding='utf-8')

    try:
        os.remove(TEMP_FILE_NAME)
    except OSError:
        pass

    return FINAL_FILE_NAME


def run_all_policies(data_file):
    
    if not os.path.exists(data_file):
        return None

    lambda_ = 0.2
    T = 4000
    runs = 10
    seed = 42

    policies = {
        'OptimalPolicy': OptimalPolicy,
        'RandomPolicy': RandomPolicy,
        'CQBEps': CQBEps,
        'CQBOpt': CQBOpt,
        'CQBEpsopt': CQBEpsopt,
        'CQBts': CQBts,
        'UCB1': UCB1,
        'TS1': TS1
    }

    queue_lengths = {name: [] for name in policies.keys()}

    try:
        temp_env = ContextualQueue(data_file=data_file, T=T, seed=seed, lambda_=lambda_)
        K_val = temp_env.K
        d_val = temp_env.d
        N_val = temp_env.N
    except Exception:
        return None

    for policy_name, PolicyClass in policies.items():
        for run in tqdm(range(runs)):
            env = ContextualQueue(data_file=data_file, lambda_=lambda_, T=T, seed=seed + run)
            agent = PolicyClass(env, lambda_=lambda_, seed=seed + run)

            agent.run(T)

            queue_lengths[policy_name].append(env.queue_length_history)

    avg_queue_lengths = {}
    se_queue_lengths = {}

    for policy_name, lengths_list in queue_lengths.items():
        if lengths_list:
            lengths_array = np.array(lengths_list)

            avg_queue_lengths[f"avg_{policy_name.lower()}"] = np.mean(lengths_array, axis=0)

            if runs > 1:
                se_queue_lengths[f"se_{policy_name.lower()}"] = np.std(lengths_array, axis=0)
            else:
                se_queue_lengths[f"se_{policy_name.lower()}"] = np.zeros(lengths_array.shape[1])
        else:
             avg_queue_lengths[f"avg_{policy_name.lower()}"] = np.zeros(T)
             se_queue_lengths[f"se_{policy_name.lower()}"] = np.zeros(T)

    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    filename = f"queue_performance_heart_disease_T{T}_R{runs}_{timestamp}.npz"

    save_data = {
        'T': T,
        'runs': runs,
        'lambda_': lambda_,
        'd': d_val,
        'K': K_val,
        'N': N_val,
        'policies': list(policies.keys()),
        **avg_queue_lengths,
        **se_queue_lengths
    }

    np.savez(filename, **save_data)

    return filename

def start_fig():
    plt.rcParams.update(RC_FONTS)
    plt.figure(figsize=FIGSIZE)
    plt.subplots_adjust(left=0.12, right=0.98, top=0.95, bottom=0.14)


def plot_performance(filename):
    start_fig()

    if not filename:
        return

    try:
        data = np.load(filename, allow_pickle=True)
    except FileNotFoundError:
        return

    try:
        T = int(np.array(data['T']).item()) if 'T' in data else 4000
        runs = int(np.array(data['runs']).item()) if 'runs' in data else 1

        first_avg_key = next(iter(ALGORITHM_KEYS.values()))[0]
        if first_avg_key in data:
            T_plot = data[first_avg_key].shape[0]
        else:
            T_plot = T

        x = np.arange(T_plot)
        mark = np.arange(0, T_plot, PLOT_EVERY)
        if T_plot > 0 and mark[-1] != T_plot - 1:
            mark = np.append(mark, T_plot - 1)
            mark = np.unique(mark)

    except Exception:
        return

    se_data_present = 'se_optimalpolicy' in data and runs > 1

    for label, (color, linestyle, marker, markersize, zorder) in STYLE_MAP.items():

        avg_key, se_key = ALGORITHM_KEYS.get(label, (None, None))

        if avg_key in data:
            try:
                avg_data = data[avg_key][:T_plot]

                plt.plot(x, avg_data, label=label,
                         color=color, linestyle=linestyle, linewidth=LINEWIDTH, zorder=zorder)

                plt.plot(mark, avg_data[mark],
                         color=color, linestyle='none',
                         marker=marker, markersize=markersize, zorder=zorder + 1)

                if se_data_present and se_key in data:
                    se_data = data[se_key][:T_plot]

                    if se_data.size > 0 and np.any(se_data[mark] != 0):
                        plt.errorbar(mark, avg_data[mark], yerr=se_data[mark],
                                     fmt='none', ecolor=color, capsize=CAPSIZE, elinewidth=ELINEWIDTH, zorder=zorder)
            except KeyError:
                pass
            except Exception:
                pass
        else:
            pass


    plt.xlabel("Time Step t"); plt.ylabel("Q(t)")
    plt.xlim(0, T_plot)
    plt.ylim(*Y_LIM)
    plt.legend(**LEGEND_KW)
    plt.grid(True, linestyle='--', alpha=0.6)

    out = f'./{os.path.basename(filename).replace(".npz", "_plot.png")}'
    plt.gcf().set_size_inches(*FIGSIZE)
    plt.savefig(out, dpi=DPI, bbox_inches='tight')
    plt.show()



if __name__ == "__main__":

    final_data_file = None

    try:
        if preprocess_heart_disease_data():
            final_data_file = run_smote_and_sampling()
        else:
            final_data_file = None

    except Exception:
        if os.path.exists(TEMP_FILE_NAME):
            os.remove(TEMP_FILE_NAME)
        sys.exit(1)

    if final_data_file:
        output_npz_file = run_all_policies(data_file=final_data_file)

        if output_npz_file:
            plot_performance(output_npz_file)