import os
import sys
import json
import tempfile
import shutil
import traceback
import argparse
import multiprocessing
import signal
import threading
import pickle
import io
import torch
from flask import Flask, request, jsonify

# op_eval_framework/op_eval/server.py
# This module provides the HTTP server for operator evaluation.

from op_eval.utils.evaluation_utils import compile_single, evaluate_compiled
from op_eval.config import OP_EVAL_SLOT_TIMEOUT_S, OP_EVAL_WORKER_TIMEOUT_S, check_ascend_env

import queue

app = Flask(__name__)

# Global configuration defaults
DEFAULT_PORT = 5000

# Device pool: queue-based allocation ensures exactly 1 task per NPU
# - Each device ID is in the queue once (or N times for max_ops_per_npu > 1)
# - Blocking get() acquires a device, put() returns it
# - Simple and deadlock-free: single queue, FIFO order
DEVICE_POOL = None
DEVICE_LEASES = None
DEVICE_LEASES_MANAGER = None

ASCEND_LANGUAGES = {"ascendc", "triton_ascend", "tilelang_ascend"}
CUDA_LANGUAGES = {"cuda", "triton"}

def run_worker(op, code_content, language, run_name, device_pool, device_leases, result_queue, workspace_dir, slot_timeout_s):
    """
    Worker process to compile and evaluate a single operator.
    Compiles (in parallel) then executes on the assigned device.
    
    Timeouts are handled granularly in correctness.py (FIRST_RUN_TIMEOUT_S).
    """
    device_id = None
    lease_key = None
    try:
        print(f"[DEBUG] run_worker: Starting op={op}, pid={os.getpid()}", file=sys.stderr, flush=True)
        
        # Enforce workspace location
        os.environ["ASCEND_OP_ROOT"] = workspace_dir
        os.environ["ASCEND_OP_RUNS_ROOT"] = workspace_dir
        os.environ.setdefault("OP_EVAL_SKIP_NPU_CLEANUP", "1")

        # Keep temp files in workspace
        tmp_scratch = os.path.join(workspace_dir, "_tmp")
        os.makedirs(tmp_scratch, exist_ok=True)
        os.environ.setdefault("TMPDIR", tmp_scratch)
        os.environ.setdefault("TMP", tmp_scratch)
        os.environ.setdefault("TEMP", tmp_scratch)

        # Compile first (CPU-bound) without holding a device slot.
        print(f"[DEBUG] run_worker: Compiling {op}...", file=sys.stderr, flush=True)
        backend, result, finalize, has_reference, compiled = compile_single(code_content, op, language)
        if not compiled:
            result_queue.put(finalize())
            return

        if device_pool is None:
            result["error"] = "Server not initialized with device pool"
            result_queue.put(finalize())
            return

        try:
            device_id = device_pool.get(timeout=slot_timeout_s)
        except queue.Empty:
            result["error"] = f"No evaluation slots available after {slot_timeout_s}s wait."
            result["error_type"] = "slot_timeout"
            result_queue.put(finalize())
            return

        lease_key = str(os.getpid())
        if device_leases is not None:
            device_leases[lease_key] = device_id

        os.environ["ASCEND_DEVICE_ID"] = str(device_id)
        os.environ["DEVICE_ID"] = str(device_id)

        # CRITICAL: Set NPU device BEFORE any torch_npu operations
        # This prevents default allocation on NPU 0 during backend initialization
        language_key = str(language).lower()
        if language_key in ASCEND_LANGUAGES:
            try:
                import torch_npu
                torch_npu.npu.set_device(device_id)
                print(f"[DEBUG] run_worker: NPU device set to {device_id}", file=sys.stderr, flush=True)
            except Exception as e:
                print(f"[WARN] Failed to set NPU device {device_id}: {e}", file=sys.stderr, flush=True)

        print(f"[DEBUG] run_worker: Evaluating {op} on device {device_id}...", file=sys.stderr, flush=True)
        result = evaluate_compiled(
            backend,
            op,
            result,
            has_reference,
            device_id=device_id,
        )
        print(f"[DEBUG] run_worker: Evaluation returned. compiled={result.get('compiled')}, correctness={result.get('correctness')}", file=sys.stderr, flush=True)
        result_queue.put(result)
    except Exception as e:
        # Catch-all for process crashes/errors
        print(f"[ERROR] run_worker: Worker exception for {op}: {e}", file=sys.stderr, flush=True)
        traceback.print_exc()
        result_queue.put({
            "compiled": False,
            "correctness": None,
            "performance": None,
            "error": f"Worker Process Failure: {str(e)}"
        })
    finally:
        # Device slots are reclaimed by the parent after join().
        pass

