import importlib
import os
import subprocess
import shutil
import sys
from datetime import datetime
from pathlib import Path
from op_eval.config import (
    get_op_engineer_dir,
    ascendc_device,
    project_root_path,
    ASCEND_TEMPLATE_ROOT,
    ASCEND_BUILD_TIMEOUT_S,
)
from op_eval.utils.utils import underscore_to_pascalcase, normalize_op_name


def _create_operator_workspace(base_root: str, op_capital: str) -> str:
    workspace_root = os.path.join(base_root, "operators")
    os.makedirs(workspace_root, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d%H%M%S_%f")
    workspace_dir = os.path.join(workspace_root, f"{op_capital}_{timestamp}")
    shutil.copytree(ASCEND_TEMPLATE_ROOT, workspace_dir)
    return workspace_dir


def _validate_acl_symbols(customize_dst: Path, op_capital: str):
    lib_path = customize_dst / "op_api" / "lib" / "libcust_opapi.so"
    if not lib_path.exists():
        raise FileNotFoundError(f"Missing libcust_opapi.so under {customize_dst}")
    expected_symbols = [
        f"aclnn{op_capital}",
        f"aclnn{op_capital}GetWorkspaceSize",
    ]
    try:
        nm_output = subprocess.run(
            ["nm", "-D", str(lib_path)],
            check=True,
            capture_output=True,
            text=True,
        ).stdout
    except FileNotFoundError:
        print("[WARN] nm not found; skipping ACL symbol validation")
        return
    except subprocess.CalledProcessError as exc:
        raise RuntimeError(
            f"Failed to inspect symbols in {lib_path}: {exc.stdout or exc.stderr}"
        ) from exc
    missing = [symbol for symbol in expected_symbols if symbol not in nm_output]
    if missing:
        raise RuntimeError(
            f"Missing ACLNN symbols {missing} in {lib_path}. "
            "Generated operator will not be callable."
        )


def _stage_custom_opp(target_directory: str, workspace_root: str, op_capital: str):
    """
    Copy the packaged custom operator artifacts into a workspace-local opp directory
    so ASCEND_CUSTOM_OPP_PATH can point to an isolated tree. Creates a libopapi.so
    symlink for compatibility with runtime lookups.
    """
    build_out_dir = Path(target_directory) / "build_out"
    # Make installer search robust to OS/Arch differences (e.g., openEuler vs HCE)
    installers = list(build_out_dir.glob("custom_opp_*.run"))
    if not installers:
        raise FileNotFoundError(f"No custom_opp_*.run installer found in {build_out_dir}")
    # Prioritize if multiple? Usually only one exists. Take first.
    installer = installers[0]
    
    opp_root = Path(workspace_root) / "opp"
    if opp_root.exists():
        shutil.rmtree(opp_root)
    os.makedirs(opp_root, exist_ok=True)
    install_cmd = [
        str(installer),
        "--quiet",
        f"--install-path={opp_root}",
    ]
    try:
        subprocess.run(
            install_cmd,
            check=True,
            capture_output=True,
            text=True,
            cwd=build_out_dir,
        )
    except subprocess.CalledProcessError as exc:
        raise RuntimeError(
            f"Failed to install custom opp via {installer}:\n{exc.stdout}\n{exc.stderr}"
        ) from exc
    customize_dst = opp_root / "vendors" / "customize"
    if not customize_dst.exists():
        raise FileNotFoundError(f"customize directory missing under {opp_root}")
    libcust = customize_dst / "op_api" / "lib" / "libcust_opapi.so"
    libopapi = customize_dst / "op_api" / "lib" / "libopapi.so"
    if not libcust.exists():
        raise FileNotFoundError(f"Expected libcust_opapi.so under {libcust.parent}")
    if libopapi.exists() or libopapi.is_symlink():
        libopapi.unlink()
    os.symlink(libcust.name, libopapi)
    _validate_acl_symbols(customize_dst, op_capital)


def _load_custom_module(module_name: str):
    if not module_name:
        return
    if "custom_ops_lib" in sys.modules:
        del sys.modules["custom_ops_lib"]
    try:
        if module_name in sys.modules:
            module = importlib.reload(sys.modules[module_name])
        else:
            module = importlib.import_module(module_name)
    except ModuleNotFoundError as exc:
        raise RuntimeError(
            f"Failed to import per-op module '{module_name}'. "
            "Ensure build_and_run.sh installed the wheel into pybind_install."
        ) from exc
    sys.modules["custom_ops_lib"] = module


def _push_env(custom_opp_path: str, custom_lib_path: str, pybind_path: str | None = None):
    token = {
        "ASCEND_CUSTOM_OPP_PATH": os.environ.get("ASCEND_CUSTOM_OPP_PATH"),
        "LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH"),
        "PYTHONPATH": os.environ.get("PYTHONPATH"),
        "SYS_PATH": list(sys.path),
    }
    os.environ["ASCEND_CUSTOM_OPP_PATH"] = custom_opp_path
    existing_ld_path = os.environ.get("LD_LIBRARY_PATH", "")
    if custom_lib_path:
        if existing_ld_path:
            if custom_lib_path not in existing_ld_path.split(":"):
                os.environ["LD_LIBRARY_PATH"] = f"{custom_lib_path}:{existing_ld_path}"
        else:
            os.environ["LD_LIBRARY_PATH"] = custom_lib_path

    if pybind_path:
        existing_py = os.environ.get("PYTHONPATH", "")
        if existing_py:
            if pybind_path not in existing_py.split(":"):
                os.environ["PYTHONPATH"] = f"{pybind_path}:{existing_py}"
        else:
            os.environ["PYTHONPATH"] = pybind_path
        if pybind_path not in sys.path:
            sys.path.insert(0, pybind_path)
    return token


