import importlib
import re
import time
import os

import numpy as np
from contextlib import nullcontext

from op_eval.backends.backend_registry import BACKEND_REGISTRY
from op_eval.dataset import dataset
from op_eval.utils.utils import get_ref_src_path


def extract_first_code(output_string: str, code_language_types: list[str]) -> str:
    """
    Extract first code block from model output, specified by code_language_type
    """
    trimmed = output_string.strip()

    # Extracting the first occurrence of content between backticks
    code_match = re.search(r"```(.*?)```", trimmed, re.DOTALL)

    if code_match:
        # Strip leading and trailing whitespace from the extracted code
        code_block = code_match.group(1).strip()

        # depends on code_language_type: cpp, python, etc.
        # sometimes the block of code is ```cpp ... ``` instead of ``` ... ```
        # in this case strip the cpp out
        for code_type in code_language_types:
            if code_block.startswith(code_type):
                code_block = code_block[len(code_type) :].strip()
        return code_block

    return None

def _ensure_backend(language: str):
    if language not in BACKEND_REGISTRY:
        try:
            importlib.import_module(f"op_eval.backends.{language}_backend")
        except ImportError as e:
            print(f"[ERROR] Failed to import backend op_eval.backends.{language}_backend: {e}")
            raise ValueError(f"Unsupported language/platform: {language} (module not found)") from e
    backend = BACKEND_REGISTRY.get(language)
    if backend is None:
        raise ValueError(f"Unsupported language/platform: {language}")
    backend.language = language
    return backend


def _init_result(backend) -> dict:
    hardware = backend.get_hardware_name()
    return {
        "compiled": False,
        "correctness": None,
        "performance": None,
        "hardware": hardware,
        "profiling": None,
    }


def _finalize_result(backend, result: dict) -> dict:
    # Wrap cleanup in try/except - NPU may be in crashed state and cleanup will throw
    try:
        backend.cleanup()
    except Exception:
        pass  # Ignore cleanup errors - NPU may be irrecoverably crashed
    return result


def _compile_backend(
    response_txt: str,
    op: str,
    backend,
    result: dict,
) -> bool:
    import sys
    print(f"[DEBUG] _compile_backend: Compiling {op}...", file=sys.stderr, flush=True)
    generated_code = extract_first_code(response_txt, ['python', 'cpp'])
    if generated_code is None:
        generated_code = response_txt
    compiled, compile_info = backend.compile(generated_code, op)
    if not compiled:
        print(f"[DEBUG] _compile_backend: Compilation FAILED: {compile_info[:200] if compile_info else 'None'}", file=sys.stderr, flush=True)
        result["compile_info"] = compile_info
        return False
    print(f"[DEBUG] _compile_backend: Compilation SUCCESS", file=sys.stderr, flush=True)
    result["compiled"] = True
    return True


def compile_single(response_txt: str, op: str, language: str):
    """
    Compile a single operator and return backend + result without running evaluation.
    Caller is responsible for cleanup via the returned finalize callback.
    """
    backend = _ensure_backend(language)
    result = _init_result(backend)

    def _finalize():
        return _finalize_result(backend, result)

    has_reference = op in dataset
    try:
        compiled = _compile_backend(response_txt, op, backend, result)
        return backend, result, _finalize, has_reference, compiled
    except Exception as e:
        import sys
        import traceback
        print(f"[FAIL] compile_single: Kernel exception for {op}: {e}", file=sys.stderr, flush=True)
        traceback.print_exc(file=sys.stderr)
        result["compiled"] = result.get("compiled", False)
        result["correctness"] = False
        result["correctness_info"] = f"[FAIL] Kernel exception: {str(e)}"
        return backend, result, _finalize, has_reference, False


def evaluate_compiled(
    backend,
    op: str,
    result: dict,
    has_reference: bool,
    *,
    device_id=None,
):
    """Run correctness + profiling on a compiled backend."""
    def _finalize():
        return _finalize_result(backend, result)

    try:
        if device_id is not None and hasattr(backend, "set_device"):
            backend.set_device(device_id)
        return _eval_after_compile(op, backend, has_reference, result, _finalize, device_id)
    except Exception as e:
        import sys
        import traceback
        print(f"[FAIL] evaluate_compiled: Kernel exception for {op}: {e}", file=sys.stderr, flush=True)
        traceback.print_exc(file=sys.stderr)
        result["correctness"] = False
        result["correctness_info"] = f"[FAIL] Kernel exception: {str(e)}"
        return _finalize()


