#!/usr/bin/env python3
"""
Create a 1x3 grid figure for throughput (tokens/s) from throughput_summary.csv
or directly from results_summary.csv (auto-converted).

Cols: datasets in fixed order [openwebtext, code, multilingual]
Each cell shows tokens-per-packet on X-axis (discrete) and throughput on Y-axis
for PinTok, Rust, Python with error bars (std across trials).
"""

import argparse
import os
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter

import csv


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Create a 1x3 throughput grid from throughput_summary.csv or results_summary.csv")
    p.add_argument("--csv", required=True, help="Path to throughput_summary.csv or results_summary.csv")
    p.add_argument("--out-dir", required=True, help="Directory to save the grid plot")
    p.add_argument("--outfile", default="plot_throughput_grid_1x3.pdf", help="Output filename (PDF)")
    return p.parse_args()


DATASETS = ["openwebtext", "code", "multilingual"]
DATASET_TITLES = {"openwebtext": "OpenWebText", "code": "Code", "multilingual": "Multilingual"}
COLORS = {"python": "#FDB515", "rust": "#002676", "pintok": "#8C1515"}
LABELS = {"python": "Python", "rust": "Rust", "pintok": "PinTok"}


def load_rows(csv_path: str) -> List[Dict[str, str]]:
    with open(csv_path, "r") as f:
        reader = csv.DictReader(f)
        rows = list(reader)
    return rows


def to_float(s):
    try:
        return float(s)
    except Exception:
        return None


def build_series(rows: List[Dict[str, str]], dataset: str) -> Dict[str, List[Tuple[int, float, float]]]:
    """Return throughput series for a dataset.

    Supports two CSV formats:
      - throughput_summary.csv with *_tps_mean/std columns
      - results_summary.csv with avg_*_mean_us/std_us columns (auto-converted)
    """
    # Detect which column scheme is present
    has_tps = any('pintok_tps_mean' in r for r in rows) or any(
        'pintok_tps_mean' in k for r in rows for k in r.keys()
    )

    series = {"pintok": [], "rust": [], "python": []}
    if has_tps:
        col = {
            "pintok": ("pintok_tps_mean", "pintok_tps_std"),
            "rust": ("rust_tps_mean", "rust_tps_std"),
            "python": ("python_tps_mean", "python_tps_std"),
        }
        for r in rows:
            if r.get('dataset') != dataset:
                continue
            try:
                tokens = int(float(r.get('tokens_per_packet', '')))
            except Exception:
                continue
            for mode, (mcol, scol) in col.items():
                mean = to_float(r.get(mcol))
                std = to_float(r.get(scol))
                if mean is not None:
                    series[mode].append((tokens, mean, std or 0.0))
    else:
        # Convert results_summary rows to throughput
        # Expect columns: avg_pintok_mean_us, avg_pintok_std_us, etc.
        def tf(mu_us: str | None, sd_us: str | None, tokens: int) -> Tuple[float | None, float | None]:
            try:
                mu = float(mu_us) if mu_us else None
                sd = float(sd_us) if sd_us else None
            except Exception:
                return (None, None)
            if mu and mu > 0:
                mean_tps = tokens * 1_000_000.0 / mu
                std_tps = (tokens * 1_000_000.0 * sd / (mu * mu)) if (sd is not None) else 0.0
                return (mean_tps, std_tps)
            return (None, None)

        for r in rows:
            if r.get('dataset') != dataset:
                continue
            try:
                tokens = int(float(r.get('tokens_per_packet', '')))
            except Exception:
                continue
            dpdk_m, dpdk_s = tf(r.get('avg_pintok_mean_us'), r.get('avg_pintok_std_us'), tokens)
            rust_m, rust_s = tf(r.get('avg_rust_mean_us'), r.get('avg_rust_std_us'), tokens)
            py_m, py_s = tf(r.get('avg_python_mean_us'), r.get('avg_python_std_us'), tokens)
            if dpdk_m is not None:
                series['pintok'].append((tokens, dpdk_m, dpdk_s or 0.0))
            if rust_m is not None:
                series['rust'].append((tokens, rust_m, rust_s or 0.0))
            if py_m is not None:
                series['python'].append((tokens, py_m, py_s or 0.0))

    for mode in list(series.keys()):
        series[mode].sort(key=lambda t: t[0])
    return series


