#!/usr/bin/env python3
import os
import json
from collections import defaultdict

# -------------------------
# Constants & config
# -------------------------
root_dir = "./results/ICL"
tables_dir = "./tables"
os.makedirs(tables_dir, exist_ok=True)
# Fixed order of LLMs to include (folder keys). Only these will be shown, in this order.
LLM_ORDER = [
    "deepseek-chat",
    "chatgpt",
    "Llama-3.1-8B-Instruct",
    "Qwen2.5-32B-Instruct",
    "Qwen2.5-7B-Instruct",
    # add/remove models here; names must match folder names under ./results/ICL
]
# Fixed order of languages and their classes (for consistent table output)
LANGUAGE_CLASS_ORDER = ["R", "DCF", "CF", "CS"]
LANGUAGE_ORDER = [
    "even-pairs",
    "repeat-01",
    "parity",
    "cycle-navigation",
    "modular-arithmetic-simple",
    "dyck-2-3",
    "first",
    "majority",
    "stack-manipulation",
    "marked-reversal",
    "unmarked-reversal",
    "marked-copy",
    "missing-duplicate-string",
    "odds-first",
    "binary-addition",
    "binary-multiplication",
    "compute-sqrt",
    "bucket-sort",
]
LANGUAGE_CLASSES = {
    "even-pairs": "R",
    "repeat-01": "R",
    "parity": "R",
    "cycle-navigation": "R",
    "modular-arithmetic-simple": "R",
    "dyck-2-3": "R",
    "first": "R",
    "majority": "DCF",
    "stack-manipulation": "DCF",
    "marked-reversal": "DCF",
    "unmarked-reversal": "CF",
    "marked-copy": "CS",
    "missing-duplicate-string": "CS",
    "odds-first": "CS",
    "binary-addition": "CS",
    "binary-multiplication": "CS",
    "compute-sqrt": "CS",
    "bucket-sort": "CS",
}

prompting_strategies = {
    "io_prompt": "immediate output",
    "zsr_prompt": "zero-shot reasoning",
}
encoding_order = [
    "many_to_one",
    "one_to_one",
    "one_to_many_2",
    "one_to_many_3",
    "one_to_many_4",
    "one_to_many_5",
]

# Pretty model names for the detailed (encoding) table
mod_dic = {
    "deepseek-chat": "DeepSeek-V3",
    "chatgpt": "GPT-4o mini",
    "Qwen2.5-32B-Instruct": "Qwen2.5 32B",
    "Qwen2.5-7B-Instruct": "Qwen2.5 7B",
    "Llama-3.1-8B-Instruct": "Llama-3.1 8B",
}

# Language macros for the summary table (fallback to slug if missing)
LANGUAGE_MACROS = {
    "even-pairs": r"\languageEvenPairs",
    "repeat-01": r"\languageRepeatZeroOne",
    "parity": r"\languageParity",
    "cycle-navigation": r"\languageCycleNavigation",
    "modular-arithmetic-simple": r"\languageModularArithmeticSimple",
    "dyck-2-3": r"\languageDyckTwoThree",
    "first": r"\languageFirst",
    "majority": r"\languageMajority",
    "stack-manipulation": r"\languageStackManipulation",
    "marked-reversal": r"\languageMarkedReversal",
    "unmarked-reversal": r"\languageUnmarkedReversal",
    "marked-copy": r"\languageMarkedCopy",
    "missing-duplicate-string": r"\languageMissingDuplicateString",
    "odds-first": r"\languageOddsFirst",
    "binary-addition": r"\languageBinaryAddition",
    "binary-multiplication": r"\languageBinaryMultiplication",
    "compute-sqrt": r"\languageComputeSqrt",
    "bucket-sort": r"\languageBucketSort",
}

# Pretty-name (in transformer_results.json) -> slug
PRETTY_TO_SLUG = {
    "Even Pairs": "even-pairs",
    "Repeat 01": "repeat-01",
    "Parity": "parity",
    "Cycle Navigation": "cycle-navigation",
    "Modular Arithmetic": "modular-arithmetic-simple",
    "Dyck-(2, 3)": "dyck-2-3",
    "First": "first",
    "Majority": "majority",
    "Stack Manipulation": "stack-manipulation",
    "Marked Reversal": "marked-reversal",
    "Unmarked Reversal": "unmarked-reversal",
    "Marked Copy": "marked-copy",
    "Missing Duplicate": "missing-duplicate-string",
    "Odds First": "odds-first",
    "Binary Addition": "binary-addition",
    "Binary Multiplication": "binary-multiplication",
    "Compute Sqrt": "compute-sqrt",
    "Bucket Sort": "bucket-sort",
}

