# -*- coding: utf-8 -*-
import pandas as pd
import io
import re
import math
from collections import defaultdict

input_filename = 'overall_summary_stats_mean_std_by_prior_rate_alg.md'

output_latex_filename = 'latex_table_dream4_no_pr_bm_bold.txt' # bm = bold math
output_precision = 2
bolding_tolerance = 1e-6

def escape_latex(text):
    chars = { '&': r'\&', '%': r'\%', '$': r'\$', '#': r'\#', '_': r'\_',
              '{': r'\{', '}': r'\}', '~': r'\textasciitilde{}', '^': r'\textasciicircum{}',
              '\\': r'\textbackslash{}', '*': r'\*', }
    regex = re.compile('|'.join(re.escape(key) for key in chars.keys()))
    return regex.sub(lambda match: chars[match.group(0)], str(text))

with open(input_filename, 'r', encoding='utf-8') as f:
    markdown_content = f.read()
data_io = io.StringIO("\n".join(markdown_content.splitlines()[0:1] + markdown_content.splitlines()[2:]))
df = pd.read_csv(data_io, sep='|', skipinitialspace=True)
df.columns = [col.strip().strip('`') for col in df.columns]
df = df.drop(columns=[''], errors='ignore')
df = df.dropna(subset=['prior_rate', 'alg', 'data_size'], axis=0)
numeric_cols = ['dataloss_mean', 'dataloss_std', 'roc_mean', 'roc_std', 'data_size', 'prior_rate']
for col in numeric_cols:
    if col in df.columns:
        df[col] = pd.to_numeric(df[col], errors='coerce')
df = df.dropna(subset=['prior_rate', 'alg', 'data_size'], axis=0)
dataset_map = {3: "DREAM4-3", 6: "DREAM4-6", 9: "DREAM4-9"}
size_map = {v: k for k, v in dataset_map.items()}
datasets = ["DREAM4-3", "DREAM4-6", "DREAM4-9"]
metrics = ['dataloss', 'roc']
metric_map = {'dataloss': 'Loss', 'roc': 'ROC'}
metric_headers = [metric_map[m] for m in metrics]
prior_rates = sorted(df['prior_rate'].unique())
algorithms = sorted(df['alg'].unique())
best_values = defaultdict(lambda: defaultdict(dict))
grouped = df.groupby(['data_size', 'prior_rate'])
for name, group_df in grouped:
    ds_code, pr = name
    if ds_code not in dataset_map: continue
    min_loss = group_df['dataloss_mean'].min(skipna=True)
    max_roc = group_df['roc_mean'].max(skipna=True)
    if pd.notna(min_loss): best_values[ds_code][pr]['Loss'] = min_loss
    if pd.notna(max_roc):  best_values[ds_code][pr]['ROC']  = max_roc

results = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
for index, row in df.iterrows():
    pr = row['prior_rate']
    alg = row['alg']
    ds_code = row['data_size']
    if ds_code not in dataset_map or pd.isna(pr) or pd.isna(alg): continue
    dataset_name = dataset_map[ds_code]

    for metric_col_base, metric_header in metric_map.items():
        mean_col = f'{metric_col_base}_mean'
        std_col = f'{metric_col_base}_std'
        raw_mean_val = row.get(mean_col, float('nan'))
        std_val = row.get(std_col, float('nan'))

        inner_content_str = "N/A" 
        if pd.notna(raw_mean_val):
            mean_str = f"{raw_mean_val:.{output_precision}f}"
            if pd.notna(std_val):
                std_str = f"{std_val:.{output_precision}f}"
                inner_content_str = f"{mean_str}{{\\scriptstyle \\pm {std_str}}}"
            else:
                inner_content_str = f"{mean_str}{{\\scriptstyle \\pm ?}}"

        results[pr][alg][dataset_name][metric_header] = (raw_mean_val, inner_content_str)



latex_string = "\\begin{table}[ht]\n"
latex_string += "\\centering\n"
latex_string += "\\small\n"
latex_string += "\\setlength{\\tabcolsep}{4pt}\n"
latex_string += "\\caption{Experiment Summary (Mean {\\scriptsize $\\pm$ Std Dev}, Best Mean Bolded)}\n"
latex_string += "\\label{tab:dream4_summary_no_pr_bm_bold}\n"
latex_string += "\\begin{tabular}{ll*{6}{c}}\n"
latex_string += "\\toprule\n"
latex_string += "\\multicolumn{2}{c}{\\textbf{Datasize}}"
for ds_name in datasets:
    latex_string += f" & \\multicolumn{{2}}{{c}}{{\\textbf{{{ds_name}}}}}"
latex_string += " \\\\\n"
latex_string += "\\cmidrule(lr){1-2} "
start_col = 3
for _ in datasets:
    end_col = start_col + 1
    latex_string += f"\\cmidrule(lr){{{start_col}-{end_col}}} "
    start_col = end_col + 1
latex_string += "\n"
latex_string += "\\textbf{Prior Rate} & \\textbf{Method}"
for _ in datasets:
    for mh in metric_headers:
        latex_string += f" & \\textbf{{{mh}}}"
latex_string += " \\\\\n"
latex_string += "\\midrule\n"

for pr in prior_rates:
    algs_for_this_pr = sorted([alg for alg in algorithms if alg in results.get(pr, {})])
    num_algs = len(algs_for_this_pr)
    if num_algs == 0: continue
    pr_display = f"{int(pr*100)}\\%"

    for i, alg in enumerate(algs_for_this_pr):
        safe_alg_name = escape_latex(alg)
        if i == 0:
            latex_string += f"\\multirow{{{num_algs}}}{{*}}{{{pr_display}}} & "
        else:
            latex_string += " & "
        latex_string += f"{safe_alg_name}"

        for ds_name in datasets:
            ds_code = size_map[ds_name]
            for mh in metric_headers:
                value_tuple = results.get(pr, {}).get(alg, {}).get(ds_name, {}).get(mh, (float('nan'), "N/A"))
                raw_val, inner_content = value_tuple
                
                if inner_content == "N/A":
                        cell_output = "N/A"
                else:
                    best_val = best_values.get(ds_code, {}).get(pr, {}).get(mh, None)
                    is_best = False
                    if pd.notna(raw_val) and best_val is not None:
                        try:
                            if math.isclose(raw_val, best_val, rel_tol=0, abs_tol=bolding_tolerance):
                                is_best = True
                        except TypeError:
                            is_best = False

                    if is_best:
                        cell_output = f"$\\bm{{{inner_content}}}$"
                    else:
                        cell_output = f"${inner_content}$"

                latex_string += f" & {cell_output}"

        latex_string += " \\\\\n"

latex_string += "\\bottomrule\n"
latex_string += "\\end{tabular}\n"
latex_string += "\\end{table}\n"

with open(output_latex_filename, 'w', encoding='utf-8') as f_out:
    f_out.write(latex_string)