"""Summarize evaluation scores across datasets and sparsity levels, and plot lines.

This script loads the evaluation routine from long_context_eval/evaluate.py and
computes scores for a set of LongBench datasets at baseline and multiple
mask-topk sparsity settings. Results are printed as a compact table, and a
line-plot (score vs. sparsity) is saved where each line corresponds to one
dataset (baseline is not plotted).

Inputs are read from long_context_eval/results/<model_key>/<dataset>/...
  - Baseline file: baseline/baseline.jsonl
  - MaskTopK grid: masktopk_grid/mask_topk_top{SPARSITY:0.0000}.jsonl

Command line usage (examples):
  python3 plot_sparsity_ablation.py \
    --model_key llama-3.1 \
    --datasets narrativeqa trec gov_report lcc multifieldqa_en multifieldqa_zh \
    --sparsities 0.001 0.005 0.01 0.05 0.1 0.5

The line plot is saved to:
  long_context_eval/figures/<model_key>/sparsity_ablation/scores_vs_sparsity.png

Time complexity: O(D * (1 + S) * N * L)
  - D: number of datasets
  - S: number of sparsity levels
  - N: number of examples per dataset result file
  - L: average tokenized string length for metric computation
Space complexity: O(1) auxiliary (files processed line-by-line).
"""

from __future__ import annotations

import argparse
import os
import sys
from typing import Dict, Iterable, List, Optional, Tuple
import json
import numpy as np
import matplotlib.pyplot as plt

# Global plotting style
plt.rcParams.update({
    "font.size": 10,
    "axes.titlesize": 12,
    "axes.labelsize": 11,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
    "legend.fontsize": 9,
    "figure.figsize": (3.5, 2.5),
})


# Ensure we can import long_context_eval/evaluate.py as a module named "evaluate"
_CUR_DIR = os.path.dirname(os.path.abspath(__file__))
_EVAL_ROOT = os.path.abspath(os.path.join(_CUR_DIR, os.pardir))
if _EVAL_ROOT not in sys.path:
    sys.path.insert(0, _EVAL_ROOT)

from evaluate import evaluate_file  # type: ignore  # local module


DEFAULT_DATASETS: List[str] = [
    "trec",
    # "multifieldqa_zh",
    # "multifieldqa_en",
    "repobench-p",
    "lcc",
    "lsht",
]


DEFAULT_SPARSITIES: List[float] = [0.001, 0.005, 0.01, 0.05, 0.1, 0.5]


def _format_score(score: Optional[float]) -> str:
    """Format score to two decimals or return '-' if missing."""
    if score is None:
        return "-"
    return f"{score:.2f}"


def _file_exists(path: str) -> bool:
    return os.path.isfile(path)


def compute_scores_for_dataset(
    dataset: str,
    model_key: str,
    eval_root: str,
    sparsities: Iterable[float],
) -> Dict[str, Optional[float]]:
    """Compute baseline and sparsity scores for one dataset.

    Returns a dict with keys:
      - "baseline": Optional[float]
      - f"{s:.2f}": Optional[float] for each sparsity level
    """
    results_dir = os.path.join(eval_root, "results", model_key, dataset)

    out: Dict[str, Optional[float]] = {}

    # Baseline
    baseline_path = os.path.join(results_dir, "baseline", "baseline.jsonl")
    out["baseline"] = evaluate_file(baseline_path, dataset) if _file_exists(baseline_path) else None

    # Sparsities
    for s in sparsities:
        key = f"{s:.2f}"
        fname = f"mask_topk_top{float(s):0.4f}.jsonl"
        s_path = os.path.join(results_dir, "masktopk_grid", fname)
        out[key] = evaluate_file(s_path, dataset) if _file_exists(s_path) else None

    return out


def format_table(
    datasets: List[str],
    sparsities: List[float],
    scores: Dict[str, Dict[str, Optional[float]]],
) -> str:
    """Return a simple fixed-width table of scores as a string."""
    # Build headers
    headers: List[str] = ["dataset", "baseline"] + [f"{s:.2f}" for s in sparsities]

    # Column widths
    col_widths: List[int] = [max(7, max(len(d) for d in datasets))]
    col_widths.append(max(8, len(headers[1])))
    for s in sparsities:
        col_widths.append(max(6, len(f"{s:.2f}")))

    # Header row
    header_cells: List[str] = []
    for i, h in enumerate(headers):
        header_cells.append(h.ljust(col_widths[i]))
    lines: List[str] = [" ".join(header_cells)]

    # Separator
    lines.append(" ".join("-" * w for w in col_widths))

    # Rows
    for dataset in datasets:
        row_cells: List[str] = []
        row_cells.append(dataset.ljust(col_widths[0]))
        row_cells.append(_format_score(scores[dataset].get("baseline")).rjust(col_widths[1]))
        for j, s in enumerate(sparsities):
            key = f"{s:.2f}"
            row_cells.append(_format_score(scores[dataset].get(key)).rjust(col_widths[2 + j]))
        lines.append(" ".join(row_cells))
    return "\n".join(lines)