def plot_cell(ax, series: Dict[str, List[Tuple[int, float, float]]], y_limits=None):
    """Grouped bar plot with thin capped error bars (Python/Rust/PinTok)."""
    # Collect all token values and create group positions (union across modes)
    all_tokens = sorted({t for mode in series.values() for (t, _, _) in mode})
    n_groups = len(all_tokens)
    # Map token -> index for alignment even if some modes are missing a token
    token_index = {t: i for i, t in enumerate(all_tokens)}
    x_centers = list(range(n_groups))
    ax.set_xticks(x_centers)
    ax.set_xticklabels([f"{t:,}" for t in all_tokens])

    # Bar layout: Python left, Rust middle, PinTok right
    bar_width = 0.26 if n_groups > 1 else 0.28
    offsets = {
        'python': -bar_width,
        'rust': 0.0,
        'pintok': bar_width,
    }

    for mode in ["python", "rust", "pintok"]:
        pts = series.get(mode, [])
        if not pts:
            continue
        xs = [token_index[t] + offsets[mode] for (t, _, _) in pts]
        ys = [m for (_, m, _) in pts]
        es = [s for (_, _, s) in pts]
        # Draw bars with thin error bars + caps
        ax.bar(
            xs,
            ys,
            width=bar_width,
            color=COLORS[mode],
            label=LABELS[mode],
            linewidth=0.0,
            yerr=es if any(e > 0 for e in es) else None,
            error_kw=dict(ecolor='black', elinewidth=0.7, capsize=2.0, capthick=0.7),
        )

    # Visual styling
    ax.grid(True, linestyle='--', alpha=0.25, linewidth=0.4, axis='y')
    ax.set_xlim(-0.6, max(x_centers) + 0.6 if x_centers else 0.4)
    # Apply shared y-limits if provided
    if y_limits:
        ax.set_ylim(y_limits)
    # Show the token length under each 3-bar set
    # Labels already set via set_xticklabels; do not override with a formatter


def main():
    args = parse_args()
    rows = load_rows(args.csv)
    if not rows:
        print(f"No rows found in {args.csv}")
        return

    os.makedirs(args.out_dir, exist_ok=True)

    plt.rcParams.update({
        'font.size': 7.0,
        'axes.titlesize': 7.0,
        'axes.labelsize': 6.8,
        'xtick.labelsize': 6.2,
        'ytick.labelsize': 6.2,
        'legend.fontsize': 6.4,
        'axes.linewidth': 0.6,
        'xtick.major.width': 0.5,
        'ytick.major.width': 0.5,
        'xtick.major.size': 2.0,
        'ytick.major.size': 2.0,
        'grid.linewidth': 0.4,
    })

    fig, axs = plt.subplots(1, 3, figsize=(5.2, 1.44))

    # Pre-compute y-axis limits across all datasets for alignment
    import math
    y_min = math.inf
    y_max = -math.inf
    all_series = {}
    for dataset in DATASETS:
        series = build_series(rows, dataset)
        all_series[dataset] = series
        for mode in ["python", "rust", "pintok"]:
            for _, mean, std in series.get(mode, []):
                y_min = min(y_min, mean - std if std else mean)
                y_max = max(y_max, mean + std if std else mean)

    if y_min == math.inf or y_max == -math.inf:
        y_min, y_max = 0, 100000
    else:
        y_min = max(0, y_min)
        span = y_max - y_min
        pad = 0.05 * span
        y_min = max(0, y_min - pad)
        y_max = y_max + pad

    shared_y_limits = (y_min, y_max)

    for c, dataset in enumerate(DATASETS):
        ax = axs[c]
        series = all_series[dataset]
        plot_cell(ax, series, y_limits=shared_y_limits)
        ax.set_title(DATASET_TITLES.get(dataset, dataset))
        if c == 0:
            ax.set_ylabel("Throughput\n(tokens/s)")
        else:
            ax.set_yticklabels([])
        ax.set_xlabel("Tokens")

    # Single legend
    handles, labels = axs[0].get_legend_handles_labels()
    desired = ["Python", "Rust", "PinTok"]
    if handles and labels:
        map_h = {l: h for h, l in zip(handles, labels)}
        labels = [l for l in desired if l in map_h]
        handles = [map_h[l] for l in labels]
    fig.tight_layout(rect=[0.01, 0.01, 0.88, 0.99])
    if handles and labels:
        fig.legend(handles, labels, loc='center left', ncol=1, frameon=False, bbox_to_anchor=(0.90, 0.5), borderaxespad=0.0, handlelength=1.5, handletextpad=0.5, markerscale=1.0, borderpad=0.2)

    out_path = os.path.join(args.out_dir, args.outfile)
    fig.savefig(out_path, bbox_inches='tight', pad_inches=0.02)
    print(f"Saved throughput grid to {out_path}")


if __name__ == "__main__":
    main()
