from typing import List, Optional, Tuple, Dict, Any
import re
import os
import json
import argparse
from statistics import mean

import numpy as np

# Written with significant assistance by ChatGPT
# Generated tables were manually checked for consistency with the raw data

# This script was used for generating Table 1 in the paper


# Setup
DATASETS = ["metr-la", "pems-bay", "solar", "electricity"]  # table column order (hardcoded)
HORIZONS = [6, 24, 48, 96]
KERNELS = ["morlet", "hat"]
METHOD_FOLDERS = [("sgd", "SGD"), ("sswim", "SSWIM"), ("sswim_sgd", "SSWIM+SGD")]
TEMPLATE_LABEL = r"\label{tab:tsf_table}"
CELL_COUNT = len(DATASETS) * len(HORIZONS)  # 16
EPS = 1e-3

NUM_COLS = 20

NONE_TOKEN = "--"

EXPORT_LIST = [
    (
        "full_table", "\\label{tab:results_table}"
                          "Experimental results of time-series forecasting on $4$ benchmarks with various prediction lengths $6, 24, 48, 96$. "
                          "\\tquote{Kernel} denotes the \\textit{PSPK} used in the first layer. "
                          "Bold font indicates the best \\textit{SNN} result. "
                          "Underlined results indicate \\name performing at least as well as gradient-based methods. "
                          "Italic font indicates that \\textit{SGD} optimisation did not fully converge, i.e. the best epoch was within 30 epochs of the maximum. "
                            "Results are given in the \\textit{RSE} Metric, where lower is better. "
                          "Results highlighted with shading are ours; the remaining results were sourced from~\cite{tsf_reference}{Table 1}. "
                          "All results are averaged across $3$ seeds.", [0, 1, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14]
    ),
]





# ------------------------------------------------

float_re = re.compile(r'(-?\d+\.\d+)|(-?\.\d+)')  # accept only floats (with decimal point) or .123

def strip_formatting(s: str) -> str:
    """Remove rowcolor, textbf, bf, and color macros to obtain raw text for parsing."""
    s = re.sub(r'\\rowcolor\{[^\}]*\}', '', s)
    s = re.sub(r'\\textbf\{([^\}]*)\}', r'\1', s)
    s = re.sub(r'\\bf\b', '', s)
    s = re.sub(r'\\color\{[^\}]+\}', '', s)
    return s

def find_table_block(template_text: str) -> Tuple[str, int, int]:
    """Return table block and the start/end indices within template_text."""
    bpos = template_text.rfind(r"\begin{table*}", )
    if bpos == -1:
        raise RuntimeError(r"Couldn't find '\begin{table*}' before the label")
    epos = template_text.find(r"\end{table*}")
    if epos == -1:
        raise RuntimeError(r"Couldn't find '\end{table*}' after the label")
    return template_text[bpos:epos + len(r"\end{table*}")], bpos, epos + len(r"\end{table*}")

def collect_floats_from_lines(lines: List[str], start_idx: int, needed: int, lookahead: int = 3) -> List[Optional[float]]:
    """Collect up to `needed` float tokens from start_idx and up to lookahead following lines."""
    L = len(lines)
    nums = []
    for j in range(start_idx, min(L, start_idx + 1 + lookahead)):
        ln = strip_formatting(lines[j])
        # tokens:
        for m in float_re.findall(ln):
            token = m[0] if m[0] != '' else m[1]
            try:
                nums.append(float(token))
            except:
                continue
        if len(nums) >= needed:
            break
    if len(nums) < needed:
        # pad with None
        nums += [None] * (needed - len(nums))
    return nums[:needed]