def _pop_env(token):
    if token.get("ASCEND_CUSTOM_OPP_PATH") is None:
        os.environ.pop("ASCEND_CUSTOM_OPP_PATH", None)
    else:
        os.environ["ASCEND_CUSTOM_OPP_PATH"] = token["ASCEND_CUSTOM_OPP_PATH"]
    if token.get("LD_LIBRARY_PATH") is None:
        os.environ.pop("LD_LIBRARY_PATH", None)
    else:
        os.environ["LD_LIBRARY_PATH"] = token["LD_LIBRARY_PATH"]
    if token.get("PYTHONPATH") is None:
        os.environ.pop("PYTHONPATH", None)
    else:
        os.environ["PYTHONPATH"] = token["PYTHONPATH"]
    sys.path[:] = token.get("SYS_PATH", sys.path)


def ascend_compile(generated_code, op, context, apply_env: bool = True):
    base_root = get_op_engineer_dir()
    # Normalize op name: strip numeric prefix, lowercase, clean underscores
    # Normalize op name: strip numeric prefix, lowercase, clean underscores
    # e.g., "1_Square_matrix_multiplication_" -> "square_matrix_multiplication_custom"
    op = normalize_op_name(op) + '_custom'
    op_capital = underscore_to_pascalcase(op)
    op_engineer_dir = _create_operator_workspace(base_root, op_capital)
    previous_op_root = os.environ.get("ASCEND_OP_ROOT")
    os.environ["ASCEND_OP_ROOT"] = op_engineer_dir
    target_directory=os.path.join(op_engineer_dir, op_capital)
    build_info = {
        "workspace_root": op_engineer_dir,
        "op_project_dir": target_directory,
        "op_name": op,
        "op_capital": op_capital,
    }
    module_name = f"custom_ops_lib_{op}"
    pybind_path = os.path.join(op_engineer_dir, 'CppExtension', 'pybind_install')
    try:
        try:
            compile(generated_code, "<string>", "exec")
            exec(generated_code, context)  # For Python, use exec() (be careful with untrusted code)
        except Exception as e:
            raise Exception(f'Error in generated code {e}')
        
        # create ascendc project
        log_prefix = f"[{op}]"
        if os.path.exists(os.path.join(op_engineer_dir, op_capital)):
            print(f"[INFO]{log_prefix} Operator project already exists, deleted")
            shutil.rmtree(os.path.join(op_engineer_dir, op_capital))
        with open(os.path.join(op_engineer_dir, f'{op}.json'), 'w') as f:
            f.write(context.get('project_json_src'))
        try:
            print(f"[INFO]{log_prefix} Begin create operator project")
            os.chdir(op_engineer_dir)
            subprocess.run(
                ["msopgen", 'gen', '-i', f'{op}.json', '-c', ascendc_device, '-lan', 'cpp', '-out', op_capital],
                check=True,
                capture_output=True,
                text=True,
                timeout=ASCEND_BUILD_TIMEOUT_S,
            )
            print(f"[INFO]{log_prefix} Create operator project succeeded")
        except subprocess.CalledProcessError as e:
            print(f"[INFO]{log_prefix} Create operator project failed!")
            print("Error Output:\n", e.stdout)
            print("Error Output:\n", e.stderr)
            feedback = f'Exit Code: {e.returncode}\nError Output:\n{e.stdout}'
            raise Exception(feedback)
        except subprocess.TimeoutExpired:
            print(f"[INFO]{log_prefix} Create operator project timed out ({ASCEND_BUILD_TIMEOUT_S}s)!")
            raise Exception(
                f"Compilation timed out after {ASCEND_BUILD_TIMEOUT_S}s during msopgen"
            )

        # write code to specific location
        with open(os.path.join(target_directory, 'op_host', f'{op}_tiling.h'), 'w') as f:
            f.write(context.get('host_tiling_src'))

        with open(os.path.join(target_directory, 'op_host', f'{op}.cpp'), 'w') as f:
            f.write(context.get('host_operator_src'))

        with open(os.path.join(target_directory, 'op_kernel', f'{op}.cpp'), 'w') as f:
            f.write(context.get('kernel_src'))

        with open(os.path.join(op_engineer_dir, 'CppExtension', 'csrc', 'op.cpp'), 'w') as f:
            f.write(context.get('python_bind_src'))

        try:
            environ_varible = 'ASCEND_CUSTOM_OPP_PATH'
            os.environ.pop(environ_varible, None)
            print(f"[INFO]{log_prefix} Begin build")
            os.chdir(target_directory)
            subprocess.run(
                ["./build.sh"],
                check=True,
                capture_output=True,
                text=True,
                timeout=ASCEND_BUILD_TIMEOUT_S,
            )
            print(f"[INFO]{log_prefix} Build succeeded")
        except subprocess.CalledProcessError as e:
            print(f"[INFO]{log_prefix} Build failed!")
            error_output = ''
            for line in e.stdout.split('\n'):
                if '[ERROR]' in line or 'error:' in line:
                    print(line)
                    error_output += line + '\n'
            for line in e.stderr.split('\n'):
                if '[ERROR]' in line or 'error:' in line:
                    print(line)
                    error_output += line + '\n'
            feedback = f'Exit Code: {e.returncode}\nError Output:\n{error_output}'
            raise Exception(feedback)
        except subprocess.TimeoutExpired:
            print(f"[INFO]{log_prefix} Build timed out ({ASCEND_BUILD_TIMEOUT_S}s)!")
            raise Exception(
                f"Compilation timed out after {ASCEND_BUILD_TIMEOUT_S}s during build"
            )

        try:
            print(f"[INFO]{log_prefix} Begin deploy")
            _stage_custom_opp(target_directory, op_engineer_dir, op_capital)
            print(f"[INFO]{log_prefix} Deploy succeeded")
        except Exception as e:
            print(f"[INFO]{log_prefix} Deploy failed!")
            raise Exception(str(e))

        try:
            print(f"[INFO]{log_prefix} Begin pybind")
            os.chdir(os.path.join(op_engineer_dir, 'CppExtension'))
            pybind_env = os.environ.copy()
            pybind_env["CUSTOM_OPS_MODULE_NAME"] = module_name
            pybind_env["PYBIND_DEST"] = pybind_path
            subprocess.run(
                ['bash', "build_and_run.sh"],
                check=True,
                capture_output=True,
                text=True,
                env=pybind_env,
                timeout=ASCEND_BUILD_TIMEOUT_S,
            )
            print(f"[INFO]{log_prefix} Pybind succeeded\n")
        except subprocess.CalledProcessError as e:
            print(f"[INFO]{log_prefix} Pybind failed!")
            feedback = f'Exit Code: {e.returncode}\nError Output:\n{e.stdout}'
            raise Exception(feedback)
        except subprocess.TimeoutExpired:
            print(f"[INFO]{log_prefix} Pybind timed out ({ASCEND_BUILD_TIMEOUT_S}s)!")
            raise Exception(
                f"Compilation timed out after {ASCEND_BUILD_TIMEOUT_S}s during pybind"
            )

        custom_opp_path = os.path.join(op_engineer_dir, "opp", "vendors", "customize")
        custom_lib_path = os.path.join(custom_opp_path, "op_api", "lib")
        build_info.update(
            {
                "custom_opp_path": custom_opp_path,
                "custom_lib_path": custom_lib_path,
                "pybind_path": pybind_path,
                "pybind_module": module_name,
            }
        )
        env_token = _push_env(custom_opp_path, custom_lib_path, pybind_path)
        
        try:
            _load_custom_module(module_name)
            compile(context['model_src'], "<string>", "exec")
            exec(context['model_src'], context)
        except Exception as e:
            raise Exception(f'Error in generated code {e}')
        finally:
            if not apply_env:
                _pop_env(env_token)
    finally:
        os.chdir(project_root_path)
        if previous_op_root is not None:
            os.environ["ASCEND_OP_ROOT"] = previous_op_root
        else:
            os.environ.pop("ASCEND_OP_ROOT", None)

    return build_info


if __name__ == '__main__':
    import torch
    import torch_npu
    import custom_ops_lib
    op = 'relu'
    generated_method = getattr(custom_ops_lib, op)
