import torch_npu
import torch
from src.verification.backends.backend_registry import register_backend, Backend
from src.verification.utils.ascend_compile_pipeline import ascend_compile
from src.verification.utils.correctness import execute_template
from src.verification.utils.performance import time_execution_event_template
import os
import traceback
import logging
logger = logging.getLogger(__name__)

@register_backend('ascendc')
class AscendBackend(Backend):
    def __init__(self, op_engineer_dir: str, ascendc_device, device_id: int=0):
        self.context = {}
        self.device = self.get_device(device_id)
        self.project_root_path = os.getcwd()
        self.op_engineer_dir = os.path.join(self.project_root_path, op_engineer_dir)
        self.ascendc_device = ascendc_device
        
    def get_device(self, device_id: int=0):
        torch_npu.npu.set_device(device_id)
        return torch.device(f'npu:{device_id}')
    
    def set_device(self):
        pass
    def get_hardware_name(self):
        return self.ascendc_device

    def compile(self, generated_code, op):
        try:
            ascend_compile(generated_code, op, self.context, self.op_engineer_dir, self.ascendc_device)
            os.chdir(self.project_root_path)
            return True, None
        except Exception as e:
            os.chdir(self.project_root_path)
            logger.error(f"Failed to compile generated code: {traceback.format_exc()}")
            return False, str(e)

    def correctness_execution(self, ref_src, num_correct_trials):
        synchronize = torch_npu.npu.synchronize
        try:
            exec(ref_src, self.context)
        except Exception as e:
            raise RuntimeError(f"Failed to compile reference model: {str(e)}")
        return execute_template(synchronize, self.device, self.context, num_correct_trials)

    def time_execution(self, eval_target='ModelNew'):
        event_class = torch_npu.npu.Event
        synchronize = torch_npu.npu.synchronize
        return time_execution_event_template(self.context, self.device, synchronize, event_class, eval_target)

    def cleanup(self):
        del self.context
        torch_npu.npu.empty_cache()
        torch_npu.npu.synchronize(device=self.device)