#!/usr/bin/env python3
"""
Parallel baseline profiling for KernelBench reference implementations.

This script profiles all reference operators under op_eval/reference/KernelBench/
using torch_npu.profiler to collect AIC metrics, WITHOUT requiring entries in dataset.py.

Performance timing comes from step_trace_time.csv (3-pass profiler approach).
Profiling metrics from kernel_details.csv with all kernels keyed by name.

Usage:
    python -m op_eval.generate_baseline_profiling --language ascendc --workers 8 --device-ids 0,1,2,3
    python -m op_eval.generate_baseline_profiling --levels level1,level2
"""

from __future__ import annotations

import argparse
import importlib
import json
import multiprocessing
import os
import signal
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence

import numpy as np

from op_eval.config import project_root_path
from op_eval.backends.backend_registry import BACKEND_REGISTRY

# Path to KernelBench reference implementations
KERNELBENCH_ROOT = Path(project_root_path) / "reference" / "KernelBench"


@dataclass(frozen=True)
class BaselineRequest:
    """Request to profile a single reference operator."""
    op: str
    language: str
    ref_path: Path
    level: str  # e.g., "level1", "level2"


def _ensure_backend(language: str):
    """Ensure the backend for the given language is loaded."""
    if language not in BACKEND_REGISTRY:
        try:
            importlib.import_module(f".backends.{language}_backend", package="op_eval")
        except ImportError:
            try:
                importlib.import_module(f"op_eval.backends.{language}_backend")
            except ImportError as e:
                raise ValueError(f"Unsupported language/platform: {language}") from e
    
    backend = BACKEND_REGISTRY.get(language)
    if backend is None:
        raise ValueError(f"Unsupported language/platform: {language}")
    return backend


def _profile_single_op(
    req: BaselineRequest,
    device_id: int,
    profiling_output_dir: Optional[str] = None,
    num_warmup: int = 3,
) -> Dict[str, Any]:
    """
    Profile a single reference operator using torch_npu.profiler.
    
    Returns a dict with:
    - performance: {mean, std, min, max} in ms from step_trace_time.csv
    - profiling: {kernel_name: {metrics}, ...} for all kernels
    - device: device name
    - level: KernelBench level
    - error: error message (if failed)
    """
    try:
        import torch
        import torch_npu
        
        # Set device
        torch_npu.npu.set_device(device_id)
        device = torch.device(f"npu:{device_id}")
        
        # Load backend
        backend = _ensure_backend(req.language)
        
        if hasattr(backend, 'setup_context'):
            backend.setup_context()
        
        # Read reference source
        with open(req.ref_path, 'r') as f:
            ref_src = f.read()
        
        # Execute in backend context
        exec(ref_src, backend.context)
        
        # Get model and inputs
        Model = backend.context.get('Model')
        get_inputs = backend.context.get('get_inputs')
        get_init_inputs = backend.context.get('get_init_inputs', lambda: [])
        
        if Model is None or get_inputs is None:
            return {"error": "Reference code missing Model or get_inputs"}
        
        # Prepare model and inputs
        inputs = get_inputs()
        init_inputs = get_init_inputs()
        
        if not isinstance(inputs, (list, tuple)):
            inputs = [inputs]
        if not isinstance(init_inputs, (list, tuple)):
            init_inputs = [init_inputs]
        
        inputs = [x.to(device) if isinstance(x, torch.Tensor) else x for x in inputs]
        init_inputs = [x.to(device) if isinstance(x, torch.Tensor) else x for x in init_inputs]
        
        model = Model(*init_inputs).to(device)
        model.eval()
        
        # Warmup runs (minimal, since profiler also does warmup)
        with torch.no_grad():
            for _ in range(num_warmup):
                model(*inputs)
                torch_npu.npu.synchronize(device=device)
        
        # Run profiling - provides both performance timing AND metrics
        # Performance comes from step_trace_time.csv Computing field (3 passes)
        # Profiling comes from kernel_details.csv with all kernels
        from op_eval.utils.performance import profile_execution
        
        def run_fn():
            model(*inputs)
        
        profiling_result = profile_execution(
            run_fn, device, req.op, profiling_output_dir,
        )
        
        result = {
            "device": backend.get_hardware_name() if hasattr(backend, 'get_hardware_name') else f"npu:{device_id}",
            "level": req.level,
        }
        
        if profiling_result:
            # performance: {mean, max, min, std} in ms from step_trace_time
            if profiling_result.get('performance'):
                result['performance'] = profiling_result['performance']
            
            # profiling: {kernel_name: {metrics}, ...}
            if profiling_result.get('profiling'):
                result['profiling'] = profiling_result['profiling']
        
        return result
        
    except Exception as e:
        import traceback
        return {"error": f"{type(e).__name__}: {str(e)}", "traceback": traceback.format_exc(), "level": req.level}


