#!/usr/bin/env python3
"""parse_artifact.py
---------------------------------
Generate **summary performance tables** for 3‑layer and 5‑layer runs.

For each layer count we produce **one PrettyTable** whose rows are:

1. **HongTu** – baseline wall‑clock time (s) for Epoch 1.
2. **GRD** – fastest time among ``GRD-G`` and ``GRD-GC``.
3. **Speedup** – ratio *HongTu ÷ GRD* ("x" notation).  Values show how many
   times faster GRD is; "–" where data are missing.

Columns correspond to partitioning methods **METIS**, **Random**, **GRD**.

Tables for 3 layers are printed first, then for 5 layers.  All results are also
written—separated by blank lines—to
``artifact_results/artifact_results.txt``.
"""
from __future__ import annotations

import re
from pathlib import Path
from typing import Dict, Optional

from prettytable import PrettyTable

# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
LOG_DIR = Path(__file__).resolve().parent / "artifact_results"
if not LOG_DIR.is_dir():
    raise SystemExit(f"Log directory not found: {LOG_DIR}")

LAYERS = [3, 5]
PARTITION_TOKEN_TO_COL = {
    "metis": "METIS",
    "random": "Random",
    "grinnder": "GRD",
}
COLS = ["METIS", "Random", "GRD"]
GRD_COMPONENTS = {"GRD-G", "GRD-GC"}

# file name pattern: <dataset>_<model>_<layers>layers_<partition>.log
FILENAME_RE = re.compile(
    r"^(?P<dataset>[^_]+)_(?P<model>[^_]+)_(?P<layers>\d+)layers_(?P<part>[^.]+)\.log$",
    re.IGNORECASE,
)

EPOCH1_PAT = re.compile(r"Epoch\s+1\b", re.IGNORECASE)
TIME_PAT = re.compile(r"Time \(s\):\s*([0-9]+(?:\.[0-9]+)?)s", re.IGNORECASE)

# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------

def epoch1_time(path: Path) -> Optional[float]:
    seen = False
    with path.open(encoding="utf-8", errors="replace") as fh:
        for line in fh:
            if not seen and EPOCH1_PAT.search(line):
                seen = True
                continue
            if seen:
                m = TIME_PAT.search(line)
                if m:
                    return float(m.group(1))
    return None

# -----------------------------------------------------------------------------
# Load data -> times[layer][model][partition] = seconds
# -----------------------------------------------------------------------------
raw: Dict[int, Dict[str, Dict[str, float]]] = {}
for log in LOG_DIR.glob("*.log"):
    m = FILENAME_RE.match(log.name)
    if not m:
        continue
    layer = int(m.group("layers"))
    if layer not in LAYERS:
        continue
    model = m.group("model")
    part_token = m.group("part").lower()
    part = PARTITION_TOKEN_TO_COL.get(part_token)
    if part is None:
        continue
    t = epoch1_time(log)
    if t is None:
        continue
    raw.setdefault(layer, {}).setdefault(model, {})[part] = t

# -----------------------------------------------------------------------------
# Build summary table per layer
# -----------------------------------------------------------------------------

def build_table(layer: int) -> PrettyTable:
    tbl = PrettyTable()
    tbl.title = f"Products {layer}-Layer Summary (seconds)"
    tbl.field_names = ["Method / Partitioning"] + COLS

    layer_times = raw.get(layer, {})

    # HongTu row
    hongtu_vals = [layer_times.get("HongTu", {}).get(c, "–") for c in COLS]
    tbl.add_row(["HongTu"] + hongtu_vals)

    # GRD row (min of GRD-G / GRD-GC)
    grd_vals = []
    for c in COLS:
        candidates = [layer_times.get(m, {}).get(c) for m in GRD_COMPONENTS]
        candidates = [v for v in candidates if v is not None]
        grd_vals.append(min(candidates) if candidates else "–")
    tbl.add_row(["GRD"] + grd_vals)

    # Speed‑up row (HongTu ÷ GRD)
    speed_vals = []
    for h, g in zip(hongtu_vals, grd_vals):
        if isinstance(h, float) and isinstance(g, float) and g > 0:
            speed_vals.append(f"{h/g:.2f}×")
        else:
            speed_vals.append("–")
    tbl.add_row(["Speedup"] + speed_vals)

    return tbl

# -----------------------------------------------------------------------------
# Generate, print, save
# -----------------------------------------------------------------------------
all_tables = [build_table(l) for l in LAYERS]

for t in all_tables:
    print(t)
    print()

out = LOG_DIR / "artifact_results.txt"
with out.open("w") as fp:
    for i, t in enumerate(all_tables):
        fp.write(str(t))
        if i != len(all_tables) - 1:
            fp.write("\n\n")
print(f"Saved summary tables with speed‑ups to {out}")
