import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from typing import List, Dict, Tuple, Optional
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from OffClusBandit.core.Tools import compute_mean_and_ci, get_gamma_regrets
from OffClusBandit.configs.default import alpha_list, choose_gamma_alpha_list, dataset_configs
from OffClusBandit.configs.default import plot_gamma_methods as methods
from OffClusBandit.configs.default import gamma_vary_list as gamma_list
import re
plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman'],
    'font.size': 12,
    'axes.labelsize': 12,
    'axes.titlesize': 14,
    'legend.fontsize': 12,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'figure.figsize': (5, 4),
    'lines.linewidth': 2,
    'lines.markersize': 6,
    'grid.alpha': 0.3,
    'legend.frameon': False,
    'figure.dpi': 300,
})




def plot_avg_regret_vs_gamma_three_datasets(
    gamma_list: List[float],
    methods: Dict[str, Dict],
    T: int,
    nu: int,
    d: int,
    m: int,
    L: int,
    pj: int,
    alpha: float,
    choose_gamma_alpha: float,
    color_list: List[str] = None,
    save_filename: str = None,
    y_limit: float = 0.25,
    num_datasets: int = 3
) -> None:

    tmp_methods = ['OffDBSCAN_improve', 'OffXMeans_improve', 'LinUCB_IND', 'CLUB', 'SCLUB', 'OffCLUB', 'OffNCLUB', 'OffNCLUB_estimate_gamma_Pessimistic', 'OffNCLUB_estimate_gamma_Optimistic']
    markers = ['o' ,'d', 's', 'D', '^', 'v', 'p', '*',  'h', 'd', 'H']
    markers_list = {method: markers[i % len(markers)] for i, method in enumerate(tmp_methods)}
    fig, axes = plt.subplots(1, num_datasets, figsize=(int(5*num_datasets), 4), sharey=True)
    method_colors = {}
    num_colors = len(color_list)
    all_method_names = list(methods.keys())
    for idx, method_name in enumerate(all_method_names):
        method_colors[method_name] = color_list[idx % num_colors]
    all_handles = []
    all_labels = []
    if num_datasets == 2:
        methods.pop('OffXMeans_improve') # active it if plot two subplot
    for i, config in enumerate(dataset_configs[:num_datasets]):
        ax = axes[i]
        dataset_name = config["dataset"]
        is_empirical = config["is_empirical"]
        seed_list = config["seed_list"]
        subplot_title = config["title"]

        results = get_gamma_regrets(
            seed_list=seed_list,
            methods=methods,
            is_empirical=is_empirical,
            dataset=dataset_name,
            T=T,
            nu=nu,
            d=d,
            m=m,
            L=L,
            pj=pj,
            choose_gamma_alpha=choose_gamma_alpha,
            alpha=alpha,
            gamma_list=gamma_list,
        )

        for method_name, regrets in results.items():
            c = method_colors[method_name]
            label_name = re.sub("_estimate_gamma", "", method_name, count=1)
            label_name = label_name.replace("OffNCLUB", "Off-C²LUB")
            label_name = label_name.replace("OffCLUB", "Off-CLUB")
            label_name = label_name.replace("OffDBSCAN", "DBSCAN")
            label_name = label_name.replace("OffXMeans", "XMeans")
            label_name = label_name.replace("improve", "Improve")
            label_name = label_name.replace("Pessimistic", "Underestimation")
            label_name = label_name.replace("Optimistic", "Overestimation")
            if methods[method_name]['gamma_required']:
                means = []
                for j in range(len(gamma_list)):
                    gamma_regrets = [regrets[k] for k in range(j, len(regrets), len(gamma_list))]
                    mean = np.mean(np.array(gamma_regrets))
                    means.append(mean)
                line_obj, = ax.plot(
                    gamma_list,
                    means,
                    color=c,
                    linestyle='-',
                    label=label_name,
                    marker=markers_list[method_name],
                    markevery = 10
                )
                # Print gamma corresponding to the lowest mean
                min_mean_index = np.argmin(means)
                min_mean_gamma = gamma_list[min_mean_index]
                print(f"{method_name} min-mean corresponding gamma: {min_mean_gamma}")
            else:

                mean = regrets[0]
                mean_extended = np.full(len(gamma_list), mean)
                line_obj, = ax.plot(
                    gamma_list,
                    mean_extended,
                    color=c,
                    linestyle='--',
                    label=label_name,
                    marker=markers_list[method_name],
                    markevery = 10
                )
            all_handles.append(line_obj)
            all_labels.append(label_name)

        ax.set_title(subplot_title, fontsize=16)
        ax.set_xlabel(r'$\hat{\gamma}$', fontsize=16)
        ax.tick_params(labelsize=16)
        ax.set_ylim(0, y_limit)
        if i == 0:
            ax.set_ylabel('Average Suboptimality Gap', fontsize=16)
        ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)

    handles_labels = dict()
    for h, lbl in zip(all_handles, all_labels):
        handles_labels[lbl] = h
    unique_handles = list(handles_labels.values())
    unique_labels = list(handles_labels.keys())



    if num_datasets == 2:   
        fig.legend(
            unique_handles,
            unique_labels,
            loc='upper center',
            ncol=len(unique_labels),
            bbox_to_anchor=(0.5, 1.05),
            fontsize=15,
            handletextpad=0.2, 
            labelspacing=0.3,
            columnspacing=1, prop={'weight': 'bold'})

    else:
        fig.legend(
        unique_handles,
        unique_labels,
        loc='upper center',
        ncol=len(unique_labels),
        bbox_to_anchor=(0.5, 1.05),
        fontsize=15, prop={'weight': 'bold'})


    plt.tight_layout()
    if save_filename is not None:
        base_dir = os.path.dirname(__file__)
        plot_figs_dir = os.path.join(base_dir, 'plot_figs')
        if not os.path.exists(plot_figs_dir):
            os.mkdir(plot_figs_dir)
        out_path = os.path.join(plot_figs_dir, os.path.basename(save_filename))
        plt.savefig(out_path, format='pdf', bbox_inches='tight')
        print(f"Figure saved as {out_path}")