def _worker_process(
    req: BaselineRequest,
    device_queue: multiprocessing.Queue,
    result_queue: multiprocessing.Queue,
    profiling_output_dir: Optional[str],
    num_warmup: int,
    timeout_s: int,
):
    """Worker process for parallel baseline profiling."""
    # Redirect stdout to stderr
    try:
        sys.stdout.flush()
        os.dup2(sys.stderr.fileno(), sys.stdout.fileno())
    except Exception:
        pass
    
    # Set timeout
    def _timeout_handler(signum, frame):
        raise TimeoutError(f"Timeout profiling {req.op}")
    
    signal.signal(signal.SIGALRM, _timeout_handler)
    signal.alarm(timeout_s)
    
    try:
        # Get device from queue
        device_id = device_queue.get()
        
        try:
            result = _profile_single_op(
                req, device_id,
                profiling_output_dir=profiling_output_dir,
                num_warmup=num_warmup,
            )
            result_queue.put((req.op, result))
        finally:
            # Return device to queue
            device_queue.put(device_id)
    except TimeoutError:
        result_queue.put((req.op, {"error": "Timeout", "timeout_s": timeout_s, "level": req.level}))
    except Exception as e:
        result_queue.put((req.op, {"error": f"Critical: {str(e)}", "level": req.level}))
    finally:
        signal.alarm(0)


