"""Summarize evaluation scores across datasets and recall targets for a fixed
mask-topk base ratio, and plot lines.

This script mirrors plot_sparsity_ablation.py but for the mask_topk_recall
experiment. It loads the evaluation routine from long_context_eval/evaluate.py
and computes scores for a set of LongBench datasets at several target recalls
using a fixed base top-k ratio (default 0.05). Results are printed as a compact
table, and a line-plot (accuracy vs. recall) is saved where each line
corresponds to one dataset.

Inputs are read from long_context_eval/results/<model_key>/<dataset>/...
  - Recall grid files (base_ratio fixed):
    masktopk_recall_grid/mask_topk_recall_top{BASE_RATIO:0.0000}_recall{R}.jsonl

Command line usage (examples):
  python3 plot_recall_trend.py \
    --model_key llama-3.1 \
    --datasets narrativeqa trec gov_report lcc multifieldqa_en multifieldqa_zh \
    --base_ratio 0.05 \
    --recalls 70 75 80 85 90 95

The line plot is saved to:
  long_context_eval/figures/<model_key>/recall_trend/accuracy_vs_recall.png

Time complexity: O(D * R * N * L)
  - D: number of datasets
  - R: number of recall targets
  - N: number of examples per dataset result file
  - L: average tokenized string length for metric computation
Space complexity: O(1) auxiliary (files processed line-by-line).
"""

from __future__ import annotations

import argparse
import os
import sys
from typing import Dict, Iterable, List, Optional, Tuple
import json
import numpy as np
import matplotlib.pyplot as plt

# Global plotting style (match plot_sparsity_ablation.py)
plt.rcParams.update({
    "font.size": 10,
    "axes.titlesize": 12,
    "axes.labelsize": 11,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
    "legend.fontsize": 9,
    "figure.figsize": (3.5, 2.5),
})


# Ensure we can import long_context_eval/evaluate.py as a module named "evaluate"
_CUR_DIR = os.path.dirname(os.path.abspath(__file__))
_EVAL_ROOT = os.path.abspath(os.path.join(_CUR_DIR, os.pardir))
if _EVAL_ROOT not in sys.path:
    sys.path.insert(0, _EVAL_ROOT)

from evaluate import evaluate_file  # type: ignore  # local module


DEFAULT_DATASETS: List[str] = [
    "trec",
    # "multifieldqa_zh",
    # "multifieldqa_en",
    "repobench-p",
    "lcc",
    "lsht",
]

DEFAULT_RECALLS: List[int] = [70, 75, 80, 85, 90, 95]
DEFAULT_BASE_RATIO: float = 0.05


def _format_score(score: Optional[float]) -> str:
    """Format score to two decimals or return '-' if missing."""
    if score is None:
        return "-"
    return f"{score:.2f}"


def _file_exists(path: str) -> bool:
    return os.path.isfile(path)


def compute_scores_for_dataset(
    dataset: str,
    model_key: str,
    eval_root: str,
    base_ratio: float,
    recalls: Iterable[int],
) -> Dict[str, Optional[float]]:
    """Compute scores at various recall targets for one dataset.

    Returns a dict with keys f"{r:d}" for each recall level.
    """
    results_dir = os.path.join(eval_root, "results", model_key, dataset)

    out: Dict[str, Optional[float]] = {}

    # Recall grid
    for r in recalls:
        key = f"{int(r):d}"
        fname = f"mask_topk_recall_top{float(base_ratio):0.4f}_recall{int(r)}.jsonl"
        r_path = os.path.join(results_dir, "masktopk_recall_grid", fname)
        out[key] = evaluate_file(r_path, dataset) if _file_exists(r_path) else None

    return out


def format_table(
    datasets: List[str],
    recalls: List[int],
    scores: Dict[str, Dict[str, Optional[float]]],
) -> str:
    """Return a simple fixed-width table of scores as a string."""
    # Build headers
    headers: List[str] = ["dataset"] + [f"{int(r):d}" for r in recalls]

    # Column widths
    col_widths: List[int] = [max(7, max(len(d) for d in datasets))]
    for r in recalls:
        col_widths.append(max(6, len(f"{int(r):d}")))

    # Header row
    header_cells: List[str] = []
    for i, h in enumerate(headers):
        header_cells.append(h.ljust(col_widths[i]))
    lines: List[str] = [" ".join(header_cells)]

    # Separator
    lines.append(" ".join("-" * w for w in col_widths))

    # Rows
    for dataset in datasets:
        row_cells: List[str] = []
        row_cells.append(dataset.ljust(col_widths[0]))
        for j, r in enumerate(recalls):
            key = f"{int(r):d}"
            row_cells.append(_format_score(scores[dataset].get(key)).rjust(col_widths[1 + j]))
        lines.append(" ".join(row_cells))
    return "\n".join(lines)


def print_table(
    datasets: List[str],
    recalls: List[int],
    scores: Dict[str, Dict[str, Optional[float]]],
) -> None:
    print(format_table(datasets, recalls, scores))