@app.route('/health', methods=['GET'])
def health():
    gpu_available = torch.cuda.is_available()
    return jsonify({
        "status": "ok", 
        "python": sys.executable,
        "gpu": gpu_available,
        "gpu_name": torch.cuda.get_device_name(0) if gpu_available else "N/A"
    })


@app.route('/evaluate', methods=['POST'])
def evaluate():
    tmp_dir = None
    try:
        # 1. Parse Parameters
        op = request.args.get('op')
        language = request.args.get('language', 'ascendc')
        
        if not op:
            return jsonify({"error": "Missing 'op' query parameter"}), 400

        # Support both file upload and raw JSON content
        code_content = None
        if 'file' in request.files:
            file = request.files['file']
            code_content = file.read().decode('utf-8')
        elif request.is_json and 'code' in request.json:
            code_content = request.json['code']
        
        if not code_content:
            return jsonify({"error": "No code content provided (use multipart 'file' or json 'code')"}), 400

        # 2. Setup Temporary Workspace
        # Default to /cache/EVAL-TMP, but allow override via env var
        EVAL_TMP_BASE = os.environ.get("OP_EVAL_SERVER_TMP", "/cache/EVAL-TMP")
        os.makedirs(EVAL_TMP_BASE, exist_ok=True)
        
        # We use a distinct directory per request to avoid collision, inside the cache dir
        tmp_dir = tempfile.mkdtemp(prefix=f"eval_server_{op}_", dir=EVAL_TMP_BASE)
        
        candidate_filename = f"{op}.txt"
        candidate_path = os.path.join(tmp_dir, candidate_filename)
        with open(candidate_path, 'w', encoding='utf-8') as f:
            f.write(code_content)
            
        print(f"[INFO] Received request for {op}, written to {candidate_path}")

        # 3. Trigger Evaluation in a Separate Process
        # We use 'spawn' to ensure clean process state (re-import backends correctly)
        ctx = multiprocessing.get_context("spawn")
        
        if DEVICE_POOL is None:
            return jsonify({"error": "Server not initialized with device pool"}), 500

        # Safety timeout - very long since granular timeouts are in correctness.py
        join_timeout_s = int(
            os.environ.get(
                "OP_EVAL_SERVER_JOIN_TIMEOUT_S",
                str(OP_EVAL_WORKER_TIMEOUT_S),
            )
        )
        slot_timeout_s = int(os.environ.get("OP_EVAL_SLOT_TIMEOUT_S", str(OP_EVAL_SLOT_TIMEOUT_S)))

        # Fresh queue for this execution
        result_queue = ctx.Queue()
        p = ctx.Process(
            target=run_worker,
            args=(
                op,
                code_content,
                language,
                f"server_{op}",
                DEVICE_POOL,
                DEVICE_LEASES,
                result_queue,
                tmp_dir,
                slot_timeout_s,
            ),
        )
        p.start()
        p.join(timeout=join_timeout_s)

        if p.is_alive():
            # Safety timeout hit - this shouldn't happen with proper granular timeouts
            print(
                f"[WARN] Worker for {op} exceeded safety timeout ({join_timeout_s}s). Terminating...",
                file=sys.stderr,
            )
            p.terminate()
            p.join(timeout=5)
            if p.is_alive():
                p.kill()
                p.join()

            if DEVICE_LEASES is not None:
                leaked = DEVICE_LEASES.pop(str(p.pid), None)
                if leaked is not None and DEVICE_POOL is not None:
                    DEVICE_POOL.put(leaked)
            
            return jsonify(
                {
                    "compiled": False,
                    "correctness": False,
                    "performance": None,
                    "correctness_info": f"Timeout after {join_timeout_s}s",
                }
            )
        
        if DEVICE_LEASES is not None:
            leaked = DEVICE_LEASES.pop(str(p.pid), None)
            if leaked is not None and DEVICE_POOL is not None:
                DEVICE_POOL.put(leaked)

        # If worker exited (success or failure), read its result
        if not result_queue.empty():
            op_result = result_queue.get()
            return jsonify(op_result)

        return jsonify(
            {
                "compiled": False,
                "correctness": None,
                "performance": None,
                "error": "Evaluation worker returned no result",
            }
        )
    except Exception as e:
        traceback.print_exc()
        return jsonify({"error": str(e), "trace": traceback.format_exc()}), 500
    finally:
        if tmp_dir and os.path.exists(tmp_dir):
            try:
                # Keep artifacts if failed? User didn't specify. Cleaning up is safe.
                shutil.rmtree(tmp_dir)
            except Exception as e:
                print(f"[WARN] Failed to cleanup temp dir {tmp_dir}: {e}")

