import os
import numpy as np
from collections import defaultdict
import utils.helpers as UH  # your helper module

bins = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]

encoders = [
    "sentence-transformers/all-MiniLM-L6-v2",
    "sentence-transformers/all-mpnet-base-v2",
    "sentence-transformers/all-roberta-large-v1",
]
encoders_short = [e.split("/")[-1].replace("_", "-") for e in encoders]

data = defaultdict(lambda: defaultdict(dict))

# --- Load all data (same as before) ---
for K_path in ["K1", "K2", "K3"]:
    for encoder in encoders:
        emb_name = encoder.split("/")[1]
        emb_short = emb_name.replace("_", "-")
        base_path = f"cache/{emb_name}/{K_path}"
        if not os.path.isdir(base_path):
            continue
        for file in os.listdir(base_path):
            if not file.endswith(".json"):
                continue

            K = UH.load_from_cache(f"{emb_name}/{K_path}", file)
            G = {}
            for k, v in K.items():
                if k not in G:
                    G[k] = []
                v = [a for a in v if not np.isnan(a) and not np.isinf(a)]
                G[k].extend(v)

            binned = UH.bin_tuple_dict_log_to_tuple(G)
            diagonal = UH.diagonalize(binned)

            # compute median instead of median
            D = {}
            for key, values in diagonal.items():
                if len(values) == 0:
                    D[key] = np.nan
                else:
                    D[key] = float(np.median(values))

            dataset = file.replace(".json", "").replace("_", "-")
            values = [D.get((b, b), np.nan) for b in bins]

            data[dataset][emb_short][K_path] = values

# === LATEX TABLE GENERATION ===
bin_header = " & ".join(str(b) for b in bins)
n_bins = len(bins)

# Group datasets by "data family" (string before "/")
families = defaultdict(list)
for dataset in sorted(data.keys()):
    family = dataset.split("/")[0] if "/" in dataset else "misc"
    families[family].append(dataset)

# --- Generate one longtable per K value ---
tables = []
for K_path in ["K1", "K2", "K3"]:
    rows = []
    for family, datasets in sorted(families.items()):
        rows.append(rf"\multicolumn{{{2 + n_bins}}}{{l}}{{\textbf{{{family}}}}} \\")
        rows.append(r"\midrule")

        for dataset in datasets:
            # collect encoder medians if available
            if dataset not in data:
                continue
            enc_vals = []
            for emb_short in encoders_short:
                if emb_short not in data[dataset] or K_path not in data[dataset][emb_short]:
                    continue
                vals = data[dataset][emb_short][K_path]
                vals_str = [f"{v:.4f}" if not np.isnan(v) else "-" for v in vals]
                enc_vals.append((emb_short, vals_str))

            if not enc_vals:
                continue

            for i, (emb_short, vals_str) in enumerate(enc_vals):
                if i == 0:
                    row = rf"\multirow{{{len(enc_vals)}}}{{*}}{{{dataset}}} & {emb_short} & " + " & ".join(vals_str) + r" \\"
                else:
                    row = r" & " + emb_short + " & " + " & ".join(vals_str) + r" \\"
                rows.append(row)

        rows.append(r"\midrule")

    # clean up trailing midrule
    if rows and rows[-1] == r"\midrule":
        rows.pop()

    table_latex = rf"""
\begin{{longtable}}{{ll{"c" * n_bins}}}
\caption{{median $K$ values per log bin for {K_path}}} \\
\toprule
Dataset & Encoder & {bin_header} \\
\midrule
\endfirsthead

\toprule
Dataset & Encoder & {bin_header} \\
\midrule
\endhead

\bottomrule
\endfoot

{"\n".join(rows)}
\end{{longtable}}
"""
    tables.append(table_latex)

# Save all three to one .tex file
with open("results_medianK_tables.tex", "w") as f:
    f.write("\n\n".join(tables))

print("=== Saved to results_medianK_tables.tex ===")