# --- robust parse_existing_rows with reliable name and spiking detection ---
def _clean_latex_text(s: str) -> str:
    """Remove common latex macros and braces to yield readable text for name/kernel extraction."""
    if s is None:
        return ''
    # remove \rowcolor, \color{...}, \textcolor{...}{...}, \multirow{...}{...}{...} handled separately
    t = s
    # remove math $...$
    t = re.sub(r'\$[^\$]*\$', '', t)
    # remove \textcolor{...}{...} but keep inner text
    t = re.sub(r'\\textcolor\{[^\}]+\}\{([^\}]*)\}', r'\1', t)
    # remove \color{...} tokens
    t = re.sub(r'\\color\{[^\}]+\}', '', t)
    # remove \rowcolor{...}
    t = re.sub(r'\\rowcolor\{[^\}]+\}', '', t)
    # remove \multirow{...}{...}{...} but capture inner name if present
    t = re.sub(r'\\multirow\{[^\}]*\}\{[^\}]*\}\{([^\}]*)\}', r'\1', t)
    # remove other commands like \bf, \textbf{...} keeping inner text
    t = re.sub(r'\\textbf\{([^\}]*)\}', r'\1', t)
    t = re.sub(r'\\bf\b', '', t)
    # remove any other backslash commands \cmd{...} -> keep inner if simple, else remove
    t = re.sub(r'\\[a-zA-Z]+\*?\{([^\}]*)\}', r'\1', t)
    # remove remaining bare backslash commands like \cmark or \ding{51}
    t = re.sub(r'\\[a-zA-Z]+\{?[0-9]*\}?', '', t)
    # strip braces and extra whitespace
    t = t.replace('{', '').replace('}', '')
    t = re.sub(r'\s+', ' ', t).strip()
    return t

def extract_name_and_spike(block_lines: List[str]) -> Tuple[str, Optional[str], bool]:
    """
    From a list of body lines for a single method block, find:
      - name (prefer \multirow{...}{...}{NAME}; fallback: first token before first & after cleaning)
      - kernel name if present (search for 'morlet' or 'hat' case-insensitive)
      - is_spiking bool (look for \cmark, \ding{51}, or common colored variants)
    Returns (name, kernel, is_spiking)
    """
    block_text = '\n'.join(block_lines)
    # 1) name via multirow pattern (robust to color wrappers)
    name = None
    m = re.search(r'\\(?:textcolor\{[^\}]+\})?\s*\\?multirow\{[^\}]*\}\{[^\}]*\}\{([^\}]*)\}', block_text)
    if m:
        name_raw = m.group(1).strip()
        name = _clean_latex_text(name_raw)
    else:
        # fallback: scan first 3 lines for token before first & (strip latex)
        for ln in block_lines[:3]:
            if '&' in ln:
                candidate = ln.split('&', 1)[0]
                cand_clean = _clean_latex_text(candidate)
                if cand_clean:
                    name = cand_clean
                    break
        if name is None:
            # overall fallback: check entire block for a short textual token that looks like a model name
            # heuristics: find sequences of letters/numbers/spaces between braces or at line starts
            m2 = re.search(r'\}\s*([A-Za-z][A-Za-z0-9 \-\_()]{1,50})\s*&', block_text)
            if m2:
                name = _clean_latex_text(m2.group(1))
    if not name:
        name = "(unknown)"

    # 2) kernel detection (search block for known kernel tokens)
    kernel = None
    lowered = block_text.lower()
    for k in KERNELS:
        if k in lowered:
            kernel = k
            break

    # 3) spiking detection
    is_spiking = False
    # directly present cmark
    if '\\cmark' in block_text:
        is_spiking = True
    # common alternative: \ding{51} (pifont)
    elif re.search(r'\\ding\{\s*51\s*\}', block_text):
        is_spiking = True
    # colored variants like {\color{red}\multirow... \cmark }
    elif re.search(r'\\textcolor\{[^\}]+\}\{[^\}]*\\cmark[^\}]*\}', block_text):
        is_spiking = True
    # sometimes they use literal tickmark glyph or the word 'spike' -- allow some heuristic matches
    elif 'spike' in lowered or 'spiking' in lowered:
        # only set True if also a numeric block looks present (avoid false positives)
        is_spiking = True

    return name, kernel, is_spiking

