from typing import Dict, List, Tuple, Optional, Union
import multiprocessing as mp
import numpy as np
from datetime import datetime
import importlib
import logging
import os
import json
import traceback
import urllib.request
import urllib.parse
import random
from src.verification.backends.backend_registry import BACKEND_REGISTRY
from src.verification.checker import Checker
from src.configs.verify import VerifyConfig
from src.utils import copy_dir
from src.log import logger_setup

logger = logging.getLogger(__name__)

class Verifier:
    """
    Multi-process producer-consumer verifier:
    - Initialize with at most max_npus workers, each worker occupies one NPU
    - self.run(...) is the unified entry point (supports multiple jobs)
    """
    
    def __init__(self, config: VerifyConfig):
        self.config: VerifyConfig = config
        self.max_npus = config.max_npus
        self._ctx = mp.get_context("spawn")
        self._manager = mp.Manager()
        self._task_queue: mp.Queue = self._ctx.Queue()
        self._workers: List[mp.Process] = []
        if self.config.remote_urls:
            self.npu_per_remote = self.max_npus // len(self.config.remote_urls)
            self.to_remote_url_map = {device_id: self.config.remote_urls[device_id // self.npu_per_remote] for device_id in range(self.max_npus)}
        elif self.config.remote_url:
            self.to_remote_url_map = {device_id: self.config.remote_url for device_id in range(self.max_npus)}
        else:
            self.to_remote_url_map = {device_id: None for device_id in range(self.max_npus)}
        self.verifier_setup()

    def verifier_setup(self):
        for device_id in range(self.max_npus):
            p = self._ctx.Process(
                target=self._worker_loop,
                args=(
                    device_id, 
                    self.to_remote_url_map[device_id],
                    self._task_queue, 
                    self.config
                ),
                name=f"Verifier-Worker-{device_id}"
            )
            p.daemon = False
            p.start()
            self._workers.append(p)
            logger.info(f"Started worker process {p.name} (PID: {p.pid})")
    
    def run(self, tasks):
        if not tasks:
            return []

        num_tasks = len(tasks)
        logger.info(f"Submitting {num_tasks} tasks to {self.max_npus} workers")
        local_result_queue: mp.Queue = self._manager.Queue()
        for task_id, (generated_code, ref_src, op) in enumerate(tasks):
            self._task_queue.put((local_result_queue, task_id, generated_code, ref_src, op))
        
        results = [None] * num_tasks
        for _ in range(num_tasks):
            task_id, result = local_result_queue.get()
            logger.info(f"Collecting result for task {task_id}")
            results[task_id] = result
        return results

    def close(self):
        """Actively close all workers."""
        for _ in self._workers:
            self._task_queue.put(None)
        for p in self._workers:
            p.join()
        self._workers.clear()

    @staticmethod
    def _worker_loop(
        device_id: int, 
        remote_url: str,
        task_queue: mp.Queue, 
        config: Optional[VerifyConfig] = None,
    ):
        """
        Each worker runs in its own process:
        - Bind to one NPU (device_id)
        - Continuously fetch tasks from task_queue (producer-consumer model)
        """
        # Create Checker instance in worker process (if config is provided)
        checker = None
        if config is not None:
            checker = Checker(config)

        op_engineer_id=f"op_projects_dev{device_id}"
        local_workspace = os.path.join(config.workspace, config.timestamp, op_engineer_id)
        worker_tag = f"[worker-{device_id}]"
        logger_setup(os.path.join(local_workspace, f'worker-{device_id}.log'))
        logger.info(
            "%s started. ascendc_device=%s, workspace=%s, num_correct_trials=%s, timestamp=%s",
            worker_tag, config.ascendc_device, config.workspace, config.num_correct_trials, config.timestamp
        )

        cpp_src = os.path.join(config.workspace, "CppExtension")
        cpp_dst = os.path.join(local_workspace, "CppExtension")
        logger.info(
            "%s preparing workspace: local_workspace=%s, copying CppExtension: %s -> %s",
            worker_tag, local_workspace, cpp_src, cpp_dst
        )
        copy_dir(cpp_src, cpp_dst)
        
        while True:
            logger.info("%s waiting for next task...", worker_tag)
            task = task_queue.get()
            if task is None:
                logger.info("%s received termination signal. Shutting down.", worker_tag)
                break
            try:
                result_queue, task_id, generated_code, ref_src, op = task
            except Exception:
                logger.exception("%s received invalid task object: %r", worker_tag, task)
                continue
            
            logger.info(
                "%s processing task_id=%s, op=%s, generated_len=%s, ref_len=%s",
                worker_tag,
                task_id,
                op,
                len(generated_code) if isinstance(generated_code, str) else "N/A",
                len(ref_src) if isinstance(ref_src, str) else "N/A",
            )
            result: Dict = {
                'compiled': False,
                'correctness': None,
                'performance': None,
                'device_id': device_id,
            }

            # Step 1: Rule-based validation (ALWAYS LOCAL)
            try:
                if checker is not None:
                    valid = checker.run(op, ref_src, generated_code)
                else:
                    from src.verification.rules_checker import filter_code_result_all
                    valid = filter_code_result_all(generated_code)
            except Exception as e:
                logger.error("%s failed to validate code for task_id=%s: %s", worker_tag, task_id, e)
                result['compile_info'] = f"failed to validate code: {e}"
                result['valid_info'] = f"failed to validate code: {e}"
                result_queue.put((task_id, result))
                continue
            
            if not valid[0]:
                logger.error("%s failed to validate code for task_id=%s: %s", worker_tag, task_id, valid[1])
                result['compile_info'] = f"Error in validating code: [Invalid Error] Your implementation does not meet the requirements. You must implement the operations in forward() as custom kernels in custom_ops_lib and call them from there. Details:\n{valid[1]}"
                result['valid'] = False
                result['valid_info'] = f"Error in validating code: [Invalid Error] Your implementation does not meet the requirements. You must implement the operations in forward() as custom kernels in custom_ops_lib and call them from there. Details:\n{valid[1]}"
                result_queue.put((task_id, result))
                continue

            # Step 2: Compilation & Evaluation (REMOTE or LOCAL)
            if remote_url:
                # Remote mode: send to remote server
                logger.info("%s using REMOTE mode for task_id=%s, urls=%s", worker_tag, task_id, config.remote_urls)
                try:
                    remote_result = Verifier._verify_remote(
                        remote_url=remote_url,
                        generated_code=generated_code,
                        op=op,
                        language=config.remote_language,
                        timeout=config.remote_timeout,
                    )
                    # Merge remote result with local result structure
                    result.update(remote_result)
                    result['valid'] = True  # Passed rule check
                    result['device_id'] = device_id
                    logger.info("%s remote verification completed for task_id=%s: compiled=%s, correctness=%s",
                                worker_tag, task_id, result.get('compiled'), result.get('correctness'))
                except Exception as e:
                    logger.exception("%s remote verification failed for task_id=%s", worker_tag, task_id)
                    result['compile_info'] = f"Remote verification error: {e}"
                    result['error'] = str(e)
                result_queue.put((task_id, result))
                continue
            else:
                try:
                    ctx = mp.get_context("spawn")  # Keep consistent with other places
                    p = ctx.Process(
                        target=Verifier._verify_once_entry,
                        args=(
                            result_queue,
                            task_id,
                            generated_code,
                            ref_src,
                            op,
                            device_id,
                            config.ascendc_device,
                            local_workspace,
                            config.num_correct_trials,
                        ),
                    )
                    logger.info("%s starting verify subprocess for task_id=%s", worker_tag, task_id)
                    p.start()
                    p.join(timeout=config.timeout)  # Wait for subprocess to finish here, then enter next while loop

                    if p.is_alive():
                        logger.error("%s verify subprocess timeout for task_id=%s", worker_tag, task_id)
                        p.terminate()
                        p.join(timeout=5)
                        result_queue.put((
                            task_id,
                            {
                                "compiled": False,
                                "compile_info": "verify subprocess timeout",
                                "correctness": None,
                                "correctness_info": "verify subprocess timeout",
                                "performance": None,
                                "device_id": device_id,
                            },
                        ))
                    else:
                        if p.exitcode == -11:
                            logger.error("%s verify subprocess segmentation fault for task_id=%s", worker_tag, task_id)
                            result_queue.put((
                                task_id,
                                {
                                    "compiled": False,   # Assume segmentation fault occurred at runtime, so compilation succeeded
                                    "correctness": None,
                                    "performance": None,
                                    "correctness_info": "Segmentation fault",
                                    "device_id": device_id,
                                }
                            ))
                        elif p.exitcode != 0:
                            logger.error(
                                "%s verify subprocess exited with code %s for task_id=%s",
                                worker_tag, p.exitcode, task_id
                            )
                            result_queue.put((
                                task_id,
                                {
                                    "compiled": False,
                                    "correctness": None,
                                    "performance": None,
                                    "correctness_info": f"verify subprocess exitcode={p.exitcode}",
                                    "device_id": device_id,
                                    
                                },
                            ))
                        else:
                            logger.info("%s verify subprocess finished for task_id=%s", worker_tag, task_id)
                        
                except Exception:
                    logger.exception("%s failed to start/join verify subprocess for task_id=%s",
                                    worker_tag, task_id)
                    result_queue.put((
                        task_id,
                        {
                            "compiled": False,
                            "correctness": None,
                            "performance": None,
                            "correctness_info": f"failed to start/join verify subprocess, details: {traceback.format_exc()}",
                            "device_id": device_id,
                        },
                    ))

        logger.info("%s worker shut down completed.", worker_tag)

    @staticmethod
    def _verify_once_entry(
        result_queue,
        task_id: str,
        generated_code: str,
        ref_src: str,
        op: str,
        device_id: int,
        ascendc_device: str,
        op_engineer_dir: str,
        num_correct_trials: int,
    ):
        """
        Subprocess entry: call Verifier._verify_once here,
        then directly put (task_id, result) into the result_queue given by parent process.
        """
        try:
            logger_setup(os.path.join(op_engineer_dir, f'worker-{device_id}.log'))
            result = Verifier._verify_once(
                generated_code=generated_code,
                ref_src=ref_src,
                op=op,
                device_id=device_id,
                ascendc_device=ascendc_device,
                op_engineer_dir=op_engineer_dir,
                num_correct_trials=num_correct_trials,
            )
        except KeyboardInterrupt:
            logger.warning("[subprocess] KeyboardInterrupt in _verify_once for task_id=%s", task_id)
            raise
        except Exception as e:
            logger.exception("[subprocess] _verify_once failed for task_id=%s", task_id)
            result = {
                "compiled": False,
                "compile_info": f"_verify_once exception: {type(e).__name__}: {e}",
                "correctness": None,
                "traceback": traceback.format_exc(),
            }

        # Directly write to the main process's result_queue
        result_queue.put((task_id, result))
        
    @staticmethod
    def _verify_once(
        generated_code: str,
        ref_src: str,
        op: str,
        device_id: int,
        ascendc_device: str,
        op_engineer_dir: str,
        num_correct_trials: int,
        language: str = 'ascendc'
    ) -> Dict:
        """
        Single verification logic
        """
        if language not in BACKEND_REGISTRY:
            try:
                importlib.import_module(f"src.verification.backends.{language}_backend")
            except ImportError as e:
                raise ValueError(f"Unsupported language/platform: {language} (module not found)") from e
        backend = BACKEND_REGISTRY.get(language)
        if backend is None:
            raise ValueError(f"Unsupported language/platform: {language}")
        backend = BACKEND_REGISTRY.get(language)(
            op_engineer_dir=op_engineer_dir,
            ascendc_device=ascendc_device,
            device_id=device_id,
        )
        logger.info(
            "Initializing backend: %s with op_engineer_dir=%s, ascendc_device=%s, device_id=%d",
            language, op_engineer_dir, ascendc_device, device_id
        )
        hardware = backend.get_hardware_name()
        logger.info("Using hardware: %s (device_id=%d)", hardware, device_id)

        result: Dict = {
            'compiled': False,
            'correctness': None,
            'performance': None,
            'hardware': hardware,
            'device_id': device_id,
        }

        logger.info("Compiling generated code for op=%s", op)
        compiled, compile_info = backend.compile(generated_code, op)
        if not compiled:
            logger.warning("Compilation failed for op=%s on device_id=%d: %s", op, device_id, compile_info)
            if "[Invalid Error]" in compile_info:
                result['valid'] = False
                result['valid_info'] = compile_info
            else:
                result['valid'] = True
            result['compile_info'] = compile_info
            backend.cleanup()
            logger.info("Backend cleanup done after compilation failure")
            logger.info(result)
            return result
        
        logger.info("Compilation succeeded for op=%s on device_id=%d", op, device_id)
        result['compiled'] = True
        result['compile_info'] = ""  # Set to empty string when compilation succeeds

        logger.info(
            "Starting correctness_execution for op=%s, num_correct_trials=%d",
            op, num_correct_trials
        )
        correctness, info = backend.correctness_execution(ref_src, num_correct_trials)
        if not correctness:
            logger.warning(
                "Correctness check failed for op=%s on device_id=%d: %s",
                op, device_id, info
            )
            result['correctness'] = False
            result['correctness_info'] = info
            backend.cleanup()
            logger.info("Backend cleanup done after correctness failure")
            logger.info(result)
            return result
        logger.info("Correctness check passed for op=%s on device_id=%d", op, device_id)
        result['correctness'] = True
        result['correctness_info'] = ""  # Set to empty string when correctness check passes

        logger.info("Starting performance timing for op=%s", op)
        elapsed_times = backend.time_execution()
        if not elapsed_times:
            logger.warning("No elapsed times returned from backend.time_execution for op=%s", op)
        else:
            logger.info("Raw elapsed times: %s", elapsed_times)
            
        result['performance'] = {
            "mean": float(f"{np.mean(elapsed_times):.3g}"),
            "std": float(f"{np.std(elapsed_times):.3g}"),
            "min": float(f"{np.min(elapsed_times):.3g}"),
            "max": float(f"{np.max(elapsed_times):.3g}"),
            "num_trials": len(elapsed_times),
        }
        logger.info(result)
        backend.cleanup()
        logger.info("Backend cleanup done after successful verification")
        logger.info("Finished _verify_once for op=%s on device_id=%d", op, device_id)
        return result

    @staticmethod
    def _verify_remote(
        remote_url: str,
        generated_code: str,
        op: str,
        language: str,
        timeout: float,
    ) -> Dict:
        """
        Send code to remote server for compilation and evaluation.
        
        Args:
            remote_url: Base URL of the remote op_eval server (e.g., 'http://127.0.0.1:5001')
            generated_code: The generated kernel code to evaluate
            op: Operator name
            language: Language/backend (e.g., 'ascendc', 'tilelang_ascend')
            timeout: HTTP request timeout in seconds
            
        Returns:
            Dict containing 'compiled', 'correctness', 'performance', and optionally 'error' fields
        """
        endpoint = remote_url.rstrip("/") + "/evaluate"
        params = urllib.parse.urlencode({"op": op, "language": language})
        url = f"{endpoint}?{params}"
        
        logger.info("Sending remote evaluation request: op=%s, language=%s, url=%s", op, language, url)
        
        payload = json.dumps({"code": generated_code}).encode("utf-8")
        req = urllib.request.Request(
            url,
            data=payload,
            headers={"Content-Type": "application/json"},
        )
        
        try:
            with urllib.request.urlopen(req, timeout=timeout) as response:
                if response.status == 200:
                    result = json.loads(response.read())
                    logger.info("Remote evaluation response: compiled=%s, correctness=%s", 
                               result.get('compiled'), result.get('correctness'))
                    return result
                else:
                    error_msg = f"HTTP {response.status}: {response.read().decode()}"
                    logger.error("Remote evaluation failed: %s", error_msg)
                    return {
                        "compiled": False,
                        "correctness": None,
                        "performance": None,
                        "error": error_msg,
                    }
        except urllib.error.HTTPError as e:
            error_msg = f"HTTP Error {e.code}: {e.reason}"
            logger.error("Remote evaluation HTTP error: %s", error_msg)
            return {
                "compiled": False,
                "correctness": None,
                "performance": None,
                "error": error_msg,
            }
        except urllib.error.URLError as e:
            error_msg = f"URL Error: {e.reason}"
            logger.error("Remote evaluation URL error: %s", error_msg)
            return {
                "compiled": False,
                "correctness": None,
                "performance": None,
                "error": f"Remote Error: {error_msg}",
            }
        except Exception as e:
            error_msg = str(e)
            logger.exception("Remote evaluation unexpected error")
            return {
                "compiled": False,
                "correctness": None,
                "performance": None,
                "error": f"Remote Error: {error_msg}",
            }