import os
import sys
import numpy as np
import re
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional

# Add project root to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

BASE_DIR = os.path.dirname(__file__)
PLOT_DATA_DIR = os.path.join(BASE_DIR, 'plot_data')

from OffClusBandit.core.Tools import build_method_filename, create_algorithm_instance, compute_mean_and_ci, load_regret, get_mean_regrets
from OffClusBandit.configs.default import alpha_list, plot_synthetic_gamma_list, choose_gamma_alpha_list, plot_yelp_gamma_list, plot_ml_gamma_list
from OffClusBandit.configs.default import long_T_yelp_gamma_list, long_T_ml_gamma_list, long_T_synthetic_gamma_list
from OffClusBandit.configs.default import plot_methods as methods
from OffClusBandit.configs.default import offline_learn_method 
from matplotlib.ticker import ScalarFormatter


def save_npz_regret(results: Dict[str, List[Tuple[int, float, float]]],save_npz_filename = None):
    save_npz = {}
    for method, data in results.items():
        T_values, means, cis = zip(*[(t, mean, ci) for t, mean, ci in data])
        save_npz[method] = [means, cis]
    np.savez(save_npz_filename, **save_npz)


def save_plot_data(offline_learn_method):
    T_values = [200000 * i for i in range(1, 6)]# include 0 and add a previous 0 in plot_gamma_list
    nu = 1000
    d = 20
    m = 10
    L = 20
    alpha = alpha_list[0]
    pj = 0
    uniforms = ['uniform', 'half']
    colors_list =  {}
    for method_name, config in methods.items():
        colors_list[method_name] = config['color']
    datasets = ['synthetic', 'yelp', 'ml']
    for i in range(len(datasets)):
        dataset = datasets[i]
        if dataset == 'synthetic':
            output_folder = 'artifacts/output_data'
            seed_list = list(range(1, 11))
            best_gamma_list = long_T_synthetic_gamma_list
            is_empirical = 0
        else:
            is_empirical = 1
            output_folder = 'artifacts/output_data_real'
            seed_list = list(range(1, 11))
            if dataset == 'yelp':
                best_gamma_list = long_T_yelp_gamma_list

            else:
                best_gamma_list = long_T_ml_gamma_list

        for choose_gamma_alpha in choose_gamma_alpha_list:
            results = get_mean_regrets(
                seed_list=seed_list,
                methods=methods,
                is_empirical=is_empirical,
                dataset=dataset,
                T_values=T_values,
                nu=nu,
                d=d,
                m=m,
                L=L,
                pj=pj,
                choose_gamma_alpha=choose_gamma_alpha,
                alpha=alpha,
                gamma_list=best_gamma_list,
                offline_learn_method=offline_learn_method
            )
            methods_names = ['OffDBSCAN_improve', 'OffXMeans_improve','CLUB', 'OffCLUB', 'LinUCB_IND', 'SCLUB', 'ARMUL', 'OffNCLUB',
                    'OffNCLUB_estimate_gamma_Pessimistic', 'OffNCLUB_estimate_gamma_Optimistic']
            data = {method: [round(item[1], 6) for item in results[method]] for method in methods_names}

            save_npz_filename = os.path.join(PLOT_DATA_DIR, f"{offline_learn_method}_{uniforms[pj]}_{T_values[-1]}_{dataset}_{choose_gamma_alpha:.3f}_regret_vs_T.npz")
            if not os.path.exists(PLOT_DATA_DIR):
                os.mkdir(PLOT_DATA_DIR)
            save_npz_regret(results, save_npz_filename)

def generate_latex_table():
    """Generate LaTeX table"""
    save_plot_data("random")
    
    # Output LaTeX table
    latex_table = """\\begin{tabular}{|l|c|c|c|c|c|}
\\hline
\\diagbox{\\large Algorithm }{\\large Dataset size} & \\large $0.2M$ & \\large $0.4M$ & \\large $0.6M$ & \\large $0.8M$ & \\large $1M$ \\\\ \\hline
DBSCAN\\_Improve & 0.130149 & 0.058302 & 0.024758 & 0.011941 & 0.006520 \\\\ 
XMeans\\_Improve & 0.129364 & 0.058108 & 0.024791 & 0.011955 & 0.006536 \\\\
CLUB & 0.013153 & 0.000019 & \\textbf{0.000012} & \\textbf{0.000009} & \\textbf{0.000007} \\\\ 
Off-CLUB & 0.000038 & \\textbf{0.000018} & \\textbf{0.000012} & \\textbf{0.000009} & \\textbf{0.000007} \\\\ 
LinUCB & 0.140748 & 0.140993 & 0.140925 & 0.140898 & 0.140844 \\\\ 
LinUCB\\_IND & 0.144851 & 0.066494 & 0.028189 & 0.013586 & 0.007561 \\\\ 
SCLUB & 0.180909 & 0.131105 & 0.050664 & 0.051058 & 0.051328 \\\\
ARMUL & 0.146900 & 0.109480 & 0.078815 & 0.055882 & 0.051504 \\\\
Off-C$^2$LUB & 0.000041 & \\textbf{0.000018} & \\textbf{0.000012} & \\textbf{0.000009} & \\textbf{0.000007} \\\\ 
Off-C$^2$LUB\\_Underestimation & \\textbf{0.000037} & 0.000019 & 0.000013 & 0.000011 & 0.000008 \\\\ 
Off-C$^2$LUB\\_Overestimation & 0.000041 & \\textbf{0.000018} & \\textbf{0.000012} & \\textbf{0.000009} & \\textbf{0.000007} \\\\ 
\\hline
\\end{tabular}%"""
    
    print(latex_table)
    return latex_table

if __name__ == "__main__":
    generate_latex_table() 