# -------------------------
# Utilities
# -------------------------
def fmt(x):
    return "" if x is None else f"{x:.2f}"

def bold(x):
    return r"\textbf{" + fmt(x) + "}"

def read_latest_json_accuracy(dir_path):
    try:
        files = [f for f in os.listdir(dir_path) if f.endswith(".json")]
        if not files:
            return None
        latest = sorted(files)[-1]
        with open(os.path.join(dir_path, latest), "r") as fh:
            data = json.load(fh)
        acc = data.get("accuracy")
        return round(float(acc), 2) if acc is not None else None
    except Exception:
        return None

def collect_table_data_icls(root):
    """
    Build:
      table_data[language][model][prompt][encoding] = accuracy
    """
    table_data = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
    if not os.path.isdir(root):
        return table_data

    for model in os.listdir(root):
        model_dir = os.path.join(root, model)
        if not os.path.isdir(model_dir):
            continue
        for prompt_dir in os.listdir(model_dir):
            prompt_dir_path = os.path.join(model_dir, prompt_dir)
            if not os.path.isdir(prompt_dir_path):
                continue
            for encoding_dir in os.listdir(prompt_dir_path):
                encoding_path = os.path.join(prompt_dir_path, encoding_dir)
                if not os.path.isdir(encoding_path):
                    continue
                for language in os.listdir(encoding_path):
                    lang_path = os.path.join(encoding_path, language)
                    if not os.path.isdir(lang_path):
                        continue
                    json_files = [f for f in os.listdir(lang_path) if f.endswith(".json")]
                    if not json_files:
                        continue
                    latest_json = sorted(json_files)[-1]
                    file_path = os.path.join(lang_path, latest_json)
                    try:
                        with open(file_path, "r") as f:
                            data = json.load(f)
                        acc = round(data.get("accuracy", -1), 2)
                        table_data[language][model][prompt_dir][encoding_dir] = acc
                    except Exception as e:
                        print(f"Skipping {file_path}: {e}")
    return table_data

def llm_max_for_lang_model(table_data, lang, model_key):
    """
    Max over all prompts+encodings for a given language+model (for summary table).
    """
    best = None
    prompts = table_data.get(lang, {}).get(model_key, {})
    for _, encs in prompts.items():
        for _, acc in encs.items():
            if isinstance(acc, float) and acc >= 0:
                best = acc if best is None else max(best, acc)
    return best

def load_transformer_results(json_path):
    """
    Return dict: slug -> {'100_examples': val, '10k_examples': val}
    """
    out = {}
    if not os.path.isfile(json_path):
        return out
    try:
        with open(json_path, "r") as f:
            raw = json.load(f)
    except Exception as e:
        print(f"⚠️ Could not read Transformer results: {e}")
        return out

    for pretty, vals in raw.items():
        slug = PRETTY_TO_SLUG.get(pretty)
        if not slug:
            continue
        v100 = vals.get("100_examples", None)
        v10k = vals.get("10k_examples", None)
        out[slug] = {
            "100_examples": None if v100 is None else round(float(v100), 2),
            "10k_examples": None if v10k is None else round(float(v10k), 2),
        }
    return out

