# save as make_segment_tables.py
import argparse, pandas as pd, numpy as np
from pathlib import Path

def latex_escape(x):  # not used in numbers, kept for robustness
    return str(x)

def fmt_overhead(r):
    return f"{r:.1f}×"

def to_latex_table_energy(per_seg):
    rows = []
    for s in sorted(per_seg['segment'].unique()):
        r = per_seg.loc[per_seg['segment']==s].iloc[0]
        rows.append(f"{s} & {r['imlp_energy']:.1f} & {r['mlp_energy']:.1f} & {fmt_overhead(r['overhead'])} \\\\")
    body = "\n".join(rows)
    return rf"""\begin{table}[h]
\centering
\begin{tabular}{c|cc|c}
\toprule
\textbf{{Segment}} & \textbf{{IMLP Energy (J)}} & \textbf{{MLP Energy (J)}} & \textbf{{MLP Overhead}} \\
\midrule
{body}
\bottomrule
\end{tabular}
\caption{{Per-segment energy consumption.}}
\label{{tab:energy-per-segment}}
\end{table}"""

def to_latex_table_accuracy(per_seg_acc):
    rows=[]
    for s in sorted(per_seg_acc['segment'].unique()):
        r = per_seg_acc.loc[per_seg_acc['segment']==s].iloc[0]
        ratio = f"{s+1}:1"  # cumulative MLP data vs segmental IMLP
        rows.append(f"{s} & {r['imlp_acc']:.3f} & {r['mlp_acc']:.3f} & {ratio} \\\\")
    body="\n".join(rows)
    return rf"""\begin{table}[h]
\centering
\begin{tabular}{c|c|c|c}
\toprule
\textbf{{Segment}} & \textbf{{IMLP Accuracy}} & \textbf{{MLP Accuracy}} & \textbf{{Training Data Ratio (MLP:IMLP)}} \\
\midrule
{body}
\bottomrule
\end{tabular}
\caption{{Performance vs training data consumption.}}
\label{{tab:data-efficiency}}
\end{table}"""

def to_latex_table_cumulative(per_seg):
    per_seg = per_seg.sort_values('segment')
    per_seg['imlp_cum'] = per_seg['imlp_energy'].cumsum()
    per_seg['mlp_cum']  = per_seg['mlp_energy'].cumsum()
    per_seg['adv']      = per_seg['mlp_cum'] / per_seg['imlp_cum']
    # show segments 0,2,4,6,7 if present
    want = [0,2,4,6,7]
    rows=[]
    for s in [x for x in want if x in set(per_seg['segment'])]:
        r = per_seg.loc[per_seg['segment']==s].iloc[0]
        rows.append(f"{s} & {r['imlp_cum']:.1f} & {r['mlp_cum']:.1f} & {fmt_overhead(r['adv'])} \\\\")
    body="\n".join(rows)
    return rf"""\begin{table}[h]
\centering
\begin{tabular}{c|c|c|c}
\toprule
\textbf{{Segment}} & \textbf{{IMLP Cumulative (J)}} & \textbf{{MLP Cumulative (J)}} & \textbf{{Efficiency Advantage}} \\
\midrule
{body}
\bottomrule
\end{tabular}
\caption{{Cumulative energy consumption.}}
\label{{tab:cumulative-cost}}
\end{table}"""

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("csv", help="CSV with columns: segment,model,energy_j,accuracy")
    ap.add_argument("--imlp", default="imlp", help="Name of IMLP model in CSV")
    ap.add_argument("--mlp",  default="mlp", help="Name of MLP model in CSV")
    ap.add_argument("--outdir", default=".", help="Output directory")
    args = ap.parse_args()

    df = pd.read_csv(args.csv)
    for col in ["segment","model","energy_j","accuracy"]:
        if col not in df.columns:
            raise SystemExit(f"Missing column '{col}' in {args.csv}")

    # keep only requested models
    df2 = df[df["model"].isin([args.imlp, args.mlp])].copy()
    if df2.empty:
        raise SystemExit("No rows for selected models.")

    # per-segment energy
    g = df2.groupby(["segment","model"], as_index=False)["energy_j"].mean()
    imlp = g[g["model"]==args.imlp].rename(columns={"energy_j":"imlp_energy"}).drop(columns="model")
    mlp  = g[g["model"]==args.mlp].rename(columns={"energy_j":"mlp_energy"}).drop(columns="model")
    per_seg = pd.merge(imlp, mlp, on="segment", how="inner")
    per_seg["overhead"] = per_seg["mlp_energy"]/per_seg["imlp_energy"]

    # per-segment accuracy
    ga = df2.groupby(["segment","model"], as_index=False)["accuracy"].mean()
    imlp_a = ga[ga["model"]==args.imlp].rename(columns={"accuracy":"imlp_acc"}).drop(columns="model")
    mlp_a  = ga[ga["model"]==args.mlp].rename(columns={"accuracy":"mlp_acc"}).drop(columns="model")
    per_seg_acc = pd.merge(imlp_a, mlp_a, on="segment", how="inner")

    outdir = Path(args.outdir); outdir.mkdir(parents=True, exist_ok=True)
    # save CSVs
    per_seg.round(4).to_csv(outdir/"per_segment_energy.csv", index=False)
    per_seg_acc.round(4).to_csv(outdir/"per_segment_accuracy.csv", index=False)

    # write LaTeX
    (outdir/"energy_per_segment.tex").write_text(to_latex_table_energy(per_seg))
    (outdir/"data_efficiency.tex").write_text(to_latex_table_accuracy(per_seg_acc))
    (outdir/"cumulative_energy.tex").write_text(to_latex_table_cumulative(per_seg))

    # quick console preview
    print((outdir/"energy_per_segment.tex").read_text())
    print()
    print((outdir/"data_efficiency.tex").read_text())
    print()
    print((outdir/"cumulative_energy.tex").read_text())

if __name__ == "__main__":
    main()