def parse_existing_rows(table_block: str, debug: bool=False) -> Tuple[List[Dict[str,Any]], List[str], str]:
    """
    More robust version of parse_existing_rows.
    Returns (rows, header_lines, footer_text).
    Each row: dict with keys name, kernel, is_spiking, r2_vals, rse_vals, orig_block
    """
    lines = table_block.splitlines()
    # find bottomrule line index
    bottom_idx = None
    for i, ln in enumerate(lines):
        if '\\bottomrule' in ln:
            bottom_idx = i
            break
    if bottom_idx is None:
        raise RuntimeError("Could not find '\\bottomrule' in table block.")
    # find start of data area: first line containing R$^2$ and a & (heuristic)
    first_data_idx = None
    for i, ln in enumerate(lines):
        if ('R$^2$' in ln or 'R^2' in ln) and '&' in ln:
            first_data_idx = i
            break
    if first_data_idx is None:
        # fallback to first \multirow or heavy & line
        for i, ln in enumerate(lines):
            if '\\multirow' in ln or ln.count('&') >= 6:
                first_data_idx = i
                break
    if first_data_idx is None:
        raise RuntimeError("Could not find start of data rows in table block.")

    header_lines = lines[:first_data_idx]
    body_lines = lines[first_data_idx:bottom_idx]
    footer_text = '\n'.join(lines[bottom_idx:])

    rows = []
    i = 0
    L = len(body_lines)
    while i < L:
        ln = body_lines[i]
        # find R^2 start
        if 'R$^2$' in ln or 'R^2' in ln:
            r2_start = i
            # find rse line: the first line after r2 that mentions 'RSE' or has \cmark/\xmark or many &'s
            rse_idx = None
            for j in range(r2_start+1, min(L, r2_start+6)):
                if 'RSE' in body_lines[j] or '\\cmark' in body_lines[j] or '\\xmark' in body_lines[j] or body_lines[j].count('&') >= 6:
                    rse_idx = j
                    break
            if rse_idx is None:
                rse_idx = min(L-1, r2_start+1)

            # collect floats from r2_start and from rse_idx
            r2_vals = collect_floats_from_lines(body_lines, r2_start, CELL_COUNT, lookahead=4)
            rse_vals = collect_floats_from_lines(body_lines, rse_idx, CELL_COUNT, lookahead=4)

            # extract name/kernel/spike using block lines from r2_start..rse_idx (and a couple lines after)
            block_slice = body_lines[max(0, r2_start-1):min(L, rse_idx+2)]
            name, kernel, is_spiking = extract_name_and_spike(block_slice)

            rows.append({
                'name': name,
                'kernel': kernel,
                'is_spiking': is_spiking,
                'r2_vals': r2_vals,
                'rse_vals': rse_vals,
                'orig_block': block_slice
            })
            i = rse_idx + 1
        else:
            i += 1

    if debug:
        print(f"[DEBUG] parse_existing_rows_fixed: parsed {len(rows)} rows")
        for idx, r in enumerate(rows):
            print(f"  [{idx}] name={r['name']!r}, kernel={r['kernel']}, spiking={r['is_spiking']}, r2_first={r['r2_vals'][0]}, rse_first={r['rse_vals'][0]}")

    return rows, header_lines, footer_text