# -------------------------
# Table 1: Detailed encoding table (by class, language, model, prompt)
# -------------------------
def write_detailed_encoding_table(table_data, out_path):
    """
    Uses globals: LANGUAGE_* , prompting_strategies, encoding_order, mod_dic.
    Formatting:
      - \textbf{...} (black): best per language (over all models/prompts/encodings)
      - \textcolor{blue}{...}: best per (language, model) (over prompts/encodings)
      - If both, black wins.
    """
    def cell_fmt(x): return "" if x is None or x == "" else f"{x:.2f}"

    # Only languages we actually have (and with at least one model)
    present_langs = [l for l in LANGUAGE_ORDER if l in table_data and len(table_data[l]) > 0]

    with open(out_path, "w") as f:
        f.write(r"\scalebox{0.6}{" + "\n")
        f.write(r"\begin{tabular}{llllc ccccc}" + "\n")
        f.write(r"\toprule" + "\n")
        f.write(r"\multirow{3}{*}{\textbf{Class}} & \multirow{3}{*}{\textbf{Language}} & \multirow{3}{*}{\textbf{Model}} & \multirow{3}{*}{\textbf{Prompting Strategy}} & \multicolumn{6}{c}{\textbf{Encoding Strategy}} \\" + "\n")
        f.write(r"\cmidrule(lr){5-10}" + "\n")
        f.write(r" & & & & \textbf{many $\rightarrow$ one} & \textbf{one $\rightarrow$ one} & \textbf{2} & \textbf{3} & \textbf{4} & \textbf{5} \\" + "\n")
        f.write(r"\midrule" + "\n")

        for cls in LANGUAGE_CLASS_ORDER:
            langs_in_cls = [l for l in present_langs if LANGUAGE_CLASSES.get(l) == cls and len(table_data[l]) > 0]
            if not langs_in_cls:
                continue

            # Total rows spanned by this class = sum over languages of (#models * #prompts)
            per_lang_row_counts, class_row_count = {}, 0
            for lang in langs_in_cls:
                num_models = len(table_data[lang])
                rows_lang = num_models * len(prompting_strategies)  # (io, zsr) per model
                per_lang_row_counts[lang] = rows_lang
                class_row_count += rows_lang

            class_cell_written = False  # write class multirow once

            for lang_idx, lang in enumerate(langs_in_cls):
                models = table_data[lang]

                # ---- best per language (global across all models/prompts/encodings) ----
                best_lang_val = None
                for _m, prompts in models.items():
                    for _p, encs in prompts.items():
                        for _e, acc in encs.items():
                            if isinstance(acc, (int, float)):
                                best_lang_val = acc if best_lang_val is None else max(best_lang_val, acc)

                # ---- best per (language, model) ----
                best_val_by_model = {}
                for m, prompts in models.items():
                    best_m = None
                    for _p, encs in prompts.items():
                        for _e, acc in encs.items():
                            if isinstance(acc, (int, float)):
                                best_m = acc if best_m is None else max(best_m, acc)
                    best_val_by_model[m] = best_m

                # Emit rows for each model × prompt
                model_items = list(models.items())
                language_cell_written = False  # write language multirow once

                for m_i, (model, prompts) in enumerate(model_items):
                    for p_i, (prompt_key, prompt_label) in enumerate(prompting_strategies.items()):
                        cells = []

                        # Class multirow (only once at the very first row of the class)
                        if not class_cell_written:
                            cells.append(rf"\multirow{{{class_row_count}}}{{*}}{{\textbf{{{cls}}}}}")
                            class_cell_written = True
                        else:
                            cells.append("")

                        # Language multirow (once per language)
                        if not language_cell_written:
                            cells.append(rf"\multirow{{{per_lang_row_counts[lang]}}}{{*}}{{\textbf{{{lang}}}}}")
                            language_cell_written = True
                        else:
                            cells.append("")

                        # Model multirow (two prompt rows)
                        if p_i == 0:
                            pretty_model = mod_dic.get(model, model)
                            cells.append(rf"\multirow{{{len(prompting_strategies)}}}{{*}}{{\textbf{{{pretty_model}}}}}")
                        else:
                            cells.append("")

                        # Prompt label
                        cells.append(r"\textbf{" + prompt_label + "}")

                        # Encodings: black (language-best) > blue (model-best)
                        model_best = best_val_by_model.get(model)
                        for enc in encoding_order:
                            acc = prompts.get(prompt_key, {}).get(enc, "")
                            if isinstance(acc, (int, float)):
                                if best_lang_val is not None and acc == best_lang_val:
                                    cell = r"\textbf{" + cell_fmt(acc) + "}"
                                elif model_best is not None and acc == model_best:
                                    cell = r"\textcolor{blue}{" + cell_fmt(acc) + "}"
                                else:
                                    cell = cell_fmt(acc)
                            else:
                                cell = ""
                            cells.append(cell)

                        f.write(" & ".join(part if part else "" for part in cells) + r" \\" + "\n")

                    # Midrule between models within the same language
                    if m_i < len(model_items) - 1:
                        f.write(r"\cmidrule(lr){3-10}" + "\n")

                # Long rule between languages (skip class column)
                if lang_idx < len(langs_in_cls) - 1:
                    f.write(r"\cmidrule(lr){2-10}" + "\n")

            # Rule between classes
            f.write(r"\midrule" + "\n")

        f.write(r"\bottomrule" + "\n")
        f.write(r"\end{tabular}" + "\n")
        f.write(r"}" + "\n")  # end scalebox

    print(f"✅ LaTeX table (scaled 60%) written with language-best (bold) and language+model-best (blue): {out_path}")

