from __future__ import annotations


import os
import sys
import shlex
import argparse
from pathlib import Path
from op_eval.config import check_ascend_env

from op_eval.batch import build_requests_from_dir, evaluate_requests





def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="op_eval: Evaluate generated kernels from a folder of *.txt files.")
    parser.add_argument("--input-dir", type=Path, required=True, help="Directory containing *.txt responses")
    parser.add_argument("--language", required=True, help="Language/backend (e.g., ascendc, cuda)")
    parser.add_argument("--ops", nargs="+", help="Optional subset of operator names to evaluate")
    parser.add_argument("--runs", type=int, default=1, help="Number of runs")
    parser.add_argument("--build-workers", type=int, default=4, help="Max concurrent build/eval tasks")
    parser.add_argument("--num-devices", type=int, default=8, help="Number of devices to lease for Ascend backends")
    parser.add_argument("--device-offset", type=int, default=0, help="Device index offset")
    parser.add_argument("--run-name", default="default", help="Workspace label for Ascend builds")
    parser.add_argument("--result-path", type=Path, help="Where to write the aggregated JSON results")
    parser.add_argument("--categories", nargs="+", default=["all"], help="Filter ops by category (e.g., enable 'activation')")
    parser.add_argument("--levels", nargs="+", default=["all"], help="Filter ops by level (e.g., 'level1')")
    parser.add_argument("--merge-code", type=Path, help="Target JSON file for merged results (defaults to {input_dir}/all_results.json)")
    parser.add_argument("--disable-npu-parallelism", action="store_true", help="Disable concurrent tasks per NPU device (default: parallelism enabled, matching server behavior).")
    parser.add_argument("--remote-url", type=str, help="URL of the remote op_eval_server (e.g. http://host:5000). If set, evaluation runs remotely.")
    parser.add_argument("--skip-dataset-check", action="store_true", help="Skip dataset validation; compile and evaluate even if operator is not in dataset.")
    return parser.parse_args()


def main():
    args = parse_args()
    if not args.remote_url:
        check_ascend_env()

    # Default merge-code path if not specified
    if args.merge_code is None:
        args.merge_code = args.input_dir / "all_results.json"
    
    # Filter ops based on category/level if specified
    from op_eval.dataset import dataset

    allowed_ops = None
    
    # If explicit ops are given, start with them. Otherwise consider all known ops.
    if args.ops:
        current_selection = set(args.ops)
    else:
        # If no ops specified, we don't start with 'all dataset keys' because 
        # build_requests_from_dir scans the directory. We only apply filters if they are not 'all'.
        current_selection = None

    # Apply category/level filters
    filtering_active = (args.categories != ["all"]) or (args.levels != ["all"])
    
    if filtering_active:
        filtered_group = set()
        
        for op, meta in dataset.items():
            cat = meta.get("category")
            lvl = meta.get("level")
            
            # Check category
            cat_match = (args.categories == ["all"]) or (cat in args.categories)
            # Check level
            lvl_match = (args.levels == ["all"]) or (lvl in args.levels)
            
            if cat_match and lvl_match:
                filtered_group.add(op)
        
        if current_selection is not None:
            current_selection = current_selection.intersection(filtered_group)
        else:
            current_selection = filtered_group
            
    # Convert back to list or None
    final_ops = list(current_selection) if current_selection is not None else None

    requests = build_requests_from_dir(
        input_dir=args.input_dir,
        language=args.language,
        ops=final_ops,
        run_name=args.run_name,
        skip_dataset_check=args.skip_dataset_check,
    )
    if not requests:
        raise SystemExit("No requests found to evaluate.")

    result_path = args.result_path
    if result_path is None:
        result_path = args.input_dir / "results.json"

    progress_cb = None
    if args.merge_code:
        # Define callback to merge on flight
        from op_eval.merge import merge_results_with_code
        def _incremental_merge():
            print(f"[INFO] Updating merged results at {args.merge_code}...")
            merge_results_with_code(result_path, args.merge_code)
        progress_cb = _incremental_merge

    evaluate_requests(
        requests,
        runs=args.runs,
        build_workers=args.build_workers,
        device_ids=list(range(args.num_devices)),
        device_offset=args.device_offset,
        result_path=result_path,
        enable_npu_parallelism=not args.disable_npu_parallelism,
        progress_callback=progress_cb,
        remote_url=args.remote_url,
    )
    print(f"[INFO] Evaluation complete. Results written to {result_path}")
    
    if args.merge_code:
        from op_eval.merge import merge_results_with_code
        print(f"[INFO] Merging results with code into {args.merge_code}...")
        merge_results_with_code(result_path, args.merge_code)


if __name__ == "__main__":
    main()