def read_new_rows_from_results(debug: bool=False) -> Tuple[List[Dict[str,Any]], List[Tuple[str,str]]]:
    """Read JSON results and return list of new-row dicts and insertion order for labels."""
    new_rows = []
    insertion_order = []
    for folder, label in METHOD_FOLDERS:
        for kernel in KERNELS:
            insertion_order.append((label, kernel))
            r2_cells = []
            rse_cells = []
            best_epoch_cells = []
            max_epochs_cells = []
            for ds in DATASETS:
                for H in HORIZONS:
                    path = os.path.join(folder, 'results', f'{ds}_{kernel}_{H}.json')
                    metrics = None
                    if os.path.isfile(path):
                        try:
                            with open(path, 'r', encoding='utf-8') as fh:
                                js = json.load(fh)
                            avg = js.get('averages', {})
                            r2 = avg.get('r2_test', None)
                            rse = avg.get('rse_test', None)
                            best_epoch = avg.get('best_epoch', 0)
                            max_epochs = js['config']['train'].get('num_epochs', 0)
                            # coerce lists to mean
                            if isinstance(r2, list):
                                r2 = float(mean(r2)) if r2 else None
                            elif isinstance(r2, (int, float)):
                                r2 = float(r2)
                            else:
                                r2 = None
                            if isinstance(rse, list):
                                rse = float(mean(rse)) if rse else None
                            elif isinstance(rse, (int, float)):
                                rse = float(rse)
                            else:
                                rse = None
                            metrics = (r2, rse, best_epoch, max_epochs)
                        except Exception as e:
                            print(f"[WARN] failed to read {path}: {e}")
                    else:
                        metrics = None
                    if debug:
                        print(f"[DEBUG] read new result: folder={folder} ds={ds} kernel={kernel} H={H} -> {path} -> {metrics}")
                    if metrics is None:
                        r2_cells.append(None)
                        rse_cells.append(None)
                        best_epoch_cells.append(None)
                        max_epochs_cells.append(None)
                    else:
                        r2_cells.append(metrics[0])
                        rse_cells.append(metrics[1])
                        best_epoch_cells.append(metrics[2])
                        max_epochs_cells.append(metrics[3])
            new_rows.append({
                'name': f"{label} ({kernel.capitalize()})",
                'kernel': kernel,
                'is_spiking': True,
                'r2_vals': r2_cells,
                'rse_vals': rse_cells,
                'best_epoch_vals': best_epoch_cells,
                'max_epochs_vals': max_epochs_cells,
                'added': True
            })
    return new_rows, insertion_order

def compute_bolding(existing_rows: List[Dict], new_rows: List[Dict], debug: bool=False):
    """
    Compute per-row-per-cell bold masks for spiking rows only.
    Returns two dicts mapping row_id -> list[bool] for r2 and rse respectively.
    row_id is index into combined_rows (existing first then new appended).
    """
    combined = []
    # keep track mapping to identify which are spiking rows
    for r in existing_rows:
        combined.append(r)
    for r in new_rows:
        combined.append(r)
    M = len(combined)
    N = CELL_COUNT
    bold_r2 = [ [False]*N for _ in range(M) ]
    bold_rse = [ [False]*N for _ in range(M) ]

    # Build list of spiking rows indices
    spiking_indices = [i for i, r in enumerate(combined) if r.get('is_spiking', False)]
    # For each cell compute best among spiking rows only
    for c in range(N):
        vals_r2 = []
        vals_rse = []
        for i in spiking_indices:
            v = combined[i]['r2_vals'][c]
            if v is not None:
                vals_r2.append(v)
            s = combined[i]['rse_vals'][c]
            if s is not None:
                vals_rse.append(s)
        # R2: higher better
        if vals_r2:
            best_r2 = max(vals_r2)
            for i in spiking_indices:
                v = combined[i]['r2_vals'][c]
                if v is not None and abs(v - best_r2) < EPS:
                    bold_r2[i][c] = True
        # RSE: lower better
        if vals_rse:
            best_rse = min(vals_rse)
            for i in spiking_indices:
                s = combined[i]['rse_vals'][c]
                if s is not None and abs(s - best_rse) < EPS:
                    bold_rse[i][c] = True

    if debug:
        print(f"[DEBUG] computed bold masks for {M} combined rows, {len(spiking_indices)} spiking rows.")
    return combined, bold_r2, bold_rse

