import os

import torch_npu
import torch
from op_eval.backends.backend_registry import register_backend, Backend
from op_eval.utils.ascend_compile_pipeline import ascend_compile
from op_eval.utils.correctness import execute_template
from op_eval.utils.performance import time_execution_event_template
from op_eval.config import project_root_path, ascendc_device


def _skip_npu_cleanup() -> bool:
    raw = os.environ.get("OP_EVAL_SKIP_NPU_CLEANUP")
    if raw is None:
        return False
    return raw.strip().lower() in ("1", "true", "yes", "on")

@register_backend('ascendc')
class AscendBackend(Backend):
    def __init__(self):
        self.context = {}
        self.device = self.get_device()
        self.model_key = 'ModelNew'  # Custom model class name for evaluation

    def get_device(self):
        # Read from env var set by server worker, default to 0
        device_id = int(os.environ.get("ASCEND_DEVICE_ID", "0"))
        return torch.device(f"npu:{device_id}")

    def set_device(self, device_id: int):
        self.device = torch.device(f"npu:{device_id}")
        torch_npu.npu.set_device(device_id)

    def get_hardware_name(self):
        return ascendc_device  # torch_npu.npu.get_device_name(device) causes crash

    def synchronize(self, device=None):
        """Synchronize NPU device to ensure all operations are complete."""
        if device is None:
            device = self.device
        torch_npu.npu.synchronize(device=device)

    @property
    def event_class(self):
        """Return the NPU Event class for timing."""
        return torch_npu.npu.Event

    def compile(self, generated_code, op, apply_env: bool = True):
        try:
            build_info = ascend_compile(generated_code, op, self.context, apply_env=apply_env)
            self.context["_op_eval_build_info"] = build_info
            self.context["_op_eval_generated_code"] = generated_code
            return True, None
        except Exception as e:
            os.chdir(project_root_path)
            return False, str(e)

    def correctness_execution(self, ref_src):
        synchronize = torch_npu.npu.synchronize
        try:
            exec(ref_src, self.context)
        except Exception as e:
            raise RuntimeError(f"Failed to compile reference model: {str(e)}")
        return execute_template(synchronize, self.device, self.context)

    def time_execution(self, eval_target='ModelNew'):
        event_class = torch_npu.npu.Event
        synchronize = torch_npu.npu.synchronize
        return time_execution_event_template(
            self.context, self.device, synchronize, event_class, eval_target,
        )

    def cleanup(self):
        return
        self.context = {}
        if _skip_npu_cleanup():
            return
        torch_npu.npu.empty_cache()
        torch_npu.npu.synchronize(device=self.device)