def run_parallel_baseline(
    requests: Sequence[BaselineRequest],
    *,
    device_ids: Sequence[int],
    max_workers: int = 4,
    profiling_output_dir: Optional[str] = None,
    num_warmup: int = 3,
    timeout_s: int = 300,
    result_path: Optional[Path] = None,
    resume: bool = True,
) -> Dict[str, Dict]:
    """
    Run baseline profiling in parallel.
    
    Args:
        requests: List of BaselineRequest objects
        device_ids: List of NPU device IDs to use
        max_workers: Maximum number of parallel workers
        profiling_output_dir: Directory to store profiling artifacts
        num_warmup: Number of warmup iterations
        timeout_s: Timeout per operator in seconds
        result_path: Path to save results (supports resumption)
        resume: Whether to resume from existing results
        
    Returns:
        Dict mapping op names to results
    """
    results: Dict[str, Dict] = {}
    
    # Resume from existing results
    if resume and result_path and result_path.exists():
        try:
            with open(result_path, 'r') as f:
                existing = json.load(f)
            for op, val in existing.items():
                if isinstance(val, dict) and "error" not in val:
                    results[op] = val
            print(f"[INFO] Resuming from {result_path}, {len(results)} ops already done.")
        except Exception as e:
            print(f"[WARN] Could not load {result_path}: {e}")
    
    # Filter requests already done
    pending = [r for r in requests if r.op not in results]
    if not pending:
        print("[INFO] All operators already profiled.")
        return results
    
    print(f"[INFO] Profiling {len(pending)} operators with {max_workers} workers on devices {list(device_ids)}")
    
    # Create device queue
    ctx = multiprocessing.get_context("spawn")
    device_queue = ctx.Queue()
    for dev_id in device_ids:
        device_queue.put(dev_id)
    
    result_queue = ctx.Queue()
    
    # Submit all tasks
    processes: List[multiprocessing.Process] = []
    for req in pending:
        p = ctx.Process(
            target=_worker_process,
            args=(req, device_queue, result_queue, profiling_output_dir, num_warmup, timeout_s),
        )
        processes.append(p)
    
    # Start workers up to max_workers
    running: List[multiprocessing.Process] = []
    pending_procs = list(processes)
    completed = 0
    total = len(pending)
    
    try:
        from tqdm import tqdm
        pbar = tqdm(total=total, desc="KernelBench Baseline")
    except ImportError:
        pbar = None
    
    while pending_procs or running:
        # Start new workers
        while pending_procs and len(running) < max_workers:
            p = pending_procs.pop(0)
            p.start()
            running.append(p)
        
        # Collect results
        while not result_queue.empty():
            op, result = result_queue.get_nowait()
            results[op] = result
            completed += 1
            
            if pbar:
                status = "OK" if "performance" in result else "ERR"
                pbar.set_postfix_str(f"{op}: {status}")
                pbar.update(1)
            else:
                print(f"[{completed}/{total}] {op}")
            
            # Periodic save
            if result_path and completed % 10 == 0:
                _save_results(result_path, results)
        
        # Clean up finished processes
        running = [p for p in running if p.is_alive()]
        
        # Short sleep to avoid busy loop
        time.sleep(0.1)
    
    if pbar:
        pbar.close()
    
    # Final save
    if result_path:
        _save_results(result_path, results)
    
    return results


def _save_results(path: Path, results: Dict) -> None:
    """Save results to JSON file."""
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, 'w') as f:
        json.dump(results, f, indent=2)


def discover_kernelbench_ops(
    levels: Optional[Sequence[str]] = None,
) -> List[tuple]:
    """
    Discover all operators in KernelBench directory.
    
    Args:
        levels: Optional list of levels to include (e.g., ["level1", "level2"])
    
    Returns:
        List of (op_name, ref_path, level) tuples
    """
    ops = []
    
    if not KERNELBENCH_ROOT.exists():
        print(f"[ERROR] KernelBench root not found: {KERNELBENCH_ROOT}")
        return ops
    
    # Get all level directories
    level_dirs = sorted([d for d in KERNELBENCH_ROOT.iterdir() if d.is_dir() and d.name.startswith("level")])
    
    # Filter by requested levels
    if levels:
        level_dirs = [d for d in level_dirs if d.name in levels]
    
    for level_dir in level_dirs:
        level_name = level_dir.name
        
        # Find all .py files in this level
        for py_file in sorted(level_dir.glob("*.py")):
            # Derive op name from filename (remove .py extension)
            op_name = py_file.stem
            
            # Create unique op key: level/op_name
            op_key = f"{level_name}/{op_name}"
            
            ops.append((op_key, py_file, level_name))
    
    return ops


def build_baseline_requests(
    language: str,
    levels: Optional[Sequence[str]] = None,
    ops: Optional[Sequence[str]] = None,
) -> List[BaselineRequest]:
    """Build baseline requests for KernelBench operators."""
    requests = []
    
    discovered = discover_kernelbench_ops(levels)
    print(f"[INFO] Discovered {len(discovered)} operators in KernelBench")
    
    for op_key, ref_path, level in discovered:
        # If specific ops requested, filter
        if ops:
            # Match by full key or just filename
            if op_key not in ops and ref_path.stem not in ops:
                continue
        
        requests.append(BaselineRequest(
            op=op_key,
            language=language,
            ref_path=ref_path,
            level=level,
        ))
    
    return requests


