import os
import sys
import numpy as np
import re
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional

# Add project root to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

from OffClusBandit.core.Tools import build_method_filename, create_algorithm_instance, compute_mean_and_ci, load_regret, get_mean_regrets
from OffClusBandit.configs.default import alpha_list, plot_synthetic_gamma_list, choose_gamma_alpha_list, plot_yelp_gamma_list, plot_ml_gamma_list
from OffClusBandit.configs.default import plot_long_T_yelp_gamma_list, plot_long_T_ml_gamma_list, plot_long_T_synthetic_gamma_list
from OffClusBandit.configs.default import plot_methods as methods
from OffClusBandit.configs.default import offline_learn_method 
from matplotlib.ticker import ScalarFormatter
plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman'],
    'font.size': 16,
    'axes.labelsize': 16,
    'axes.titlesize': 16,
    'legend.fontsize': 15,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'figure.figsize': (20, 4),
    'lines.linewidth': 2,
    'lines.markersize': 6,
    'grid.alpha': 0.5,
    'legend.frameon': False,
    'figure.dpi': 300,
})


BASE_DIR = os.path.dirname(__file__)
PLOT_DATA_DIR = os.path.join(BASE_DIR, 'plot_data')
PLOT_FIGS_DIR = os.path.join(BASE_DIR, 'plot_figs')


def plot_regret_vs_T_multi(files: List[str], methods: Dict[str, Dict], colors_list: Dict[str, str],markers_list: Dict[str, str],linestyles_list: Dict[str, str]):
    fig, axes = plt.subplots(1, 4, figsize=(20, 5), sharey=True)
    T_values = [5000 * i for i in range(21)]
    
    titles = [f'{distribution} under Synthetic Dataset' for distribution in['Equal Distribution', 'Semi-Random Distribution']]
    titles += ['Equal Distribution under Yelp Dataset','Equal Distribution under MovieLens Dataset']

    for idx, file in enumerate(files):
        data = np.load(file)
        for method in methods:
            label_name = re.sub("_estimate_gamma", "", method, count=1)
            label_name = label_name.replace("OffNCLUB", "Off-C²LUB")
            label_name = label_name.replace("OffCLUB", "Off-CLUB")
            label_name = label_name.replace("OffDBSCAN", "DBSCAN")
            label_name = label_name.replace("OffXMeans", "XMeans")
            label_name = label_name.replace("improve", "Improve")
            label_name = label_name.replace("Pessimistic", "Underestimation")
            label_name = label_name.replace("Optimistic", "Overestimation")
            means = data[method][0]
            cis = data[method][1]
            axes[idx].plot(T_values, means, label=label_name, color=colors_list[method], marker=markers_list[method], linestyle=linestyles_list[method])
            axes[idx].fill_between(T_values,
                                  np.array(means) - np.array(cis),
                                  np.array(means) + np.array(cis),
                                  color=colors_list[method], alpha=0.4)
        axes[idx].set_title(f'{titles[idx]}')
        axes[idx].grid(True, linestyle='--', linewidth=0.5, alpha=0.7)

        axes[idx].xaxis.set_major_formatter(ScalarFormatter(useMathText=True))
        axes[idx].ticklabel_format(axis='x', style='sci', scilimits=(0, 0))
        if idx == 0:
            axes[idx].set_ylabel('Average Suboptimality Gap')
        axes[idx].set_xlabel(r'The size of dataset $|\mathcal{D}|$')

    fig.subplots_adjust(wspace=0.05)

    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=10,handletextpad=0.2, labelspacing=0.3,columnspacing=1, prop={'weight': 'bold'})
    plt.tight_layout()
    if not os.path.exists(PLOT_FIGS_DIR):
        os.mkdir(PLOT_FIGS_DIR)
    out_path = os.path.join(PLOT_FIGS_DIR, f"combine_fig_{offline_learn_method}.pdf")
    plt.savefig(out_path, format='pdf', bbox_inches='tight')
    print(f"Figure saved as {out_path}")


def save_npz_regret(results: Dict[str, List[Tuple[int, float, float]]],save_npz_filename = None):
    save_npz = {}
    for method, data in results.items():
        T_values, means, cis = zip(*[(t, mean, ci) for t, mean, ci in data])
        save_npz[method] = [means, cis]
    np.savez(save_npz_filename, **save_npz)