def eval_single(
    response_txt: str,
    op,
    language,
    device_id=None,
):
    """
    Evaluate a single operator code submission.
    
    Timeouts are handled granularly in correctness.py:
    - FIRST_RUN_TIMEOUT_S for the first correctness run
    Overall evaluation is bounded by the worker timeout in the server.
    """
    backend = _ensure_backend(language)
    has_reference = op in dataset

    result = _init_result(backend)
    def _finalize():
        return _finalize_result(backend, result)

    # Run evaluation
    # Wrap in try/except to catch ANY kernel exceptions (crash, timeout, etc.)
    # and convert to proper failure result (no 'error' field) 
    try:
        if device_id is not None:
            if hasattr(backend, "set_device"):
                backend.set_device(device_id)
        compiled = _compile_backend(response_txt, op, backend, result)
        if not compiled:
            return _finalize()
        return _eval_after_compile(op, backend, has_reference, result, _finalize, device_id)
    except Exception as e:
        # Kernel crashed somewhere - this is a FAILURE, not an infrastructure ERROR
        import sys
        import traceback
        print(f"[FAIL] eval_single: Kernel exception for {op}: {e}", file=sys.stderr, flush=True)
        traceback.print_exc(file=sys.stderr)
        result['compiled'] = result.get('compiled', False)
        result['correctness'] = False
        result['correctness_info'] = f"[FAIL] Kernel exception: {str(e)}"
        return _finalize()


def _eval_after_compile(
    op,
    backend,
    has_reference: bool,
    result: dict,
    _finalize,
    device_id,
):
    """
    Evaluation implementation after successful compilation.
    """
    import sys
    print(f"[DEBUG] _eval_after_compile: Starting op={op}, device_id={device_id}", file=sys.stderr, flush=True)

    if device_id is not None and hasattr(backend, "get_device"):
        try:
            print(f"[DEBUG] _eval_after_compile: Device set to {backend.get_device()}", file=sys.stderr, flush=True)
        except Exception:
            pass

    if not has_reference:
        result["correctness_info"] = "Skipped: no reference implementation available"
        return _finalize()
    ref_src_path = get_ref_src_path(op)
    with open(ref_src_path, 'r') as f:
        ref_src = f.read()

    exec(ref_src, backend.context)
    
    Model = backend.context.get('Model')
    get_inputs = backend.context.get('get_inputs', None)
    get_init_inputs = backend.context.get('get_init_inputs', lambda: [])

    if Model is None or get_inputs is None:
        result['correctness_info'] = 'Reference code missing Model or get_inputs'
        return _finalize()
    
    backend.context['_op_eval_op_name'] = op

    # Run correctness check with new timeout-aware function
    print(f"[DEBUG] _eval_after_compile: Starting correctness check...", file=sys.stderr, flush=True)
    from op_eval.utils.correctness import execute_template
    
    try:
        correct, info, first_run_time_s = execute_template(
            synchronize=backend.synchronize,
            device=backend.get_device(),
            context=backend.context,
            op_name=op,
        )
        print(f"[DEBUG] _eval_after_compile: Correctness check returned: correct={correct}, info={info[:100] if info else 'None'}", file=sys.stderr, flush=True)
    except Exception as e:
        print(f"[DEBUG] _eval_after_compile: Correctness check EXCEPTION: {e}", file=sys.stderr, flush=True)
        import traceback
        traceback.print_exc()
        result['correctness'] = False
        result['correctness_info'] = f"[FAIL] Exception during correctness: {str(e)}"
        return _finalize()
    
    result['correctness'] = correct
    result['correctness_info'] = info
    
    if not correct:
        # Correctness failed - skip profiling
        return _finalize()

    backend_language = (getattr(backend, "language", None) or "").lower()
    if backend_language != "ascendc":
        print(f"[DEBUG] _eval_after_compile: Skipping profiling for backend={backend_language or 'unknown'}", file=sys.stderr, flush=True)
        result["performance"] = None
        result["profiling"] = None
        return _finalize()

    # Full profiling for correct operators
    print(f"[DEBUG] _eval_after_compile: Starting profiling...", file=sys.stderr, flush=True)
    from op_eval.utils.performance import time_execution_event_template
    try:
        elapsed_times, profiling = time_execution_event_template(
            backend.context,
            device=backend.get_device(),
            synchronize=backend.synchronize,
            event_class=backend.event_class,
            eval_target=backend.model_key,
        )
        print(f"[DEBUG] _eval_after_compile: Profiling complete. elapsed_times count={len(elapsed_times) if elapsed_times else 0}", file=sys.stderr, flush=True)
        if profiling and profiling.get('performance'):
            result['performance'] = profiling['performance']
        elif elapsed_times:
            result['performance'] = {
                'mean': float(f"{np.mean(elapsed_times):.3f}"),
                'std': float(f"{np.std(elapsed_times):.3f}"),
            }
        if profiling:
            # Extract only the kernel metrics, not the redundant performance/profiling wrapper
            result["profiling"] = profiling.get('profiling', profiling)
            
    except TimeoutError:
        result['performance'] = {
            "mean": None,
            "std": None, 
            "note": "Timeout during performance runs"
        }
    except Exception as e:
        # Kernel crashed during profiling (e.g. vector core exception)
        # Return partial result with correctness=True but profiling failed
        print(f"[WARN] _eval_after_compile: Profiling failed with exception: {e}", file=sys.stderr, flush=True)
        result['performance'] = None
        result['profiling_error'] = f"Kernel crashed during profiling: {str(e)}"
    
    return _finalize()