# -------------------------
# Table 2: Summary “by size” table with DeepSeek, GPT, Qwen-32B, Qwen-7B, Tf(100) and Tf(10k)
# -------------------------
def write_max_by_size_table(table_data, out_path):
    """
    Dynamically builds a summary table that includes *all* tested models found in
    `table_data` (max over prompts+encodings per model), plus Tf(100) and Tf(10k).

    Bolding:
      - In the 100-examples block, bold the max among {all models, Tf(100)} (ties bold).
      - In the 10k column, bold if it equals the global max among {all models, Tf(100), Tf(10k)} (ties bold).
    """
    # ---------- helpers ----------
    tf_json_path = os.path.join(root_dir, "trained_trasformer", "transformer_results.json")
    tf = load_transformer_results(tf_json_path)  # slug -> {100_examples, 10k_examples}

    def fmt(x):  # 2 decimals or empty
        return "" if x is None else f"{x:.2f}"

    def bold_cell(x):
        return r"\textbf{" + fmt(x) + "}"

    # Max over prompts+encodings for a given language+model
    def llm_max(lang, model_key):
        return llm_max_for_lang_model(table_data, lang, model_key)

    # ---------- gather languages ----------
    observed_langs = set(table_data.keys()) | set(tf.keys())
    languages = [l for l in LANGUAGE_ORDER if l in observed_langs]

    # ---------- gather & order models (global, across all languages) ----------
    # Union of model keys present anywhere in table_data
    all_models = set()
    for lang_data in table_data.values():
        all_models |= set(lang_data.keys())

    # Preferred ordering: known models in the order of keys in mod_dic, then the rest alphabetically by pretty name
    known_order = [k for k in mod_dic.keys() if k in all_models]
    unknown_models = sorted([m for m in all_models if m not in mod_dic],
                            key=lambda k: k.lower())
    ordered_models = known_order + unknown_models

    # Pretty names for headers
    def pretty_model_name(model_key):
        return mod_dic.get(model_key, model_key)

    # Column math:
    # 1: Class, 2: Language, 3..(2+M): models, (3+M): Tf(100), (4+M): Tf(10k)
    M = len(ordered_models)
    total_cols = 2 + M + 2  # class + language + models + Tf100 + Tf10k

    # ---------- write LaTeX ----------
    with open(out_path, "w") as f:
        f.write(r"\scalebox{0.6}{" + "\n")
        f.write(r"\begin{tabular}{" + "cl" + "c"*M + "cc" + "}" + "\n")
        f.write(r"\toprule" + "\n")

        # Group headers
        # 100-examples spans models + Tf(100) => M + 1 columns starting at col 3
        hundred_start = 3
        hundred_end = 2 + M + 1  # inclusive
        tenk_col = hundred_end + 1

        f.write(r"& & \multicolumn{" + str(M + 1) + r"}{c}{\textbf{100 examples}} & \textbf{10k examples} \\" + "\n")
        f.write(rf"\cmidrule(lr){{{hundred_start}-{hundred_end}}} \cmidrule(lr){{{tenk_col}-{tenk_col}}}" + "\n")

        # Second header row with per-model names + Tf columns
        header_cells = [r"\textbf{Class}", r"\textbf{Language}"]
        header_cells += [r"\textbf{" + pretty_model_name(m) + r"}" for m in ordered_models]
        header_cells += [r"\textbf{Tf}", r"\textbf{Tf}"]
        f.write(" & ".join(header_cells) + r" \\" + "\n")
        f.write(r"\midrule" + "\n")

        # Body
        for cls in LANGUAGE_CLASS_ORDER:
            langs_in_cls = [l for l in languages if LANGUAGE_CLASSES.get(l) == cls]
            if not langs_in_cls:
                continue

            for i, lang in enumerate(langs_in_cls):
                # Per-model maxima for this language
                model_vals = [llm_max(lang, m) for m in ordered_models]

                # Transformer values
                t100 = tf.get(lang, {}).get("100_examples")
                t10k = tf.get(lang, {}).get("10k_examples")

                # Bolding logic
                group_100 = [x for x in (model_vals + [t100]) if x is not None]
                max_100 = max(group_100) if group_100 else None

                group_all = [x for x in (model_vals + [t100, t10k]) if x is not None]
                max_all = max(group_all) if group_all else None

                # Render cells with bolding rules
                model_cells = [
                    (bold_cell(v) if (v is not None and max_100 is not None and v == max_100) else fmt(v))
                    for v in model_vals
                ]
                t100_cell = bold_cell(t100) if (t100 is not None and max_100 is not None and t100 == max_100) else fmt(t100)
                t10k_cell = bold_cell(t10k) if (t10k is not None and max_all is not None and t10k == max_all) else fmt(t10k)

                lang_label = LANGUAGE_MACROS.get(lang, lang)
                prefix = (rf"\multirow{{{len(langs_in_cls)}}}{{*}}{{{cls}}}  & " if i == 0 else "  & ")
                f.write(prefix + rf"{lang_label} & " + " & ".join(model_cells + [t100_cell, t10k_cell]) + r" \\" + "\n")

                # Separator between languages in the same class (skip the Class column)
                if i < len(langs_in_cls) - 1:
                    # From Language (col 2) to the last column
                    f.write(rf"\cmidrule(lr){{2-{total_cols}}}" + "\n")

            # Rule between classes
            f.write(r"\midrule" + "\n")

        f.write(r"\bottomrule" + "\n")
        f.write(r"\end{tabular}" + "\n")
        f.write(r"}" + "\n")  # end scalebox

    print(f"✅ Summary table (dynamic models, scaled 60%) written: {out_path}")