def save_plot_data(offline_learn_method):
    T_values = [5000 * i for i in range(21)]# include 0 and add a previous 0 in plot_gamma_list
    nu = 1000
    d = 20
    m = 10
    L = 20
    alpha = alpha_list[0]
    pj = 1
    uniforms = ['uniform', 'half']
    colors_list =  {}
    for method_name, config in methods.items():
        colors_list[method_name] = config['color']
    datasets = ['synthetic', 'synthetic', 'yelp', 'ml']
    pjs = [0, 1, 0, 0]
    for i in range(len(datasets)):
        dataset = datasets[i]
        pj = pjs[i]
        if dataset == 'synthetic':
            output_folder = 'artifacts/output_data'
            seed_list = list(range(1, 11))
            best_gamma_list = plot_synthetic_gamma_list
            is_empirical = 0
        else:
            is_empirical = 1
            output_folder = 'artifacts/output_data_real'
            seed_list = list(range(1, 11))
            if dataset == 'yelp':
                best_gamma_list = plot_yelp_gamma_list
            else:
                best_gamma_list = plot_ml_gamma_list

        for choose_gamma_alpha in choose_gamma_alpha_list:
            results = get_mean_regrets(
                seed_list=seed_list,
                methods=methods,
                is_empirical=is_empirical,
                dataset=dataset,
                T_values=T_values,
                nu=nu,
                d=d,
                m=m,
                L=L,
                pj=pj,
                choose_gamma_alpha=choose_gamma_alpha,
                alpha=alpha,
                gamma_list=best_gamma_list,
                offline_learn_method=offline_learn_method
            )
            methods_names = ['OffDBSCAN_improve', 'OffXMeans_improve','CLUB', 'OffCLUB', 'LinUCB_IND', 'SCLUB', 'OffNCLUB',
                    'OffNCLUB_estimate_gamma_Pessimistic', 'OffNCLUB_estimate_gamma_Optimistic']
            data = {method: [round(item[1], 6) for item in results[method]] for method in methods_names}

            save_npz_filename = os.path.join(PLOT_DATA_DIR, f"{offline_learn_method}_{uniforms[pj]}_{T_values[-1]}_{dataset}_{choose_gamma_alpha:.3f}_regret_vs_T.npz")
            if not os.path.exists(PLOT_DATA_DIR):
                os.mkdir(PLOT_DATA_DIR)
            save_npz_regret(results, save_npz_filename)

def cal_mean_overcome_best_reward(files, methods, mode: str = 'perT', start_index: int = 4, as_percent: bool = False):
    # output_data =  
    for file in files:
        data = np.load(file)
        series = {m: np.array(data[m][0], dtype=float) for m in methods}

        def avg_reduction_perT(numer_arr, denom_arr):
            numer_arr = np.array(numer_arr, dtype=float)[start_index:]
            denom_arr = np.array(denom_arr, dtype=float)[start_index:]
            with np.errstate(divide='ignore', invalid='ignore'):
                ratios = 1.0 - np.where(denom_arr > 0, numer_arr / denom_arr, np.nan)
            val = float(np.nanmean(ratios))
            return (val * 100.0) if as_percent else val

        def reduction_from_avgs(numer_arr, denom_arr):
            numer_mean = float(np.nanmean(np.array(numer_arr, dtype=float)[start_index:]))
            denom_mean = float(np.nanmean(np.array(denom_arr, dtype=float)[start_index:]))
            val = 0.0 if denom_mean <= 0 else (1.0 - numer_mean / denom_mean)
            return (val * 100.0) if as_percent else val

        reducer = avg_reduction_perT if mode == 'perT' else reduction_from_avgs

        under_vs_offclub = reducer(series['OffNCLUB_estimate_gamma_Pessimistic'], series['OffCLUB'])
        under_vs_armul = reducer(series['OffNCLUB_estimate_gamma_Pessimistic'], series['ARMUL'])
        over_vs_offclub = reducer(series['OffNCLUB_estimate_gamma_Optimistic'], series['OffCLUB'])
        over_vs_armul = reducer(series['OffNCLUB_estimate_gamma_Optimistic'], series['ARMUL'])

        suffix = '%' if as_percent else ''
        print("OffCLUB^2_Underestimation overcome OffCLUB: ", f"{under_vs_offclub:.4f}{suffix}")
        print("OffCLUB^2_Underestimation overcome ARMUL: ", f"{under_vs_armul:.4f}{suffix}")
        print("OffCLUB^2_Overestimation overcome OffCLUB: ", f"{over_vs_offclub:.4f}{suffix}")
        print("OffCLUB^2_Overestimation overcome ARMUL: ", f"{over_vs_armul:.4f}{suffix}")
        # print(dict_data['OffNCLUB_estimate_gamma_Optimistic'] - dict_data['OffCLUB'])
        # print(dict_data['OffNCLUB_estimate_gamma_Optimistic'] - dict_data['ARMUL'])