def compute_underline(existing_rows: List[Dict], new_rows: List[Dict], debug: bool=False):
    """
    Compute per-row-per-cell bold masks for spiking rows only.
    Returns two dicts mapping row_id -> list[bool] for r2 and rse respectively.
    row_id is index into combined_rows (existing first then new appended).
    """
    combined = []
    # keep track mapping to identify which are spiking rows
    for r in existing_rows:
        combined.append(r)
    for r in new_rows:
        combined.append(r)
    M = len(combined)
    N = CELL_COUNT
    it_r2 = [ [False]*N for _ in range(M) ]
    it_rse = [ [False]*N for _ in range(M) ]

    # Find best sgd
    best_sgd_r2 = np.full(CELL_COUNT, -np.inf)
    best_sgd_rse = np.full(CELL_COUNT, np.inf)
    sgd_idcs = [i for i, r in enumerate(combined) if r.get('added', False) and "SGD" in r.get("name", "") and not '+' in r.get("name", "")]
    for c in range(N):
        for i in sgd_idcs:
            if combined[i]['r2_vals'][c] is not None:
                best_sgd_r2[c] = max(best_sgd_r2[c], combined[i]['r2_vals'][c])
            if combined[i]['rse_vals'][c] is not None:
                best_sgd_rse[c] = min(best_sgd_rse[c], combined[i]['rse_vals'][c])

    # Build list of sswim rows indices
    sswim_idcs = [i for i, r in enumerate(combined) if (r.get('added', False) and 'SSWIM' in r.get("name", "")) and not '+' in r.get("name", "")]
    # For each cell compute best among spiking rows only
    for c in range(N):
        for i in sswim_idcs:
            if combined[i]['r2_vals'][c] is not None:
                if combined[i]['r2_vals'][c] >= best_sgd_r2[c]:
                    it_r2[i][c] = True
            if combined[i]['rse_vals'][c] is not None:
                if combined[i]['rse_vals'][c] <= best_sgd_rse[c]:
                    it_rse[i][c] = True


    if debug:
        print(f"[DEBUG] computed it masks for {M} combined rows, {len(sswim_idcs)} sswim rows.")
    return combined, it_r2, it_rse

def compute_italic(existing_rows: List[Dict], new_rows: List[Dict], debug: bool=False):
    """
    Compute per-row-per-cell bold masks for spiking rows only.
    Returns two dicts mapping row_id -> list[bool] for r2 and rse respectively.
    row_id is index into combined_rows (existing first then new appended).
    """
    combined = []
    # keep track mapping to identify which are spiking rows
    for r in existing_rows:
        combined.append(r)
    for r in new_rows:
        combined.append(r)
    M = len(combined)
    N = CELL_COUNT
    it_experiment = [ [False]*N for _ in range(M) ]

    # Build list of spiking rows indices
    new_idcs = [i for i, r in enumerate(combined) if (r.get('added', False) and 'SGD' in r.get("name", ""))]
    # For each cell compute best among spiking rows only
    for c in range(N):
        for i in new_idcs:
            best_epoch = combined[i]['best_epoch_vals'][c]
            max_epochs = combined[i]['max_epochs_vals'][c]
            if best_epoch is not None and max_epochs is not None:
                if best_epoch > max_epochs - 30:
                    it_experiment[i][c] = True


    if debug:
        print(f"[DEBUG] computed bold masks for {M} combined rows, {len(new_idcs)} spiking rows.")
    return combined, it_experiment


# --- helper: minimal LaTeX escaping for names/kernels ---
def latex_escape(s: str) -> str:
    """
    Escape most common LaTeX special characters so table cell text doesn't break.
    Not exhaustive but covers: &, %, $, #, _, {, }, ~, ^, backslash.
    """
    if s is None:
        return ''
    esc_map = {
        '&': r'\&', '%': r'\%', '$': r'\$', '#': r'\#',
        '_': r'\_', '{': r'\{', '}': r'\}',
        '~': r'\textasciitilde{}', '^': r'\^{}', '\\': r'\textbackslash{}'
    }
    out = []
    for ch in str(s):
        out.append(esc_map.get(ch, ch))
    return ''.join(out)

