import os
import numpy as np
from collections import defaultdict
import utils.helpers as UH

bins = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]

encoders = [
    "sentence-transformers/all-MiniLM-L6-v2",
]
encoders_short = [e.split("/")[-1].replace("_", "-") for e in encoders]

data = defaultdict(lambda: defaultdict(dict))
counts_per_dataset = {}

for K_path in ["K1"]:
    for encoder in encoders:
        emb_name = encoder.split("/")[1]
        emb_short = emb_name.replace("_", "-")
        base_path = f"cache/{emb_name}/{K_path}"
        if not os.path.isdir(base_path):
            continue
        for file in os.listdir(base_path):
            if not file.endswith(".json"):
                continue

            K = UH.load_from_cache(f"{emb_name}/{K_path}", file)
            G = {}
            for k, v in K.items():
                if k not in G:
                    G[k] = []
                v = [a for a in v if not np.isnan(a) and not np.isinf(a)]
                G[k].extend(v)

            binned = UH.bin_tuple_dict_log_to_tuple(G)
            diagonal = UH.diagonalize(binned)
            counts_diag = {k: len(diagonal[k]) for k in diagonal if k[0] == k[1]}
            D = UH.get_median_Ks(diagonal)

            dataset = file.replace(".json", "").replace("_", "-")
            values = [D.get((b, b), np.nan) for b in bins]
            count_bins = [counts_diag.get((b, b), 0) for b in bins]

            data[dataset][emb_short][K_path] = (values, count_bins)

            if dataset not in counts_per_dataset:
                counts_per_dataset[dataset] = count_bins

# === LATEX TABLE GENERATION ===
bin_header = " & ".join(str(b) for b in bins)
n_bins = len(bins)

# Group datasets by "data family" (string before "/")
families = defaultdict(list)
for dataset in sorted(counts_per_dataset.keys()):
    if "/" in dataset:
        family = dataset.split("/")[0]
    else:
        family = "misc"
    families[family].append(dataset)

rows = []
for family, datasets in sorted(families.items()):
    # multicolumn family header
    # rows.append(rf"\multicolumn{{{2 + n_bins}}}{{l}}{{\textbf{{{family}}}}} \\")
    # rows.append(r"\midrule")
    for i, dataset in enumerate(datasets):
        counts = counts_per_dataset[dataset]
        count_strs = [str(int(c)) if c > 0 else "-" for c in counts]
        row = rf"{dataset} & \#${i+1}$ & " + " & ".join(count_strs) + r" \\"
        rows.append(row)
    rows.append(r"\midrule")

# Remove trailing midrule
if rows and rows[-1] == r"\midrule":
    rows.pop()

table_latex = rf"""
\begin{{longtable}}{{ll{"c" * n_bins}}}
\caption{{Number of items per log bin across datasets (grouped by data family)}} \\
\toprule
Dataset &  & {bin_header} \\
\midrule
\endfirsthead

\toprule
Dataset &  & {bin_header} \\
\midrule
\endhead

\bottomrule
\endfoot

{"\n".join(rows)}
\end{{longtable}}
"""

# Save the LaTeX table
with open("results_counts_table.tex", "w") as f:
    f.write(table_latex)

print("=== Saved to results_counts_table.tex ===")