def plot_lines(
    datasets: List[str],
    recalls: List[int],
    scores: Dict[str, Dict[str, Optional[float]]],
    model_key: str,
    eval_root: str,
) -> str:
    """Plot accuracy vs recall for each dataset.

    Returns the path of the saved figure.
    """
    # X: recall values (percent)
    x_vals = np.array([int(r) for r in recalls], dtype=float)
    x_order = np.argsort(x_vals)

    plt.figure(figsize=(3.5, 2.5))
    for dataset in datasets:
        y_vals: List[Optional[float]] = []
        for r in recalls:
            key = f"{int(r):d}"
            y_vals.append(scores[dataset].get(key))
        y = np.array([np.nan if v is None else float(v) for v in y_vals], dtype=float)
        plt.plot(x_vals[x_order], y[x_order], marker="o", label=dataset)

    plt.xlabel("Recall target (%)")
    plt.ylabel("Accuracy")
    plt.ylim(-10, 85)
    plt.grid(True, linestyle="--", alpha=0.3)
    plt.legend(frameon=False, loc="lower left")
    plt.gca().invert_xaxis()

    out_dir = os.path.join(eval_root, "figures", model_key, "recall_trend")
    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, "accuracy_vs_recall.png")
    plt.tight_layout()
    plt.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close()
    return out_path


def _cache_paths(eval_root: str, model_key: str) -> Tuple[str, str]:
    out_dir = os.path.join(eval_root, "figures", model_key, "recall_trend")
    os.makedirs(out_dir, exist_ok=True)
    return os.path.join(out_dir, "scores_table.json"), os.path.join(out_dir, "scores_table.txt")


def save_scores_cache(
    eval_root: str,
    model_key: str,
    datasets: List[str],
    recalls: List[int],
    base_ratio: float,
    scores: Dict[str, Dict[str, Optional[float]]],
) -> None:
    json_path, txt_path = _cache_paths(eval_root, model_key)
    payload = {
        "model_key": model_key,
        "datasets": datasets,
        "recalls": [int(r) for r in recalls],
        "base_ratio": float(base_ratio),
        "scores": scores,
    }
    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, ensure_ascii=False, indent=2)
    table_str = format_table(datasets, recalls, scores)
    with open(txt_path, "w", encoding="utf-8") as f:
        f.write(table_str + "\n")


def load_scores_cache(
    eval_root: str,
    model_key: str,
) -> Optional[Tuple[List[str], List[int], float, Dict[str, Dict[str, Optional[float]]]]]:
    json_path, _ = _cache_paths(eval_root, model_key)
    if not os.path.isfile(json_path):
        return None
    try:
        with open(json_path, "r", encoding="utf-8") as f:
            payload = json.load(f)
        if payload.get("model_key") != model_key:
            return None
        datasets = list(payload.get("datasets", []))
        recalls = [int(x) for x in payload.get("recalls", [])]
        base_ratio = float(payload.get("base_ratio", DEFAULT_BASE_RATIO))
        scores = payload.get("scores", {})
        if not isinstance(scores, dict):
            return None
        return datasets, recalls, base_ratio, scores
    except Exception:
        return None


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Print evaluation table over datasets and recall targets.")
    parser.add_argument("--model_key", type=str, default="llama-3.1", help="Model key under results directory")
    parser.add_argument(
        "--datasets",
        type=str,
        nargs="*",
        default=DEFAULT_DATASETS,
        help="Datasets to summarize",
    )
    parser.add_argument(
        "--recalls",
        type=int,
        nargs="*",
        default=DEFAULT_RECALLS,
        help="Recall targets to summarize (percent)",
    )
    parser.add_argument(
        "--base_ratio",
        type=float,
        default=DEFAULT_BASE_RATIO,
        help="Base top-k ratio used for the recall runs (e.g., 0.05)",
    )
    parser.add_argument(
        "--recompute",
        action="store_true",
        help="Recompute scores even if a cached table exists",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    eval_root = _EVAL_ROOT  # long_context_eval directory

    # Try cache first (unless forced)
    use_cache = False
    cached = None if args.recompute else load_scores_cache(eval_root, args.model_key)
    if cached is not None:
        cached_datasets, cached_recalls, cached_base_ratio, cached_scores = cached
        if (
            cached_datasets == list(args.datasets)
            and [int(r) for r in cached_recalls] == [int(r) for r in args.recalls]
            and float(cached_base_ratio) == float(args.base_ratio)
        ):
            all_scores = cached_scores
            use_cache = True
        else:
            all_scores = {}
    else:
        all_scores = {}

    if not use_cache:
        # Compute scores for each dataset
        for dataset in args.datasets:
            all_scores[dataset] = compute_scores_for_dataset(
                dataset, args.model_key, eval_root, args.base_ratio, args.recalls
            )
        # Save cache (json + pretty txt table)
        save_scores_cache(
            eval_root, args.model_key, list(args.datasets), [int(r) for r in args.recalls], float(args.base_ratio), all_scores
        )

    print_table(args.datasets, [int(r) for r in args.recalls], all_scores)
    fig_path = plot_lines(args.datasets, [int(r) for r in args.recalls], all_scores, args.model_key, eval_root)
    print(f"Saved plot: {fig_path}")


if __name__ == "__main__":
    main()