if __name__ == "__main__":
    # for offline_learn_method in ['random']:
    for offline_learn_method in ['random', 'LinUCB_ind']:
        # save_plot_data(offline_learn_method)

        files = [
            os.path.join(PLOT_DATA_DIR, f"{offline_learn_method}_{distribution}_100000_synthetic_0.100_regret_vs_T.npz") for distribution in ['uniform', 'half']
        ]

        files += [
            os.path.join(PLOT_DATA_DIR, f"{offline_learn_method}_uniform_100000_{dataset}_0.100_regret_vs_T.npz")
            for dataset in ['yelp', 'ml']
        ]
        colors_list = {method: config['color'] for method, config in methods.items()}
        tmp_methods = ['OffDBSCAN_improve', 'OffXMeans_improve', 'LinUCB_IND', 'CLUB', 'SCLUB', 'OffCLUB', 'OffNCLUB',
                    'OffNCLUB_estimate_gamma_Pessimistic', 'OffNCLUB_estimate_gamma_Optimistic', 'ARMUL']
        markers = ['o','d', 's', 'D', '^', 'v', 'p', '*', 'h',  'H', 'X', 'o']
        linestyles = [
            (0, (3, 1, 1, 1, 1, 1)),
            (0, (1, 1)),
            '--',
            '-.',
            ':',
            (0, (3, 1, 1, 1)),
            '-',
            (0, (5, 2, 1, 2)),
            (0, (5, 1)),
            (0, (3, 1, 3, 1)),
            '-',
        ]
        markers_list = {method: markers[i % len(markers)] for i, method in enumerate(tmp_methods)}
        linestyles_list = {method: linestyles[i % len(linestyles)] for i, method in enumerate(tmp_methods)}
        plot_regret_vs_T_multi(files, methods, colors_list,markers_list,linestyles_list)

        # Select mode and parameters
        mode = 'avg'         # 'perT' or 'avg'  
        start_index = 4      # corresponds to |D|>=20k
        as_percent = True

        # Compute four metrics per file (Over/Under relative to Off-CLUB and ARMUL) 
        def compute_four(npz, mode, start_index, as_percent):
            def avg_reduction_perT(numer_arr, denom_arr):
                numer_arr = np.array(numer_arr, dtype=float)[start_index:]
                denom_arr = np.array(denom_arr, dtype=float)[start_index:]
                with np.errstate(divide='ignore', invalid='ignore'):
                    ratios = 1.0 - np.where(denom_arr > 0, numer_arr / denom_arr, np.nan)
                val = float(np.nanmean(ratios))
                return (val * 100.0) if as_percent else val

            def reduction_from_avgs(numer_arr, denom_arr):
                numer_mean = float(np.nanmean(np.array(numer_arr, dtype=float)[start_index:]))
                denom_mean = float(np.nanmean(np.array(denom_arr, dtype=float)[start_index:]))
                val = 0.0 if denom_mean <= 0 else (1.0 - numer_mean / denom_mean)
                return (val * 100.0) if as_percent else val

            reducer = avg_reduction_perT if mode == 'perT' else reduction_from_avgs
            series = {m: np.array(npz[m][0], dtype=float) for m in tmp_methods if m in npz}
            over_off = reducer(series['OffNCLUB_estimate_gamma_Optimistic'], series['OffCLUB'])
            under_off = reducer(series['OffNCLUB_estimate_gamma_Pessimistic'], series['OffCLUB'])
            over_arm = reducer(series['OffNCLUB_estimate_gamma_Optimistic'], series['ARMUL']) if 'SCLUB' in series else float('nan')
            under_arm = reducer(series['OffNCLUB_estimate_gamma_Pessimistic'], series['ARMUL']) if 'SCLUB' in series else float('nan')
            return over_off, under_off, over_arm, under_arm

        # Pick files for datasets (synthetic uses uniform; yelp/ml use uniform)
        def pick_file(files, dataset, distribution='uniform'):
            for fp in files:
                if f"_{distribution}_" in fp and f"_{dataset}_" in fp:
                    return fp
            return None

        syn_fp = pick_file(files, 'synthetic', 'uniform')
        yelp_fp = pick_file(files, 'yelp', 'uniform')
        ml_fp = pick_file(files, 'ml', 'uniform')

        syn_vals = (float('nan'), float('nan'), float('nan'), float('nan'))
        yelp_vals = (float('nan'), float('nan'), float('nan'), float('nan'))
        ml_vals = (float('nan'), float('nan'), float('nan'), float('nan'))

        if syn_fp and os.path.exists(syn_fp):
            syn_npz = np.load(syn_fp)
            syn_vals = compute_four(syn_npz, mode, start_index, as_percent)

        if yelp_fp and os.path.exists(yelp_fp):
            yelp_npz = np.load(yelp_fp)
            yelp_vals = compute_four(yelp_npz, mode, start_index, as_percent)

        if ml_fp and os.path.exists(ml_fp):
            ml_npz = np.load(ml_fp)
            ml_vals = compute_four(ml_npz, mode, start_index, as_percent)

        # Compose LaTeX sentence
        # def fmt_pair(over, under):
        #     return f"{over:.1f}\\% ({under:.1f}\\%)"

        def fmt_pair(over):
            return f"{over:.1f}\\%"

        # syn_off = fmt_pair(syn_vals[0], syn_vals[1])
        # syn_arm = fmt_pair(syn_vals[2], syn_vals[3])
        # yelp_off = fmt_pair(yelp_vals[0], yelp_vals[1])
        # yelp_arm = fmt_pair(yelp_vals[2], yelp_vals[3])
        # ml_off = fmt_pair(ml_vals[0], ml_vals[1])
        # ml_arm = fmt_pair(ml_vals[2], ml_vals[3])

        # latex_sentence = (
        #     "When $|\\mathcal{D}| \\in [20k, 100k]$, Off-C$^2$LUB\\_Overestimation (Underestimation) significantly reduces the suboptimality gap. "
        #     f"On the synthetic dataset, it outperforms Off-CLUB by {syn_off} and other baselines by at least {syn_arm}. "
        #     f"On Yelp, it reduces the gap by {yelp_off} over Off-CLUB and at least {yelp_arm} over baselines. "
        #     f"On MovieLens, it reduces the gap by {ml_off} over Off-CLUB and at least {ml_arm} over baselines."
        # )
        latex_sentence = (
            # f"On the synthetic dataset, Off-C$^2$LUB\_Overestimation improves over Off-CLUB by {fmt_pair(syn_vals[0])} and over other baselines by at least {fmt_pair(syn_vals[2])}, while Off-C$^2$LUB\_Underestimation achieves {fmt_pair(syn_vals[1])} and over other baselines by at least {fmt_pair(syn_vals[3])}. On Yelp, the corresponding improvements are {fmt_pair(yelp_vals[0])} and {fmt_pair(yelp_vals[2])} for Overestimation, and {fmt_pair(yelp_vals[1])} and {fmt_pair(yelp_vals[3])} for Underestimation. On MovieLens, they are {fmt_pair(ml_vals[0])} and {fmt_pair(ml_vals[2])} for Overestimation, and {fmt_pair(ml_vals[1])} and {fmt_pair(ml_vals[3])} for Underestimation."
            f"On the synthetic dataset, Off-C$^2$LUB\_Overestimation improves over Off-CLUB by {fmt_pair(syn_vals[0])} and over other baselines by at least {fmt_pair(syn_vals[2])}, while Off-C$^2$LUB\_Underestimation achieves {fmt_pair(syn_vals[1])} and {fmt_pair(syn_vals[3])}, respectively. On Yelp, the corresponding improvements are {fmt_pair(yelp_vals[0])} and {fmt_pair(yelp_vals[2])} for Overestimation, and {fmt_pair(yelp_vals[1])} and {fmt_pair(yelp_vals[3])} for Underestimation. On MovieLens, they are {fmt_pair(ml_vals[0])} and {fmt_pair(ml_vals[2])} for Overestimation, and {fmt_pair(ml_vals[1])} and {fmt_pair(ml_vals[3])} for Underestimation."
        )

        print(latex_sentence)