"""Entry point for running long-context benchmarks with optional sparse attention.

Usage (example):
    export PYTHONNOUSERSITE=1
    python3 run_benchmark.py \
      --model_key llama-3.1 \
      --exp_name my_exp \
      --use_sparse_attention \
      --datasets narrativeqa qmsum\
      --num_samples 10
This entry composes tokenizer/model, evaluator, optional sparse patching, and
dataset loading.
"""

from __future__ import annotations

import argparse
import json
import os
from typing import List

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

from src.utils.timer import GlobalTimer
from src.utils.helpers import seed_everything
from src.patching.patcher import (
    patch_model_for_sparse_attention,
    patch_model_for_mask_topk,
    patch_model_for_mask_topk_recall,
)
from src.evaluation.evaluator import BenchmarkEvaluator


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Run LongBench with optional sparse attention")
    parser.add_argument("--model_key", type=str, required=True, help="Key in configs/model2path.json")
    parser.add_argument("--exp_name", type=str, required=True, help="Experiment name for output directory")
    parser.add_argument("--datasets", type=str, nargs="+", required=True, help="Dataset names from LongBench")
    # New: attention strategy (none | mask_topk). No pq_sparse here per request.
    parser.add_argument("--attention_strategy", type=str, default="none", choices=["none", "mask_topk", "mask_topk_recall"],
                        help="Attention strategy to use during decoding")
    parser.add_argument("--sparsity_list", type=float, nargs="*", default=None,
                        help="List of sparsity ratios for mask_topk, e.g., 0.01 0.05 0.10")
    parser.add_argument("--recall_list", type=float, nargs="*", default=None,
                        help="Recall list for mask_topk_recall. Values can be in [0,1] or [0,100].")
    parser.add_argument("--base_ratio", type=float, default=0.05,
                        help="Base top-k ratio for mask_topk_recall (default 0.05)")
    parser.add_argument(
        "--num_samples",
        type=int,
        default=None,
        help="Number of test samples to evaluate per dataset (default: all)",
    )
    parser.add_argument("--seed", type=int, default=42)
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    seed_everything(args.seed)

    base_dir = os.path.dirname(os.path.abspath(__file__))
    cfg_dir = os.path.join(base_dir, "configs")
    model2path = json.load(open(os.path.join(cfg_dir, "model2path.json"), "r"))
    model2maxlen = json.load(open(os.path.join(cfg_dir, "model2maxlen.json"), "r"))
    dataset2prompt = json.load(open(os.path.join(cfg_dir, "dataset2prompt.json"), "r"))
    dataset2maxlen = json.load(open(os.path.join(cfg_dir, "dataset2maxlen.json"), "r"))

    model_path = model2path[args.model_key]
    model_name = args.model_key
    max_length = model2maxlen[args.model_key]

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype="auto")
    model.eval()

    # Patch model according to strategy
    if args.attention_strategy == "mask_topk":
        default_ratio = 0.05 if not args.sparsity_list else float(args.sparsity_list[0])
        patch_model_for_mask_topk(model, sparsity_ratio=default_ratio)
    elif args.attention_strategy == "mask_topk_recall":
        base_ratio = float(args.base_ratio)
        # Use first recall or default 0.7
        default_recall = 0.7 if not args.recall_list else float(args.recall_list[0])
        patch_model_for_mask_topk_recall(
            model,
            base_sparsity_ratio=base_ratio,
            recall_ratio=default_recall if default_recall <= 1.0 else default_recall / 100.0,
        )

    global_timer = GlobalTimer()
    evaluator = BenchmarkEvaluator(model, tokenizer, global_timer)

    for dataset_name in args.datasets:
        print(f"Evaluating on {dataset_name}...")
        data = load_dataset("THUDM/LongBench", dataset_name, split="test")
        subset_size = min(args.num_samples, len(data)) if args.num_samples and args.num_samples > 0 else len(data)
        data = data.select(range(subset_size))
        samples = list(data)
        prompt_format = dataset2prompt[dataset_name]
        max_gen = dataset2maxlen[dataset_name]

        out_dir = os.path.join(base_dir, "results", model_name, dataset_name, args.exp_name)
        os.makedirs(out_dir, exist_ok=True)
        if args.attention_strategy == "mask_topk" and args.sparsity_list:
            evaluator.run_mask_topk_multi(
                samples, prompt_format, max_gen, out_dir, args.sparsity_list,
                model_name=model_name, max_length=max_length, dataset_name=dataset_name
            )
        elif args.attention_strategy == "mask_topk_recall" and args.recall_list:
            evaluator.run_mask_topk_recall_multi(
                samples, prompt_format, max_gen, out_dir, args.base_ratio, args.recall_list,
                model_name=model_name, max_length=max_length, dataset_name=dataset_name
            )
        else:
            filename = "baseline.jsonl"
            out_path = os.path.join(out_dir, filename)
            evaluator.run(samples, prompt_format, max_gen, out_path,
                          model_name=model_name, max_length=max_length, dataset_name=dataset_name)


if __name__ == "__main__":
    main()



