#!/usr/bin/env python3
import pandas as pd
import numpy as np
from pathlib import Path

# Change this to split tables (models per table)
NUM_PER_TABLE = 10

def first_valid(series):
    for v in series:
        if pd.notna(v):
            return v
    return np.nan

def short_name(full):
    return str(full).rsplit("/", 1)[-1] if pd.notna(full) else ""

def latex_escape(s):
    if pd.isna(s):
        return "—"
    s = str(s)
    for k, v in {
        "\\": r"\textbackslash{}",
        "&": r"\&",
        "%": r"\%",
        "$": r"\$",
        "#": r"\#",
        "_": r"\_",
        "{": r"\{",
        "}": r"\}",
        "~": r"\textasciitilde{}",
        "^": r"\textasciicircum{}",
    }.items():
        s = s.replace(k, v)
    return s

def fmt_size(x):
    if pd.isna(x): return "—"
    x = float(x)
    return f"{int(round(x))}" if abs(x - round(x)) < 1e-9 else f"{x:.1f}"

def fmt_float(x):
    if pd.isna(x): return "—"
    return f"{float(x):.2f}"

def main():
    csv_path = Path(__file__).parent / "virtualhome_action_sequencing_results_with_flops_and_openllm.csv"
    out_path = Path(__file__).parent / "model_scaling_tables_chunked.tex"

    df = pd.read_csv(csv_path)

    agg_map = {
        "Model Size (B)": first_valid,
        "Pretraining Data Size (T)": first_valid,
        "FLOPs (1E21)": first_valid,
        "Average": first_valid,
    }
    g = df.groupby(["Model", "Model Family"], as_index=False).agg(agg_map)
    g["OpenLLM metric"] = g["Average"].apply(lambda x: "Yes" if pd.notna(x) else "No")
    g["ModelShort"] = g["Model"].apply(short_name)

    g = g.sort_values(["Model Family", "ModelShort"], kind="stable").reset_index(drop=True)

    g["Model Size (B)"] = g["Model Size (B)"].apply(fmt_size)
    g["Pretraining Data Size (T)"] = g["Pretraining Data Size (T)"].apply(fmt_float)
    g["FLOPs (1E21)"] = g["FLOPs (1E21)"].apply(fmt_float)

    g["ModelShort"] = g["ModelShort"].apply(latex_escape)
    g["Model Family"] = g["Model Family"].apply(latex_escape)

    final_df = g[[
        "ModelShort", "Model Family", "Model Size (B)",
        "Pretraining Data Size (T)", "FLOPs (1E21)", "OpenLLM metric"
    ]].rename(columns={
        "ModelShort": "Model",
        "Model Family": "Family",
        "Model Size (B)": "Size (B)",
        "Pretraining Data Size (T)": "Tokens (T)",
    })

    cols = list(final_df.columns)
    col_align = "l" * len(cols)

    tables = []
    total = len(final_df)
    num_tables = (total + NUM_PER_TABLE - 1) // NUM_PER_TABLE
    for i in range(num_tables):
        chunk = final_df.iloc[i*NUM_PER_TABLE:(i+1)*NUM_PER_TABLE]
        lines = []
        lines.append(r"\begin{table}[t]")
        lines.append(r"  \centering")
        lines.append(r"  \small")
        lines.append(r"  \setlength{\tabcolsep}{6pt}")
        lines.append(r"  \resizebox{\textwidth}{!}{%")
        lines.append("    " + r"\begin{tabular}{" + col_align + "}")
        lines.append(r"      \toprule")
        lines.append("      " + " & ".join(cols) + r" \\")
        lines.append(r"      \midrule")
        for _, row in chunk.iterrows():
            row_vals = [str(row[c]) for c in cols]
            lines.append("      " + " & ".join(row_vals) + r" \\")
        lines.append(r"      \bottomrule")
        lines.append(r"    \end{tabular}%")
        lines.append(r"  }")
        lines.append(f"  \\caption{{Model summary (part {i+1} of {num_tables}). Models sorted by family then name; OpenLLM metric = non-NA `Average`.}}")
        lines.append(f"  \\label{{tab:model_scaling_summary_part{i+1}}}")
        lines.append(r"\end{table}")
        tables.append("\n".join(lines))

    out_path.write_text("\n\n% ---- next part ----\n\n".join(tables), encoding="utf-8")

if __name__ == "__main__":
    main()