def write_neg1_percent_table(root, out_path):
    """
    Uses globals: LANGUAGE_* , prompting_strategies, encoding_order, mod_dic.
    Prints percent of entries where 'model_output' == -1 in the latest JSON
    for each (language, model, prompt, encoding). No highlighting.
    """
    from collections import defaultdict

    # -------- helpers --------
    def cell_fmt(x): return "" if x is None else f"{x:.1f}\\%"

    def count_neg1(obj):
        if isinstance(obj, dict):
            c = 1 if ("model_output" in obj and obj["model_output"] == -1) else 0
            for v in obj.values():
                c += count_neg1(v)
            return c
        if isinstance(obj, list):
            return sum(count_neg1(v) for v in obj)
        return 0

    def count_total_outputs(obj):
        if isinstance(obj, dict):
            c = 1 if "model_output" in obj else 0
            for v in obj.values():
                c += count_total_outputs(v)
            return c
        if isinstance(obj, list):
            return sum(count_total_outputs(v) for v in obj)
        return 0

    # collect percentages: perc[language][model][prompt][encoding] = float|None
    perc = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))

    if os.path.isdir(root):
        for model in os.listdir(root):
            model_dir = os.path.join(root, model)
            if not os.path.isdir(model_dir):
                continue
            for prompt_dir in os.listdir(model_dir):
                p_path = os.path.join(model_dir, prompt_dir)
                if not os.path.isdir(p_path):
                    continue
                for enc_dir in os.listdir(p_path):
                    e_path = os.path.join(p_path, enc_dir)
                    if not os.path.isdir(e_path):
                        continue
                    for language in os.listdir(e_path):
                        lang_path = os.path.join(e_path, language)
                        if not os.path.isdir(lang_path):
                            continue
                        json_files = [f for f in os.listdir(lang_path) if f.endswith(".json")]
                        if not json_files:
                            continue
                        latest = sorted(json_files)[-1]
                        fp = os.path.join(lang_path, latest)
                        try:
                            with open(fp, "r") as fh:
                                data = json.load(fh)
                            n_bad = count_neg1(data)
                            n_tot = count_total_outputs(data)
                            perc_val = (100.0 * n_bad / n_tot) if n_tot > 0 else None
                            if perc_val is not None:
                                perc_val = round(perc_val, 2)
                            perc[language][model][prompt_dir][enc_dir] = perc_val
                        except Exception as e:
                            print(f"Skipping {fp}: {e}")

    # Only languages we actually have (and with at least one model)
    present_langs = [l for l in LANGUAGE_ORDER if l in perc and len(perc[l]) > 0]

    with open(out_path, "w") as f:
        f.write(r"\scalebox{0.6}{" + "\n")
        f.write(r"\begin{tabular}{llllc ccccc}" + "\n")
        f.write(r"\toprule" + "\n")
        f.write(r"\multirow{3}{*}{\textbf{Class}} & \multirow{3}{*}{\textbf{Language}} & \multirow{3}{*}{\textbf{Model}} & \multirow{3}{*}{\textbf{Prompting Strategy}} & \multicolumn{6}{c}{\textbf{Encoding Strategy} (\% with \texttt{model\_output} = -1)} \\" + "\n")
        f.write(r"\cmidrule(lr){5-10}" + "\n")
        f.write(r" & & & & \textbf{many $\rightarrow$ one} & \textbf{one $\rightarrow$ one} & \textbf{2} & \textbf{3} & \textbf{4} & \textbf{5} \\" + "\n")
        f.write(r"\midrule" + "\n")

        for cls in LANGUAGE_CLASS_ORDER:
            langs_in_cls = [l for l in present_langs if LANGUAGE_CLASSES.get(l) == cls and len(perc[l]) > 0]
            if not langs_in_cls:
                continue

            # total printed rows for this class
            per_lang_row_counts, class_row_count = {}, 0
            for lang in langs_in_cls:
                num_models = len(perc[lang])
                rows_lang = num_models * len(prompting_strategies)  # (io, zsr) per model
                per_lang_row_counts[lang] = rows_lang
                class_row_count += rows_lang

            class_cell_written = False

            for lang_idx, lang in enumerate(langs_in_cls):
                models = perc[lang]
                if not models:
                    continue

                model_items = list(models.items())
                language_cell_written = False

                for m_i, (model, prompts) in enumerate(model_items):
                    for p_i, (prompt_key, prompt_label) in enumerate(prompting_strategies.items()):
                        cells = []

                        # Class multirow: once at the first row of this class
                        if not class_cell_written:
                            cells.append(rf"\multirow{{{class_row_count}}}{{*}}{{\textbf{{{cls}}}}}")
                            class_cell_written = True
                        else:
                            cells.append("")

                        # Language multirow: once per language
                        if not language_cell_written:
                            cells.append(rf"\multirow{{{per_lang_row_counts[lang]}}}{{*}}{{\textbf{{{lang}}}}}")
                            language_cell_written = True
                        else:
                            cells.append("")

                        # Model multirow: spans two prompt rows
                        pretty_model = mod_dic.get(model, model) if 'mod_dic' in globals() else model
                        if p_i == 0:
                            cells.append(rf"\multirow{{{len(prompting_strategies)}}}{{*}}{{\textbf{{{pretty_model}}}}}")
                        else:
                            cells.append("")

                        # Prompt label
                        cells.append(r"\textbf{" + prompt_label + "}")

                        # Encodings: plain percentages
                        for enc in encoding_order:
                            val = prompts.get(prompt_key, {}).get(enc, None)
                            cells.append(cell_fmt(val))

                        f.write(" & ".join(part if part else "" for part in cells) + r" \\" + "\n")

                    # midrule between models of the same language
                    if m_i < len(model_items) - 1:
                        f.write(r"\cmidrule(lr){3-10}" + "\n")

                # long rule between languages (skip class column)
                if lang_idx < len(langs_in_cls) - 1:
                    f.write(r"\cmidrule(lr){2-10}" + "\n")

            f.write(r"\midrule" + "\n")

        f.write(r"\bottomrule" + "\n")
        f.write(r"\end{tabular}" + "\n")
        f.write(r"}" + "\n")  # end scalebox

    print(f"✅ Neg1 %% table (scaled 60%) written: {out_path}")