def main():
    parser = argparse.ArgumentParser(
        description="Generate baseline statistics for KernelBench reference implementations (without dataset.py)."
    )
    parser.add_argument(
        '--language', type=str, default='ascendc',
        help='Target language/backend (default: ascendc)'
    )
    parser.add_argument(
        '--workers', type=int, default=4,
        help='Number of parallel workers (default: 4)'
    )
    parser.add_argument(
        '--device-ids', type=str, default='0',
        help='Comma-separated list of NPU device IDs (default: 0)'
    )
    parser.add_argument(
        '--device-offset', type=int, default=0,
        help='Offset to add to device IDs (default: 0)'
    )
    parser.add_argument(
        '--levels', type=str, default=None,
        help='Comma-separated list of levels to profile (e.g., level1,level2). Default: all'
    )
    parser.add_argument(
        '--ops', type=str, default=None,
        help='Comma-separated list of specific operators to profile (default: all)'
    )
    parser.add_argument(
        '--warmup', type=int, default=3,
        help='Number of warmup iterations (default: 3)'
    )
    parser.add_argument(
        '--timeout', type=int, default=300,
        help='Timeout per operator in seconds (default: 300)'
    )
    parser.add_argument(
        '--output', type=str, default=None,
        help='Output JSON file path (default: baselines/kernelbench_<device>.json)'
    )
    parser.add_argument(
        '--profile-dir', type=str, default=None,
        help='Directory to store profiling artifacts (default: temp)'
    )
    parser.add_argument(
        '--no-resume', action='store_true',
        help='Start fresh, ignoring existing results'
    )
    
    args = parser.parse_args()
    
    # Parse device IDs
    device_ids = [int(d.strip()) + args.device_offset for d in args.device_ids.split(',')]
    
    # Parse levels
    levels = [l.strip() for l in args.levels.split(',')] if args.levels else None
    
    # Parse ops
    ops = [o.strip() for o in args.ops.split(',')] if args.ops else None
    
    # Build requests
    requests = build_baseline_requests(args.language, levels, ops)
    if not requests:
        print("[ERROR] No valid operators found")
        sys.exit(1)
    
    print(f"[INFO] Found {len(requests)} operators to profile")
    
    # Determine output path
    if args.output:
        result_path = Path(args.output)
    else:
        # Try to get device name from first device
        device_name = f"npu{device_ids[0]}"
        try:
            import torch_npu
            torch_npu.npu.set_device(device_ids[0])
            device_name = torch_npu.npu.get_device_name(device_ids[0]).replace(" ", "_")
        except Exception:
            pass
        
        baselines_dir = Path(project_root_path) / 'baselines'
        level_suffix = "_".join(levels) if levels else "all"
        result_path = baselines_dir / f'kernelbench_{level_suffix}_{device_name}.json'
    
    print(f"[INFO] Output: {result_path}")
    
    # Run parallel profiling
    results = run_parallel_baseline(
        requests,
        device_ids=device_ids,
        max_workers=args.workers,
        profiling_output_dir=args.profile_dir,
        num_warmup=args.warmup,
        timeout_s=args.timeout,
        result_path=result_path,
        resume=not args.no_resume,
    )
    
    # Summary
    success = sum(1 for r in results.values() if "performance" in r)
    failed = len(results) - success
    print(f"\n[DONE] {success} succeeded, {failed} failed")
    print(f"[INFO] Results saved to {result_path}")
    
    # Level breakdown
    level_stats = {}
    for op, r in results.items():
        level = r.get("level", "unknown")
        if level not in level_stats:
            level_stats[level] = {"success": 0, "failed": 0}
        if "performance" in r:
            level_stats[level]["success"] += 1
        else:
            level_stats[level]["failed"] += 1
    
    print("\n[LEVEL BREAKDOWN]")
    for level, stats in sorted(level_stats.items()):
        print(f"  {level}: {stats['success']} OK, {stats['failed']} ERR")


if __name__ == "__main__":
    main()
