import os
import sys
import numpy as np
import re
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional

# Add project root to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

BASE_DIR = os.path.dirname(__file__)
PLOT_DATA_DIR = os.path.join(BASE_DIR, 'plot_data')

from OffClusBandit.core.Tools import build_method_filename, create_algorithm_instance, compute_mean_and_ci, load_regret, get_mean_regrets
from OffClusBandit.configs.default import alpha_list, plot_synthetic_gamma_list, choose_gamma_alpha_list, plot_yelp_gamma_list, plot_ml_gamma_list
from OffClusBandit.configs.default import long_T_yelp_gamma_list, long_T_ml_gamma_list, long_T_synthetic_gamma_list
from OffClusBandit.configs.default import plot_methods as methods
from OffClusBandit.configs.default import offline_learn_method 
from matplotlib.ticker import ScalarFormatter
def save_npz_regret(results: Dict[str, List[Tuple[int, float, float]]],save_npz_filename = None):
    save_npz = {}
    for method, data in results.items():
        T_values, means, cis = zip(*[(t, mean, ci) for t, mean, ci in data])
        save_npz[method] = [means, cis]
    np.savez(save_npz_filename, **save_npz)


def save_plot_data(offline_learn_method):
    T_values = [200000 * i for i in range(1, 6)]# include 0 and add a previous 0 in plot_gamma_list
    nu = 1000
    d = 20
    m = 10
    L = 20
    alpha = alpha_list[0]
    pj = 0
    uniforms = ['uniform']
    colors_list =  {}
    for method_name, config in methods.items():
        colors_list[method_name] = config['color']
    datasets = ['synthetic', 'yelp', 'ml']
    pjs = [0, 0, 0]
    for i in range(len(datasets)):
        dataset = datasets[i]
        pj = pjs[i]
        if dataset == 'synthetic':
            output_folder = 'artifacts/output_data'
            seed_list = list(range(1, 11))
            best_gamma_list = long_T_synthetic_gamma_list
            is_empirical = 0
        else:
            is_empirical = 1
            output_folder = 'artifacts/output_data_real'
            seed_list = list(range(1, 11))
            if dataset == 'yelp':
                best_gamma_list = long_T_yelp_gamma_list
            else:
                best_gamma_list = long_T_ml_gamma_list

        for choose_gamma_alpha in choose_gamma_alpha_list:
            results = get_mean_regrets(
                seed_list=seed_list,
                methods=methods,
                is_empirical=is_empirical,
                dataset=dataset,
                T_values=T_values,
                nu=nu,
                d=d,
                m=m,
                L=L,
                pj=pj,
                choose_gamma_alpha=choose_gamma_alpha,
                alpha=alpha,
                gamma_list=best_gamma_list,
                offline_learn_method=offline_learn_method
            )
            methods_names = ['OffDBSCAN_improve', 'OffXMeans_improve','CLUB', 'OffCLUB', 'LinUCB_IND', 'SCLUB', 'OffNCLUB',
                    'OffNCLUB_estimate_gamma_Pessimistic', 'OffNCLUB_estimate_gamma_Optimistic']
            data = {method: [round(item[1], 6) for item in results[method]] for method in methods_names}

            save_npz_filename = os.path.join(PLOT_DATA_DIR, f"{offline_learn_method}_{uniforms[pj]}_{T_values[-1]}_{dataset}_{choose_gamma_alpha:.3f}_regret_vs_T.npz")
            if not os.path.exists(PLOT_DATA_DIR):
                os.mkdir(PLOT_DATA_DIR)
            save_npz_regret(results, save_npz_filename)