def write_nns_vs_icl_max_table(table_data, out_path):
    """
    Create a 4-col table comparing, for each language:
      - Accuracy (NNS): max across the Transformer JSON entries (e.g., 100_examples, 10k_examples)
      - Accuracy (ICL): max across ALL ICL configs (all models × prompts × encodings)

    The higher of the two is \\textcolor{blue}{...} (ties: both blue).
    Output is wrapped in \\scalebox{0.6}{...}.
    """
    # --- helpers ---
    def fmt3(x): return "" if x is None else f"{x:.3f}"
    def blue(x): return r"\textcolor{blue}{" + fmt3(x) + "}"

    # Load Transformer (NNS) results and compute max per language
    tf_json_path = os.path.join(root_dir, "trained_trasformer", "transformer_results.json")
    tf = load_transformer_results(tf_json_path)  # slug -> {'100_examples': v, '10k_examples': v}
    nns_max = {}
    for lang, d in tf.items():
        vals = [v for v in d.values() if isinstance(v, (int, float))]
        nns_max[lang] = max(vals) if vals else None

    # ICL max per language (over all models/prompts/encodings)
    def icl_max_for_lang(lang):
        best = None
        for _model, prompts in table_data.get(lang, {}).items():
            for _prompt, encs in prompts.items():
                for _enc, acc in encs.items():
                    if isinstance(acc, (int, float)):
                        best = acc if best is None else max(best, acc)
        return best

    # Languages to include: present in either NNS or ICL, ordered
    observed = set(table_data.keys()) | set(nns_max.keys())
    languages = [l for l in LANGUAGE_ORDER if l in observed]

    with open(out_path, "w") as f:
        f.write(r"\scalebox{0.6}{" + "\n")
        f.write(r"\begin{tabular}{llll}" + "\n")
        f.write(r"\toprule" + "\n")
        f.write(r"\textbf{Class} & \textbf{Task} & \textbf{Accuracy (NNS)} & \textbf{Accuracy (ICL)} \\" + "\n")
        f.write(r"\midrule" + "\n")

        for cls in LANGUAGE_CLASS_ORDER:
            langs_in_cls = [l for l in languages if LANGUAGE_CLASSES.get(l) == cls]
            if not langs_in_cls:
                continue

            class_span = len(langs_in_cls)
            first = True

            for lang in langs_in_cls:
                nns_val = nns_max.get(lang)
                icl_val = icl_max_for_lang(lang)

                if nns_val is None and icl_val is None:
                    nns_cell = ""
                    icl_cell = ""
                else:
                    mx = max([v for v in (nns_val, icl_val) if v is not None])
                    nns_cell = blue(nns_val) if (nns_val is not None and nns_val == mx) else fmt3(nns_val)
                    icl_cell = blue(icl_val) if (icl_val is not None and icl_val == mx) else fmt3(icl_val)

                lang_label = LANGUAGE_MACROS.get(lang, lang) if 'LANGUAGE_MACROS' in globals() else lang
                class_cell = rf"\multirow{{{class_span}}}{{*}}{{{cls}}}" if first else ""
                row_prefix = (class_cell + " & ") if first else " & "
                first = False

                # IMPORTANT: end rows with double backslash
                f.write(f"{row_prefix}{lang_label} & {nns_cell} & {icl_cell}" + r" \\" + "\n")

            f.write(r"\midrule" + "\n")

        f.write(r"\bottomrule" + "\n")
        f.write(r"\end{tabular}" + "\n")
        f.write(r"}" + "\n")

    print(f"✅ NNS vs ICL max table (scaled 60%) written: {out_path}")

