from typing import List, Tuple, Optional, Union
import multiprocessing as mp
import numpy as np
import importlib
import logging
import os
import traceback
from src.verification.backends.backend_registry import BACKEND_REGISTRY
from src.configs.verify import VerifyConfig
from src.models import OpVerificationResult, PerformanceMetric
from src.utils import copy_dir
from src.log import logger_setup

logger = logging.getLogger(__name__)

class Verifier:
    """
    Multi-process producer-consumer verifier:
    - Initialize with at most max_npus workers, each worker occupies one NPU
    - self.run(...) is the unified entry point (supports multiple jobs)
    """
    def __init__(self, config: VerifyConfig):
        self.config: VerifyConfig = config
        self.max_npus = config.max_npus
        self._ctx = mp.get_context("spawn")
        self._manager = mp.Manager()
        self._task_queue: mp.Queue = self._ctx.Queue()
        self._workers: List[mp.Process] = []
        self.verifier_setup()

    def verifier_setup(self):
        for device_id in range(self.max_npus):
            p = self._ctx.Process(
                target=self._worker_loop,
                args=(
                    device_id, 
                    self.config.ascendc_device, 
                    self._task_queue, 
                    self.config.workspace, 
                    self.config.num_correct_trials, 
                    self.config.timestamp,
                    self.config.timeout
                ),
                name=f"Verifier-Worker-{device_id}"
            )
            p.daemon = False
            p.start()
            self._workers.append(p)
            logger.info(f"Started worker process {p.name} (PID: {p.pid})")
    
    def run(self, tasks):
        if not tasks:
            return []

        num_tasks = len(tasks)
        logger.info(f"Submitting {num_tasks} tasks to {self.max_npus} workers")
        local_result_queue: mp.Queue = self._manager.Queue()
        for task_id, (generated_code, ref_src, op) in enumerate(tasks):
            self._task_queue.put((local_result_queue, task_id, generated_code, ref_src, op))

        results = [None] * num_tasks
        for _ in range(num_tasks):
            task_id, result = local_result_queue.get()
            logger.info(f"Collecting result for task {task_id}")
            results[task_id] = result
        return results

    def close(self):
        """Actively close all workers."""
        for _ in self._workers:
            self._task_queue.put(None)
        for p in self._workers:
            p.join()
        self._workers.clear()

    @staticmethod
    def _worker_loop(
        device_id: int, 
        ascendc_device: str, 
        task_queue: mp.Queue, 
        workspace: str, 
        num_correct_trials: int, 
        timestamp: str,
        timeout: Optional[float] = None
    ):
        """
        Each worker runs in its own process:
        - Bind to one NPU (device_id)
        - Continuously fetch tasks from task_queue (producer-consumer model)
        """
        op_engineer_id=f"op_projects_dev{device_id}"
        local_workspace = os.path.join(workspace, timestamp, op_engineer_id)
        worker_tag = f"[worker-{device_id}]"
        logger_setup(os.path.join(local_workspace, f'worker-{device_id}.log'))
        logger.info(
            "%s started. ascendc_device=%s, workspace=%s, num_correct_trials=%s, timestamp=%s",
            worker_tag, ascendc_device, workspace, num_correct_trials, timestamp
        )
        
        cpp_src = os.path.join(workspace, "CppExtension")
        cpp_dst = os.path.join(local_workspace, "CppExtension")
        logger.info(
            "%s preparing workspace: local_workspace=%s, copying CppExtension: %s -> %s",
            worker_tag, local_workspace, cpp_src, cpp_dst
        )
        copy_dir(cpp_src, cpp_dst)
        
        while True:
            logger.info("%s waiting for next task...", worker_tag)
            task = task_queue.get()
            if task is None:
                logger.info("%s received termination signal. Shutting down.", worker_tag)
                break
            try:
                result_queue, task_id, generated_code, ref_src, op = task
            except Exception:
                logger.exception("%s received invalid task object: %r", worker_tag, task)
                continue
            
            logger.info(
                "%s processing task_id=%s, op=%s, generated_len=%s, ref_len=%s",
                worker_tag,
                task_id,
                op,
                len(generated_code) if isinstance(generated_code, str) else "N/A",
                len(ref_src) if isinstance(ref_src, str) else "N/A",
            )

            # Start a subprocess in worker to run _verify_once
            try:
                ctx = mp.get_context("spawn")  # Keep consistent with other places
                p = ctx.Process(
                    target=Verifier._verify_once_entry,
                    args=(
                        result_queue,
                        task_id,
                        generated_code,
                        ref_src,
                        op,
                        device_id,
                        ascendc_device,
                        local_workspace,
                        num_correct_trials,
                    ),
                )
                logger.info("%s starting verify subprocess for task_id=%s", worker_tag, task_id)
                p.start()
                p.join(timeout=timeout)  # Wait for subprocess to finish here, then enter next while loop

                if p.is_alive():
                    logger.error("%s verify subprocess timeout for task_id=%s", worker_tag, task_id)
                    p.terminate()
                    p.join(timeout=5)
                    result_queue.put((
                        task_id,
                        OpVerificationResult(
                            compiled=False,
                            compile_info="verify subprocess timeout",
                            correctness=False,
                            correctness_info="verify subprocess timeout",
                            performance=None,
                            device_id=device_id,
                        )
                    ))
                else:
                    if p.exitcode == -11:
                        logger.error("%s verify subprocess segmentation fault for task_id=%s", worker_tag, task_id)
                        result_queue.put((
                            task_id,
                            OpVerificationResult(
                                compiled=False,
                                compile_info="",
                                correctness=False,
                                correctness_info="Segmentation fault",
                                performance=None,
                                device_id=device_id,
                            )
                        ))
                    elif p.exitcode != 0:
                        logger.error(
                            "%s verify subprocess exited with code %s for task_id=%s",
                            worker_tag, p.exitcode, task_id
                        )
                        result_queue.put((
                            task_id,
                            OpVerificationResult(
                                compiled=False,
                                compile_info=f"verify subprocess exitcode={p.exitcode}",
                                correctness=False,
                                correctness_info=f"verify subprocess exitcode={p.exitcode}",
                                performance=None,
                                device_id=device_id,
                            )
                        ))
                    else:
                        logger.info("%s verify subprocess finished for task_id=%s", worker_tag, task_id)
                    
            except Exception:
                logger.exception("%s failed to start/join verify subprocess for task_id=%s",
                                 worker_tag, task_id)
                result_queue.put((
                    task_id,
                    OpVerificationResult(
                        compiled=False,
                        compile_info=f"failed to start/join verify subprocess, details: {traceback.format_exc()}",
                        correctness=False,
                        correctness_info=f"failed to start/join verify subprocess, details: {traceback.format_exc()}",
                        performance=None,
                        device_id=device_id,
                    )
                ))

        logger.info("%s worker shut down completed.", worker_tag)

    @staticmethod
    def _verify_once_entry(
        result_queue,
        task_id: str,
        generated_code: str,
        ref_src: str,
        op: str,
        device_id: int,
        ascendc_device: str,
        op_engineer_dir: str,
        num_correct_trials: int,
    ):
        """
        Subprocess entry: call Verifier._verify_once here,
        then directly put (task_id, result) into the result_queue given by parent process.
        """
        try:
            logger_setup(os.path.join(op_engineer_dir, f'worker-{device_id}.log'))
            result: OpVerificationResult = Verifier._verify_once(
                generated_code=generated_code,
                ref_src=ref_src,
                op=op,
                device_id=device_id,
                ascendc_device=ascendc_device,
                op_engineer_dir=op_engineer_dir,
                num_correct_trials=num_correct_trials,
            )
        except KeyboardInterrupt:
            logger.warning("[subprocess] KeyboardInterrupt in _verify_once for task_id=%s", task_id)
            raise
        except Exception as e:
            logger.exception("[subprocess] _verify_once failed for task_id=%s", task_id)

            result = OpVerificationResult(
                compiled=False,
                compile_info=f"_verify_once exception: {type(e).__name__}: {e}",
                correctness=False,
                correctness_info=f"_verify_once exception: {type(e).__name__}: {traceback.format_exc()}",
                performance=None,
                hardware=hardware,
                device_id=device_id
            )

        # Write to main process's result_queue
        result_queue.put((task_id, result))

    @staticmethod
    def _verify_once(
        generated_code: str,
        ref_src: str,
        op: str,
        device_id: int,
        ascendc_device: str,
        op_engineer_dir: str,
        num_correct_trials: int,
        language: str = 'ascendc'
    ) -> OpVerificationResult:
        """
        Single verification logic
        """
        if language not in BACKEND_REGISTRY:
            try:
                importlib.import_module(f"src.verification.backends.{language}_backend")
            except ImportError as e:
                raise ValueError(f"Unsupported language/platform: {language} (module not found)") from e
        backend = BACKEND_REGISTRY.get(language)
        if backend is None:
            raise ValueError(f"Unsupported language/platform: {language}")
        backend = BACKEND_REGISTRY.get(language)(
            op_engineer_dir=op_engineer_dir,
            ascendc_device=ascendc_device,
            device_id=device_id
        )
        logger.info(
            "Initializing backend: %s with op_engineer_dir=%s, ascendc_device=%s, device_id=%d",
            language, op_engineer_dir, ascendc_device, device_id
        )
        hardware = backend.get_hardware_name()
        logger.info("Using hardware: %s (device_id=%d)", hardware, device_id)

        result: OpVerificationResult = OpVerificationResult(
            compiled=False,
            compile_info="",
            correctness=False,  # Initialize as False, indicating correctness check has not been performed
            correctness_info="",
            performance=None,
            hardware=hardware,
            device_id=device_id
        )

        logger.info("Compiling generated code for op=%s", op)
        compiled, compile_info = backend.compile(generated_code, op)
        if not compiled:
            logger.warning("Compilation failed for op=%s on device_id=%d: %s", op, device_id, compile_info)
            result.compile_info = compile_info
            result.correctness = False  # When compilation fails, correctness check not performed
            result.correctness_info = ""  # When compilation fails, correctness check not performed
            backend.cleanup()
            logger.info("Backend cleanup done after compilation failure")
            logger.info(result)
            return result
        logger.info("Compilation succeeded for op=%s on device_id=%d", op, device_id)
        result.compiled = True
        result.compile_info = ""  # Set to empty string when compilation succeeds

        logger.info(
            "Starting correctness_execution for op=%s, num_correct_trials=%d",
            op, num_correct_trials
        )
        correctness, info = backend.correctness_execution(ref_src, num_correct_trials)
        if not correctness:
            logger.warning(
                "Correctness check failed for op=%s on device_id=%d: %s",
                op, device_id, info
            )
            result.correctness = False
            result.correctness_info = info
            backend.cleanup()
            logger.info("Backend cleanup done after correctness failure")
            logger.info(result)
            return result
        logger.info("Correctness check passed for op=%s on device_id=%d", op, device_id)
        result.correctness = True
        result.correctness_info = ""  # Set to empty string when correctness check passes

        logger.info("Starting performance timing for op=%s", op)
        elapsed_times = backend.time_execution()
        if not elapsed_times:
            logger.warning("No elapsed times returned from backend.time_execution for op=%s", op)
        else:
            logger.info("Raw elapsed times: %s", elapsed_times)
            
        result.performance = PerformanceMetric(
            mean=float(f"{np.mean(elapsed_times):.3g}"),
            std=float(f"{np.std(elapsed_times):.3g}"),
            min=float(f"{np.min(elapsed_times):.3g}"),
            max=float(f"{np.max(elapsed_times):.3g}"),
            num_trials=len(elapsed_times),
        )
        
        logger.info(result)
        backend.cleanup()
        logger.info("Backend cleanup done after successful verification")
        logger.info("Finished _verify_once for op=%s on device_id=%d", op, device_id)
        return result