# Compute and print the ratios relative to the optimal Off-C^2LUB and auto-generate LaTeX text
def compute_min_means_for_dataset(dataset_cfg, T):
    res = get_gamma_regrets(
        seed_list=dataset_cfg["seed_list"],
        methods=methods,
        is_empirical=dataset_cfg["is_empirical"],
        dataset=dataset_cfg["dataset"],
        T=T,
        nu=1000,
        d=20,
        m=10,
        L=20,
        pj=0,
        choose_gamma_alpha=choose_gamma_alpha,
        alpha=alpha,
        gamma_list=gamma_list,
    )
    min_means = {}
    for method_name, regrets in res.items():
        if methods[method_name]['gamma_required']:
            per_gamma_means = []
            for j in range(len(gamma_list)):
                gamma_regrets = [regrets[k] for k in range(j, len(regrets), len(gamma_list))]
                per_gamma_means.append(float(np.mean(np.array(gamma_regrets))))
            min_means[method_name] = float(np.min(np.array(per_gamma_means)))
        else:
            min_means[method_name] = float(regrets[0])
    return min_means

def build_latex_sentence_for_datasets(T, picked_datasets=("synthetic", "yelp")):
    name_map = {
        'synthetic': 'the synthetic dataset',
        'yelp': 'Yelp',
        'ml': 'MovieLens',
    }
    keys = {
        'dbscan': 'OffDBSCAN_improve',
        'club': 'CLUB',
        'offclub': 'OffCLUB',
        'under': 'OffNCLUB_estimate_gamma_Pessimistic',
        'over': 'OffNCLUB_estimate_gamma_Optimistic',
        'base': 'OffNCLUB',
    }
    dm_configs = {cfg['dataset']: cfg for cfg in dataset_configs}
    parts = []
    for ds in picked_datasets:
        cfg = dm_configs[ds]
        min_means = compute_min_means_for_dataset(cfg, T)
        base = min_means.get(keys['base'], np.nan)
        def ratio(method_key):
            val = min_means.get(method_key, np.nan)
            return float(val / base) if (np.isfinite(base) and base > 0) else np.nan
        db = ratio(keys['dbscan'])
        cl = ratio(keys['club'])
        offclub = ratio(keys['offclub'])
        un = ratio(keys['under'])
        ov = ratio(keys['over'])
        parts.append((ds, db, cl, un, ov))

    def fmt(x):
        return f"{x:.2f}\\times" if np.isfinite(x) else "N/A"

    sentences = []
    for ds, db, cl, un, ov in parts:
        if ds == 'synthetic':
            sentences.append(
                f"On {name_map[ds]}, relative suboptimality gaps compared to the optimal $\\hat{{\\gamma}}$ of Off-C$^2$LUB are: "
                f"DBSCAN (${fmt(db)}$), CLUB (${fmt(cl)}$), Off-CLUB (${fmt(offclub)}$), Off-C$^2$LUB\\_Underestimation (${fmt(un)}$), and Off-C$^2$LUB\\_Overestimation (${fmt(ov)}$)."
            )
        else:
            sentences.append(
                f"Similar trends hold for {name_map[ds]}: DBSCAN (${fmt(db)}$), CLUB (${fmt(cl)}$), Off-CLUB (${fmt(offclub)}$), Underestimation (${fmt(un)}$), and Overestimation (${fmt(ov)}$)."
            )
    return " ".join(sentences)


if __name__ == "__main__":
    T_list = [5000 * i for i in range(6, 9)]
    color_list = [config['color'] for method, config in methods.items()]
    choose_gamma_alpha = choose_gamma_alpha_list[0]
    alpha = alpha_list[0]
    for T in T_list:
        choose_gamma_alpha = choose_gamma_alpha_list[0]
        alpha = alpha_list[0]
        plot_avg_regret_vs_gamma_three_datasets(
            gamma_list=gamma_list,
            methods=methods,
            T=T,
            nu=1000,
            d=20,
            m=10,
            L=20,
            pj=0,
            alpha=alpha,
            choose_gamma_alpha=choose_gamma_alpha,
            color_list=color_list,
            save_filename=f"plot_figs/three_datasets_gamma_comparison_{T}.pdf",
            y_limit=0.25,
            num_datasets=3
        )

    T = 30000
    plot_avg_regret_vs_gamma_three_datasets(
    gamma_list=gamma_list,
    methods=methods,
    T=T,
    nu=1000,
    d=20,
    m=10,
    L=20,
    pj=0,
    alpha=alpha,
    choose_gamma_alpha=choose_gamma_alpha,
    color_list=color_list,
    save_filename=f"plot_figs/two_datasets_gamma_comparison_{T}.pdf",
    y_limit=0.25,
    num_datasets=2
)
    

    # latex_text = build_latex_sentence_for_datasets(T, picked_datasets=("synthetic", "yelp"))
    # print(latex_text)