def generate_latex_table_from_saved_data(files, methods):
    """Generate LaTeX table from saved .npz data files."""
        
    # Algorithm name mapping
    method_mapping = {
        'OffDBSCAN_improve': 'DBSCAN\\_Improve',
        'OffXMeans_improve': 'XMeans\\_Improve', 
        'CLUB': 'CLUB',
        'OffCLUB': 'Off-CLUB',
        'LinUCB_IND': 'LinUCB\\_IND',
        'SCLUB': 'SCLUB',
        'ARMUL': 'ARMUL',
        'OffNCLUB': 'Off-C$^2$LUB',
        'OffNCLUB_estimate_gamma_Pessimistic': 'Off-C$^2$LUB\\_Underestimation',
        'OffNCLUB_estimate_gamma_Optimistic': 'Off-C$^2$LUB\\_Overestimation'
    }
    
    # Algorithm order (follow the original table order)
    method_order = [
        'OffDBSCAN_improve',
        'OffXMeans_improve', 
        'CLUB',
        'OffCLUB',
        'LinUCB_IND',
        'SCLUB',
        'ARMUL',
        'OffNCLUB',
        'OffNCLUB_estimate_gamma_Pessimistic',
        'OffNCLUB_estimate_gamma_Optimistic'
    ]
    
    for file in files:
        print(f"Reading data file: {file}")
        # Read data for each T value
        all_data = {}
        for method in method_order:
            all_data[method] = []
        
        for i in range(1, 6):
            # Construct filename: random_uniform_{T}_{dataset}_{alpha}_regret_vs_T.npz
            data_file = file
            if os.path.exists(data_file):
                # print(f"Reading data file: {data_file}")
                data = np.load(data_file)
                
                for method in method_order:
                    # print(data[method])
                    all_data[method].append(data[method][0][i-1])

            else:
                print(f"Data file {data_file} does not exist, using default values")
                for method in method_order:
                    all_data[method].append(0.0)
        
        # Print data of each algorithm
        # for method in method_order:
        #     print(f"{method}: {all_data[method]}")
        
        # Find the best value (min) for each column
        best_values = []
        for col_idx in range(5):  # 5 columns
            column_values = [all_data[method][col_idx] for method in method_order if all_data[method][col_idx] > 0]
            if column_values:
                best_value = min(column_values)
            else:
                best_value = 0.0
            best_values.append(best_value)
        
        # print(f"Best values per column: {best_values}")
        
        # Generate LaTeX table
        latex_table = []
        latex_table.append("\\begin{tabular}{|l|c|c|c|c|c|}")
        latex_table.append("\\hline")
        latex_table.append("\\diagbox{\\large Algorithm }{\\large Dataset size} & \\large $0.2M$ & \\large $0.4M$ & \\large $0.6M$ & \\large $0.8M$ & \\large $1M$ \\\\ \\hline")
        
        for method in method_order:
            if method in all_data:
                # Extract values for each T
                values = []
                for idx in range(5):
                    value = all_data[method][idx]
                    # If the best value in the column, bold it (based on 6-decimal rounding)
                    if round(value, 6) == round(best_values[idx], 6):
                        values.append(f"\\textbf{{{value:.6f}}}")
                    else:
                        values.append(f"{value:.6f}")
                
                method_name = method_mapping[method]
                row = f"{method_name} & {' & '.join(values)} \\\\"
                latex_table.append(row)
        
        latex_table.append("\\hline")
        latex_table.append("\\end{tabular}%")
        
        # Output LaTeX code
        print("\nGenerated LaTeX table code:")
        latex_code = "\n".join(latex_table)
        print(latex_code)
        
        # return latex_code

if __name__ == "__main__":
    save_plot_data('random')

    files = [
        os.path.join(PLOT_DATA_DIR, f"{offline_learn_method}_{distribution}_1000000_synthetic_0.100_regret_vs_T.npz")  for distribution in ['uniform']
    ]

    files += [
        os.path.join(PLOT_DATA_DIR, f"{offline_learn_method}_uniform_1000000_{dataset}_0.100_regret_vs_T.npz")
        for dataset in ['yelp', 'ml']
    ]
    generate_latex_table_from_saved_data(files, methods)