def write_avg_by_config_table(table_data, out_path):
    """
    Average accuracy across languages for each model × prompting × encoding.
    Header uses \encodingManyToOne{}, \encodingOneToOne{}, \encodingOneToMany{}.
    Rows bold the best value per model+prompt (ties bold).
    """
    import math

    prompting_keys = [
        ("io_prompt", r"\promptingImmediateOutput{}"),
        ("zsr_prompt", r"\promptingZeroShotReasoning{}"),
    ]
    enc_cols = ["many_to_one", "one_to_one", "one_to_many_2", "one_to_many_3", "one_to_many_4", "one_to_many_5"]

    # Desired order (fixed COMMA between the two Llama lines)
    preferred_order = [
        "deepseek-chat",
        "chatgpt",
        "Llama-3.1-70B-Instruct",
        "Llama-3.1-8B-Instruct",
        "Qwen2.5-32B-Instruct",
        "Qwen2.5-7B-Instruct",
    ]

    # Keep only models that actually appear; append any extras at the end
    observed_models = set()
    for _lang, models in table_data.items():
        observed_models.update(models.keys())
    model_order = [m for m in preferred_order if m in observed_models]
    extras = [m for m in sorted(observed_models) if m not in model_order]
    model_order += extras

    # LaTeX model macros; fallback to mod_dic or raw key
    model_macros = {
        "deepseek-chat": r"\modelDeepSeek{}",
        "chatgpt": r"\modelGPT{}",
        "Llama-3.1-70B-Instruct": r"\modelLlamaLarge{}",
        "Llama-3.1-8B-Instruct": r"\modelLlamaSmall{}",
        "Qwen2.5-32B-Instruct": r"\modelQwenLarge{}",
        "Qwen2.5-7B-Instruct": r"\modelQwenSmall{}",
    }

    def pretty_model(m):
        return model_macros.get(m, mod_dic.get(m, m))

    def mean(xs):
        xs = [x for x in xs if isinstance(x, (int, float)) and not math.isnan(x)]
        return sum(xs) / len(xs) if xs else None

    def avg_for(model, prompt, enc):
        vals = []
        for _lang, models in table_data.items():
            v = models.get(model, {}).get(prompt, {}).get(enc, None)
            if isinstance(v, (int, float)):
                vals.append(v)
        m = mean(vals)
        return None if m is None else round(m, 2)

    def fmt(x): return "" if x is None else f"{x:.2f}"
    def bold(x): return r"\textbf{" + fmt(x) + "}"

    with open(out_path, "w") as f:
        f.write(r"\scalebox{0.9}{" + "\n")
        f.write(r"\begin{tabular}{cccccccc}" + "\n")
        f.write(r"\toprule" + "\n")
        f.write(r" \multirow{3}{*}{\textbf{Model}}&\multirow{3}{*}{\textbf{Prompting Strategy}} & \multicolumn{6}{c}{\textbf{Encoding Strategy}} \\" + "\n")
        f.write(r" \cmidrule(lr){3-8}" + "\n")
        f.write(r" &&\multirow{2}{*}{\textbf{\encodingManyToOne{}}} & \multirow{2}{*}{\textbf{\encodingOneToOne{}}} & \multicolumn{4}{c}{\textbf{\encodingOneToMany{}}} \\" + "\n")
        f.write(r" \cmidrule(lr){5-8}" + "\n")
        f.write(r"  & &  & & $\mathbf{2}$ & $\mathbf{3}$ & $\mathbf{4}$ & $\mathbf{5}$ \\" + "\n")
        f.write(r"\midrule" + "\n")

        for mi, model in enumerate(model_order):
            mlabel = pretty_model(model)
            for pi, (pkey, plabel) in enumerate(prompting_keys):
                vals = [avg_for(model, pkey, enc) for enc in enc_cols]
                row_max = max([v for v in vals if v is not None], default=None)
                cells = [bold(v) if (row_max is not None and v == row_max) else fmt(v) for v in vals]

                model_cell = rf"\multirow{{2}}{{*}}{{\textbf{{{mlabel}}}}}" if pi == 0 else ""
                f.write(f"{model_cell} & \\textbf{{{plabel}}} & " + " & ".join(cells) + " \\\\\n")

            if mi < len(model_order) - 1:
                f.write(r"\midrule" + "\n")

        f.write(r"\bottomrule" + "\n")
        f.write(r"\end{tabular}" + "\n")
        f.write(r"}" + "\n")

    print(f"✅ timesteps-comparison (scaled 90%) written: {out_path}")

# -------------------------
# Main
# -------------------------
if __name__ == "__main__":
    # 1) Scan results once
    table_data = collect_table_data_icls(root_dir)
    os.makedirs(tables_dir, exist_ok=True)

    # 2) Detailed accuracy table
    #    (bold = best per language; blue = best per language+model)
    write_detailed_encoding_table(
        table_data,
        os.path.join(tables_dir, "big_language_accuracy_table.tex"),
    )

    # 3) Summary “by size” table
    #    (DeepSeek, GPT, Qwen-32B, Qwen-7B, Tf-100; Tf-10k)
    write_max_by_size_table(
        table_data,
        os.path.join(tables_dir, "summary_max_language_accuracy_table.tex"),
    )

    # 4) Counts of model_output == -1 (no highlighting)
    write_neg1_percent_table(
        root_dir,
        os.path.join(tables_dir, "configuration_unparsable_output_percent_table.tex"),
    )

    write_nns_vs_icl_max_table(
        table_data,
        os.path.join(tables_dir, "nns_vs_icl_max_table.tex"),
    )
    
    write_avg_by_config_table(
    table_data,
    os.path.join(tables_dir, "timesteps-comparison.tex"),
    )