def print_table(
    datasets: List[str],
    sparsities: List[float],
    scores: Dict[str, Dict[str, Optional[float]]],
) -> None:
    print(format_table(datasets, sparsities, scores))


def plot_lines(
    datasets: List[str],
    sparsities: List[float],
    scores: Dict[str, Dict[str, Optional[float]]],
    model_key: str,
    eval_root: str,
) -> str:
    """Plot score vs sparsity for each dataset (baseline excluded).

    Returns the path of the saved figure.
    """
    # Transform x to retention percent (1 - sparsity) * 100
    x_orig = np.array(sparsities, dtype=float)
    x_vals = (1.0 - x_orig) * 100.0
    x_order = np.argsort(x_vals)

    plt.figure(figsize=(3.5, 2.5))
    for dataset in datasets:
        y_vals: List[Optional[float]] = []
        for s in sparsities:
            key = f"{s:.2f}"
            y_vals.append(scores[dataset].get(key))
        # Filter None by masking
        y = np.array([np.nan if v is None else float(v) for v in y_vals], dtype=float)
        # Sort by transformed x for clean left-to-right lines
        plt.plot(x_vals[x_order], y[x_order], marker="o", label=dataset)

    plt.xlabel("Sparsity (%)")
    # plt.xlim(50, 100)
    plt.ylabel("Accuracy")
    plt.ylim(-10, 85)
    # No title
    plt.grid(True, linestyle="--", alpha=0.3)
    plt.legend(frameon=False, loc="lower left")

    out_dir = os.path.join(eval_root, "figures", model_key, "sparsity_ablation")
    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, "accuracy_vs_sparsity.png")
    plt.tight_layout()
    plt.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close()
    return out_path


def _cache_paths(eval_root: str, model_key: str) -> Tuple[str, str]:
    out_dir = os.path.join(eval_root, "figures", model_key, "sparsity_ablation")
    os.makedirs(out_dir, exist_ok=True)
    return os.path.join(out_dir, "scores_table.json"), os.path.join(out_dir, "scores_table.txt")


def save_scores_cache(
    eval_root: str,
    model_key: str,
    datasets: List[str],
    sparsities: List[float],
    scores: Dict[str, Dict[str, Optional[float]]],
) -> None:
    json_path, txt_path = _cache_paths(eval_root, model_key)
    payload = {
        "model_key": model_key,
        "datasets": datasets,
        "sparsities": sparsities,
        "scores": scores,
    }
    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, ensure_ascii=False, indent=2)
    table_str = format_table(datasets, sparsities, scores)
    with open(txt_path, "w", encoding="utf-8") as f:
        f.write(table_str + "\n")


def load_scores_cache(
    eval_root: str,
    model_key: str,
) -> Optional[Tuple[List[str], List[float], Dict[str, Dict[str, Optional[float]]]]]:
    json_path, _ = _cache_paths(eval_root, model_key)
    if not os.path.isfile(json_path):
        return None
    try:
        with open(json_path, "r", encoding="utf-8") as f:
            payload = json.load(f)
        if payload.get("model_key") != model_key:
            return None
        datasets = list(payload.get("datasets", []))
        sparsities = [float(x) for x in payload.get("sparsities", [])]
        scores = payload.get("scores", {})
        if not isinstance(scores, dict):
            return None
        return datasets, sparsities, scores
    except Exception:
        return None


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Print evaluation table over datasets and sparsities.")
    parser.add_argument("--model_key", type=str, default="llama-3.1", help="Model key under results directory")
    parser.add_argument(
        "--datasets",
        type=str,
        nargs="*",
        default=DEFAULT_DATASETS,
        help="Datasets to summarize",
    )
    parser.add_argument(
        "--sparsities",
        type=float,
        nargs="*",
        default=DEFAULT_SPARSITIES,
        help="Sparsity levels to summarize",
    )
    parser.add_argument(
        "--recompute",
        action="store_true",
        help="Recompute scores even if a cached table exists",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    eval_root = _EVAL_ROOT  # long_context_eval directory

    # Try cache first (unless forced)
    use_cache = False
    cached = None if args.recompute else load_scores_cache(eval_root, args.model_key)
    if cached is not None:
        cached_datasets, cached_sparsities, cached_scores = cached
        if cached_datasets == list(args.datasets) and [float(s) for s in cached_sparsities] == list(args.sparsities):
            all_scores = cached_scores
            use_cache = True
        else:
            all_scores = {}
    else:
        all_scores = {}

    if not use_cache:
        # Compute scores for each dataset
        for dataset in args.datasets:
            all_scores[dataset] = compute_scores_for_dataset(dataset, args.model_key, eval_root, args.sparsities)
        # Save cache (json + pretty txt table)
        save_scores_cache(eval_root, args.model_key, list(args.datasets), list(args.sparsities), all_scores)

    print_table(args.datasets, list(args.sparsities), all_scores)
    fig_path = plot_lines(args.datasets, list(args.sparsities), all_scores, args.model_key, eval_root)
    print(f"Saved plot: {fig_path}")


if __name__ == "__main__":
    main()