def main():
    global DEVICE_POOL, DEVICE_LEASES, DEVICE_LEASES_MANAGER
    parser = argparse.ArgumentParser(description="Start the op_eval HTTP server.")
    parser.add_argument("--port", type=int, default=DEFAULT_PORT, help="Port to listen on")
    parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
    parser.add_argument(
        "--backend",
        type=str,
        default="ascendc",
        help="Backend used by the server for device setup (ascendc, tilelang_ascend, cuda, triton)",
    )
    parser.add_argument(
        "--devices",
        type=int,
        default=None,
        help="Number of devices to manage (default depends on backend)",
    )
    # Default concurrency from env var
    default_concurrency = int(os.environ.get("MAX_OPS_PER_NPU", 1))
    parser.add_argument("--max-ops-per-npu", type=int, default=default_concurrency, help="Max concurrent tasks per NPU")
    args = parser.parse_args()

    backend_key = str(args.backend).lower()
    use_cuda = backend_key in CUDA_LANGUAGES
    use_ascend = backend_key in ASCEND_LANGUAGES

    if use_ascend:
        check_ascend_env()
    elif not use_cuda:
        print(f"[WARN] Unknown backend '{args.backend}', defaulting to Ascend device pool", file=sys.stderr)
        use_ascend = True
        check_ascend_env()

    if args.devices is None:
        if use_cuda:
            args.devices = torch.cuda.device_count()
        else:
            args.devices = int(os.environ.get("NPU_VISIBLE_DEVICES", 8))

    # Initialize device pool: each device is added max_ops_per_npu times
    # This guarantees exactly N concurrent tasks per specific NPU
    ctx = multiprocessing.get_context("spawn")
    DEVICE_POOL = ctx.Queue()
    DEVICE_LEASES_MANAGER = multiprocessing.Manager()
    DEVICE_LEASES = DEVICE_LEASES_MANAGER.dict()
    for _ in range(args.max_ops_per_npu):
        for device_id in range(args.devices):
            DEVICE_POOL.put(device_id)
    
    total_slots = args.devices * args.max_ops_per_npu
    device_label = "GPU" if use_cuda else "NPU"
    print(
        f"Initializing Device Pool: {args.devices} {device_label}s (0-{args.devices-1}), "
        f"{args.max_ops_per_npu} slot(s) per {device_label}, {total_slots} total slots"
    )
    if args.max_ops_per_npu > 1:
        print(
            f"[WARN] max_ops_per_npu > 1 enables concurrent processes per {device_label}. "
            f"This can deadlock or crash the {device_label} runtime under load.",
            file=sys.stderr,
        )
    
    print(f"Starting OpEval Server on {args.host}:{args.port}...")
    print(f"Using Python: {sys.executable}")
    
    # Threaded=True allows concurrent handling of HTTP requests (Build phase)
    app.run(host=args.host, port=args.port, threaded=True)

if __name__ == '__main__':
    main()
