import sys
import os
import atexit
from hook.utils import TorchNVTXTracer

tracer = None

def initialize_nvtx_hook():
    global tracer
    if "torch_nvtx_initialized" not in sys.modules:
        sys.meta_path = [h for h in sys.meta_path if not isinstance(h, TorchNVTXTracer)]
        tracer = TorchNVTXTracer(json_path=os.getenv("NVTX_JSON_PATH", "config/pytorch.json"))
        sys.meta_path.insert(0, tracer)
        sys.modules["torch_nvtx_initialized"] = True

def _cleanup():
    global tracer
    if tracer and getattr(tracer, "loader_instance", None):
        try:
            if tracer.loader_instance.emit_nvtx_cm:
                tracer.loader_instance.emit_nvtx_cm.__exit__(None, None, None)
                print("[torch_nvtx] 🔥 emit_nvtx() 已关闭")
        except Exception as e:
            print(f"[torch_nvtx] ❌ emit_nvtx 关闭失败: {e}")

initialize_nvtx_hook()
atexit.register(_cleanup)