def safe_tex_num(v: Optional[float], bold: bool = False, underline: bool = False, it:bool=False) -> str:
    """Format numeric as before, but safe for None. Wrap in \\textbf{} if bold."""
    if v is None:
        tok = NONE_TOKEN
    else:
        s = f"{abs(v):.3f}"
        if s.startswith("0."):
            s = s[1:]
        if v < 0:
            s = "-" + s
        tok = f"${s}$"
    if it and tok != NONE_TOKEN:
        if bold:
            return "$\\bmit{" + s + "}$"
        else:
            return "$\\mathit{" + s + "}$"
    if underline and tok != NONE_TOKEN:
        if bold:
            return "$\\underline{\\mathbf{" + s + "}}$"
        else:
            return "$\\underline{" + s + "}$"
    if bold and tok != NONE_TOKEN:
        return "$\\mathbf{" + s + "}$"
    return tok

# --- replacement rebuild_table_text with safe handling ---
def rebuild_table_text(header_lines: List[str], footer_text: str,
                       existing_rows: List[Dict], new_rows: List[Dict],
                       combined: List[Dict], bold_r2, bold_rse,
                       underline_r2: List[List[bool]], underline_rse: List[List[bool]],
                       it_experiment: List[List[bool]],
                       insertion_order: List[Tuple[str,str]], print_existing: List[int] = None, metric:str="", caption:str="") -> str:
    """
    Reconstruct table block: header_lines (kept), then existing_rows (no shading),
    then new_rows (shaded). Uses bold masks in bold_r2/bold_rse indexed into combined list.
    """
    lines_out = []
    lines_out.extend(header_lines)

    def build_row_block(idx, row, is_added=False):
        # Coerce values and escape name/kernel safely
        r2 = row.get('r2_vals', [None]*CELL_COUNT)
        rse = row.get('rse_vals', [None]*CELL_COUNT)
        # Prepare per-cell tex with bold masks
        r2_cells = [safe_tex_num(r2[c], bold=bool(bold_r2[idx][c]), underline=bool(underline_r2[idx][c]), it=bool(it_experiment[idx][c])) for c in range(CELL_COUNT)]
        rse_cells = [safe_tex_num(rse[c], bold=bool(bold_rse[idx][c]), underline=bool(underline_rse[idx][c]), it=bool(it_experiment[idx][c])) for c in range(CELL_COUNT)]
        # Averages (safe)
        rv = [v for v in r2 if v is not None]
        sv = [v for v in rse if v is not None]
        avg_r2 = (sum(rv)/len(rv)) if rv else None
        avg_rse = (sum(sv)/len(sv)) if sv else None
        # Name and kernel safe
        name_raw = row.get('name') or ''
        name = latex_escape(name_raw)
        kernel_raw = row.get('kernel') or '--'
        kernel_pretty = latex_escape(kernel_raw.capitalize() if kernel_raw else '')
        # Determine spike marker
        spike_marker = r'\cmark' if row.get('is_spiking', False) else r'\xmark'
        # Build row strings
        prefix = r'\rowcolor{cpgcolor} ' if is_added else ''
        name_field = name if name != '' else ''
        if metric=="R2":
            row = prefix + r'\multirow{-1}{*}{' + name_field + r'} & ' + spike_marker + ' & ' + kernel_pretty + ' & '  + ' & '.join(r2_cells) + ' & ' + safe_tex_num(avg_r2) + r' \\'
        else:
            row = (prefix + r'\multirow{-1}{*}{' + name_field + r'} & ' + spike_marker + ' & ' + kernel_pretty + ' & ' + ' & '.join(rse_cells) + ' & ' + safe_tex_num(avg_rse) + r' \\')
        return row

    # print existing rows in parsed order (no shading)
    combined_idx = 0
    for idx, ex in enumerate(existing_rows):
        if (print_existing is None) or (idx in print_existing):
            r1 = build_row_block(idx, ex, is_added=False)
            # after every second block or at the end
            if combined_idx % 2 == 1 or combined_idx == len(print_existing) - 1:
                r1 = r1.replace(r' \\', r' \\ \cdashline{1-' + str(NUM_COLS) + '}')
            lines_out.append(r1)
            #r2 = r2.replace(r' \\', r' \\ \cdashline{1-' + str(NUM_COLS) + '}')
            #lines_out.append(r2)
            combined_idx += 1

    combined_idx = len(existing_rows)

    # append new rows (shaded)
    for  idx, new in enumerate(new_rows):
        r1 = build_row_block(combined_idx, new, is_added=True)
        if idx % 2 == 1 and idx != len(new_rows) - 1:
            r1 = r1.replace(r' \\', r' \\ \cdashline{1-' + str(NUM_COLS) + '}')
        lines_out.append(r1)
        """if idx % 2 == 1:  # after every second block
            r2 = r2.replace(r' \\', r' \\ \midrule')"""
        #r2 = r2.replace(r' \\', r' \\ \cdashline{1-' + str(NUM_COLS) + '}')
        #lines_out.append(r2)
        combined_idx += 1

    # finally append footer (bottomrule etc.)
    lines_out.append(footer_text)
    if caption != "":
        return '\n'.join(lines_out).replace(
            "\\centering", "\\centering\n\\caption{" + caption + "}\n"
        )
    else:
        return '\n'.join(lines_out)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--template', default='neurips_2025_table_source_split.tex')
    parser.add_argument('--out', default='tsf_table_extended.tex')
    parser.add_argument('--debug', action='store_true')
    args = parser.parse_args()

    if not os.path.isfile(args.template):
        raise SystemExit(f"Template file '{args.template}' not found. (If you want me to run here, re-upload your template and results folders.)")

    with open(args.template, 'r', encoding='utf-8') as fh:
        tpl = fh.read()

    table_block, bpos, epos = find_table_block(tpl)
    if args.debug:
        print("[DEBUG] extracted table block")

    # 1) parse existing table rows
    existing_rows, header_lines, footer_text = parse_existing_rows(table_block, debug=args.debug)

    # 2) parse new results
    new_rows, insertion_order = read_new_rows_from_results(debug=args.debug)

    # 3) compute best spiking methods across existing_spiking + new_spiking
    combined, bold_r2, bold_rse = compute_bolding(existing_rows, new_rows, debug=args.debug)
    combined, underline_r2, underline_rse = compute_underline(existing_rows, new_rows, debug=args.debug)
    combined, it_experiment = compute_italic(existing_rows, new_rows, debug=args.debug)


    # combined corresponds to existing_rows + new_rows; bold masks index accordingly

    # 4) rebuild entire table block
    for metric in ["R2", "RSE"]:
        for name, caption, print_existing in EXPORT_LIST:
            name = f"{name}_{metric}.tex"
            rebuilt = rebuild_table_text(header_lines, footer_text, existing_rows, new_rows, combined, bold_r2, bold_rse, underline_r2, underline_rse, it_experiment, insertion_order, print_existing, metric=metric, caption=caption)

            # write rebuilt table to out
            with open(f"artefacts/{name}", 'w', encoding='utf-8') as fh:
                fh.write(rebuilt)

            print(f"[OK] Wrote rebuilt table block to {name}")
    if args.debug:
        # print a small summary of bold winners per cell
        M = len(combined)
        N = CELL_COUNT
        winners_r2 = [ [i for i in range(M) if bold_r2[i][c]] for c in range(N) ]
        winners_rse = [ [i for i in range(M) if bold_rse[i][c]] for c in range(N) ]
        print("[DEBUG] Winners per cell (R2):", winners_r2)
        print("[DEBUG] Winners per cell (RSE):", winners_rse)
        # print parsed existing rows first entries
        for idx, r in enumerate(existing_rows):
            print(f"[EXIST] idx={idx} name={r['name']!r} spiking={r['is_spiking']} r2_first={r['r2_vals'][0]}")

if __name__ == '__main__':
    main()
