import ast
import types
import re
from typing import List, Tuple, Optional

class CodeFilter:
    def __init__(self):
        self._build_forbidden_ops()
        
    def _build_forbidden_ops(self):
        import torch
        all_ops_atens = {
            f"torch.ops.aten.{name}"
            for name in torch.ops.aten.__dict__.keys()
            if not name.startswith("__")
        }
        all_nns = {
            f"torch.nn.{name}"
            for name in torch.nn.__dict__.keys()
            if not name.startswith("__")
        }
        all_torchs = {
            f"torch.{name}"
            for name, obj in torch.__dict__.items()
            if not (name.startswith("__") and name.endswith("__"))
            and not isinstance(obj, types.ModuleType)
        }
        self.allowed_ops = {
            'torch.tensor','torch.ops.aten.eye', 'torch.rand', 'torch.empty_strided',
            'torch.nn.parameter', 'torch.ops.aten.mode', 'torch.ParameterDict',
            'torch.mode', 'torch.randn', 'torch.ops.aten.zeros_like', 'torch.zeros',
            'torch.nn.ParameterList', 'torch.ones_like', 'torch.nn.modules',
            'torch.nn.factory_kwargs', 'torch.ops.aten.new_zeros', 'torch.eye',
            'torch.ops.aten.new_empty', 'torch.empty_like', 'torch.nn.Parameter',
            'torch.FloatTensor', 'torch.ops.aten.view', 'torch.ops.aten.copy',
            'torch.ops.aten.new_ones', 'torch.ops.aten.ones_like', 'torch.zeros_like',
            'torch.ops.aten.new_empty_strided', 'torch.ops.aten.randn',
            'torch.ops.aten.empty_strided', 'torch.numel', 'torch.ops.aten.full_like',
            'torch.empty', 'torch.Tensor', 'torch.BoolTensor', 'torch.full',
            'torch.ops.aten.new_full', 'torch.ops.aten.full', 'torch.ops.aten.zeros',
            'torch.ops.aten.rand', 'torch.IntTensor', 'torch.full_like',
            'torch.nn.common_types', 'torch.nn.ModuleList', 'torch.ops.aten.ones',
            'torch.ones', 'torch.ops.aten.empty_like', 'torch.LongTensor', 'torch.is_tensor'
        }
        
        self.forbidden_ops = (all_nns | all_ops_atens | all_torchs) - self.allowed_ops
        
    def filter_code(self, code):
        """
        Rules:
        1. Only check class ModelNew, and only care about its __init__ and forward.
        2. In forward, cannot directly call forbidden torch / torch.nn / torch.ops.aten.
        3. In forward, cannot directly call self.xxx(), if xxx is forbidden_value in __init__.
        4. The entire file must import custom_ops_lib (including from / as etc.).
        5. In ModelNew (any method), must call custom_ops_lib functions/operators at least once.
        """
        try:
            tree = ast.parse(code)
            analyzer = ClassAnalyzer(self.forbidden_ops, target_class_name="ModelNew")
            analyzer.visit(tree)

            if not analyzer.has_model_class:
                analyzer.violations.append(
                    "You must define a class `ModelNew` that inherits from torch.nn.Module and the kernel function method you implemented in the custom_ops_lib will be called within ModelNew.."
                )

            # Global rule: must import custom_ops_lib
            if not analyzer.has_custom_lib_import:
                analyzer.violations.append(
                    "You must import custom_ops_lib (import custom_ops_lib as xxx / from custom_ops_lib import xxx)"
                )

            # Global rule: must call custom_ops_lib in ModelNew
            if not analyzer.has_custom_lib_call:
                analyzer.violations.append("The kernel function method you implemented in the custom_ops_lib must be called within ModelNew.")

            if analyzer.violations:
                # Only take the first error
                first = analyzer.violations[0]
                print("Code does not meet requirements, found the following violation:")
                print(f"- {first}")
                return False, first
            else:
                print("Code meets requirements!")
                return True, "success"

        except SyntaxError as e:
            print(f"Code syntax error: {e}")
            return False, f"Code syntax error: {e}"


class ClassAnalyzer(ast.NodeVisitor):
    def __init__(self, forbidden_ops, target_class_name="ModelNew"):
        self.forbidden_ops = forbidden_ops
        self.target_class_name = target_class_name

        self.init_attributes = {} 
        self.violations = []

        self.in_init = False
        self.in_forward = False
        self.current_class = None
        self.import_aliases = {}

        # custom_ops_lib state
        self.has_custom_lib_import = False   # Whether there is import custom_ops_lib series
        self.has_custom_lib_call = False     # Whether it has been called in ModelNew

        self.has_model_class = False

    # ---------- import check ----------

    def visit_Import(self, node):
        # Don't return directly because of violations, it will affect custom_ops_lib detection
        for alias in node.names:
            self.import_aliases[alias.asname or alias.name] = alias.name
            # import custom_ops_lib / import custom_ops_lib as col
            if alias.name == "custom_ops_lib":
                self.has_custom_lib_import = True
        self.generic_visit(node)

    def visit_ImportFrom(self, node):
        module_name = node.module
        for alias in node.names:
            full_name = f"{module_name}.{alias.name}"
            self.import_aliases[alias.asname or alias.name] = full_name
            # from custom_ops_lib import xxx
            if module_name == "custom_ops_lib":
                self.has_custom_lib_import = True
        self.generic_visit(node)

    # ---------- Class processing, only precisely handle ModelNew ----------

    def visit_ClassDef(self, node):
        # We will traverse all classes, but only target_class_name will be treated as "current class" for analysis
        prev_class = self.current_class

        if node.name == self.target_class_name:
            self.current_class = node.name

            # Only analyze ModelNew that inherits from torch.nn.Module
            is_module = False
            for base in node.bases:
                full_attr_str = self.get_full_attr(base)
                root = full_attr_str.split(".")[0]
                if root in self.import_aliases:
                    full_attr_str = full_attr_str.replace(root, self.import_aliases[root], 1)
                if full_attr_str == "torch.nn.Module":
                    is_module = True
                    break

            if is_module:
                self.has_model_class = True
                self.analyze_model_class(node)
            # Continue traversing class body (functions inside)
            self.generic_visit(node)

            self.current_class = prev_class
        else:
            # Non-ModelNew classes: don't change current_class, traverse directly (but won't trigger forward check logic)
            self.generic_visit(node)

    def analyze_model_class(self, node):
        for item in node.body:
            if isinstance(item, ast.FunctionDef):
                if item.name == '__init__':
                    self.in_init = True
                    self.visit(item)
                    self.in_init = False
                elif item.name == 'forward':
                    self.in_forward = True
                    self.visit(item)
                    self.in_forward = False
                    # Don't exit early because of violations, this will affect custom_ops_lib detection

    # ---------- Assignment (record self.xxx category) ----------

    def visit_Assign(self, node):
        if self.in_init:
            for target in node.targets:
                if (
                    isinstance(target, ast.Attribute)
                    and isinstance(target.value, ast.Name)
                    and target.value.id == 'self'
                ):
                    attr_name = target.attr
                    value = node.value

                    if isinstance(value, ast.Constant):
                        self.init_attributes[attr_name] = "simple_value"
                    elif isinstance(value, ast.Name):
                        self.init_attributes[attr_name] = "simple_name"
                    elif isinstance(value, ast.Call):
                        # self.xxx = self.some_method(...)
                        if (
                            isinstance(value.func, ast.Attribute)
                            and isinstance(value.func.value, ast.Name)
                            and value.func.value.id == 'self'
                        ):
                            self.init_attributes[attr_name] = "simple_func"
                        else:
                            is_forbidden, _ = self._check_forbidden_ops(value)
                            if is_forbidden:
                                self.init_attributes[attr_name] = "forbidden_value"
                            else:
                                self.init_attributes[attr_name] = "simple_tensor"
                    else:
                        self.init_attributes[attr_name] = "unknown"

        self.generic_visit(node)

    # ---------- Function call ----------

    def visit_Call(self, node):
        # 1. Regardless of whether there are violations before, prioritize checking custom_ops_lib calls
        if (
            self.current_class == self.target_class_name
            and not self.has_custom_lib_call
        ):
            if self._is_custom_lib_call(node):
                self.has_custom_lib_call = True

        # 2. Below is the original forward rule check
        if self.in_forward:
            is_forbidden, full_str = self._check_forbidden_ops(node)
            if is_forbidden:
                self.violations.append(
                    f"In the forward method, a prohibited method is directly called: {full_str}(). You should implement the operations from forward() in class `Model` in the custom_ops_lib and call it from there."
                )
                # Here return only ends the deep traversal of current Call, doesn't affect subsequent statements
                return

            # Check self.xxx(), and xxx is forbidden_value in __init__
            if (
                isinstance(node.func, ast.Attribute)
                and isinstance(node.func.value, ast.Name)
                and node.func.value.id == 'self'
            ):
                full_attr = self.get_full_attr(node.func)
                parts = full_attr.split(".")
                if len(parts) >= 2:
                    attr_name = parts[1]
                    if (
                        attr_name in self.init_attributes
                        and self.init_attributes[attr_name] == "forbidden_value"
                        and len(parts) == 2
                    ):
                        self.violations.append(
                            f"In the forward method, the model layer is directly called: self.{attr_name}(). You should implement the operations from forward() in class `Model` in the custom_ops_lib and call it from there."
                        )
                        return

        self.generic_visit(node)

    # ---------- Utility functions ----------
    def get_full_attr(self, node, init=None):
        attrs = []
        while isinstance(node, ast.Attribute):
            attrs.append(node.attr)
            node = node.value
        if isinstance(node, ast.Name):
            attrs.append(node.id)
        if init:
            attrs[-1] = init
        return ".".join(reversed(attrs))

    def _check_forbidden_ops(self, node):
        # node is ast.Call
        if isinstance(node.func, ast.Attribute):
            full_attr_str = self.get_full_attr(node.func)  # e.g., "torch.matmul"
            root = full_attr_str.split(".")[0]
            if root in self.import_aliases:
                full_attr_str = full_attr_str.replace(
                    root, self.import_aliases[root], 1
                )
            if full_attr_str in self.forbidden_ops:
                return True, full_attr_str

        elif isinstance(node.func, ast.Name):
            func_name = node.func.id
            if (
                func_name in self.import_aliases
                and self.import_aliases[func_name].startswith('torch')
            ):
                full_attr_str = self.import_aliases[func_name]
                if full_attr_str in self.forbidden_ops:
                    return True, full_attr_str

        return False, ""

    def _is_custom_lib_call(self, node):
        """
        Determine if a call comes from custom_ops_lib:
        - import custom_ops_lib; custom_ops_lib.xxx(...)
        - import custom_ops_lib as col; col.xxx(...)
        - from custom_ops_lib import my_op; my_op(...)
        """
        # custom_ops_lib.xxx(...)
        if isinstance(node.func, ast.Attribute):
            full_attr_str = self.get_full_attr(node.func)  # "custom_ops_lib.xxx" or "col.xxx"
            parts = full_attr_str.split(".")
            if not parts:
                return False
            root = parts[0]
            # root might be an alias
            if root in self.import_aliases:
                root_full = self.import_aliases[root]
            else:
                root_full = root
            if root_full == "custom_ops_lib":
                return True

        # from custom_ops_lib import xxx; xxx(...)
        elif isinstance(node.func, ast.Name):
            func_name = node.func.id
            if func_name in self.import_aliases:
                full_name = self.import_aliases[func_name]  # "custom_ops_lib.xxx"
                if full_name.startswith("custom_ops_lib."):
                    return True

        return False

def find_EXEC_NPU_CMD(generated_code: str):
    for line in generated_code.splitlines():
        if "EXEC_NPU_CMD" in line:
            return True, "Kernel dispatch detected (EXEC_NPU_CMD found)."
    return False, (
        "Invalid pybind implementation: no kernel dispatch detected.\n"
        "- Missing `EXEC_NPU_CMD`.\n"
        "- Reason: Without dispatching the kernel, operator semantics cannot be executed in `kernel_src`.\n"
        "- This indicates that the operator either performs no computation or bypasses the kernel entirely."
    )

def extract_custom_impl_npu_func_name(generated_code: str):
    m = re.search(r"PYBIND11_MODULE\s*\([^)]+\)\s*\{[\s\S]*?&\s*([A-Za-z_]\w*)\b", generated_code)
    if not m:
        return False, (
            "Failed to locate the NPU implementation function registered in `PYBIND11_MODULE`.\n"
            "- Reason: The pybind module must register a function pointer (e.g., `&xxx_custom_impl_npu`).\n"
            "- Without a registered implementation, the custom operator cannot be invoked correctly from Python."
        )
    return True, m.group(1)

def evaluate_pybind_src(func_body: str, func_name: str):
    """
    Validate whether the pybind implementation complies with kernel-only semantics:
    - No semantic computation is allowed outside kernel_src.
    - Pybind code may only allocate output buffers and dispatch the kernel.
    - Return may be `return y;` or `return y.view(...);` etc. (layout-only transforms).
    """

    # ---------- 1) Allowed allocation APIs ----------
    ALLOC_FUNCS = {
        "empty", "zeros", "ones", "full", "empty_like",
        "zeros_like", "ones_like", "full_like", "empty_strided",
        "from_blob", "tensor", "new_empty", "new_zeros",
        "new_ones", "new_full", "new_empty_strided", "intarrayref", "tensoroptions",
        "scalar_tensor"
    }

    # ---------- 2) Reject non-allocation ATen / torch calls (in function body) ----------
    AT_TORCH_CALL_RE = re.compile(r"\b(?:at|torch)::([A-Za-z_]\w*)\s*\(", re.MULTILINE)
    for m in AT_TORCH_CALL_RE.finditer(func_body):
        fn = m.group(1)
        if fn.lower() not in ALLOC_FUNCS:
            return False, (
                f"Invalid implementation in `{func_name}`: semantic computation detected outside the kernel side and host side.\n"
                f"- Forbidden call: `at::/torch::{fn}`\n"
                f"- Rule: Outside kernel side and host side, only tensor allocation and kernel dispatch are allowed."
            )

    # ---------- 3) Kernel dispatch is mandatory ----------
    if "EXEC_NPU_CMD" not in func_body:
        return False, (
            f"Invalid implementation in `{func_name}`: kernel is never invoked.\n"
            f"- Missing `EXEC_NPU_CMD`.\n"
            f"- This indicates a semantic bypass."
        )

    return True, "success"

def extract_custom_impl_npu_func_body(source_code: str, func_name: str):

    pattern = r'(?:^|[^&])\b' + re.escape(func_name) + r'\b\s*\([^;]*?\)\s*(?=[^{};]*\{)'
    match = re.search(pattern, source_code)
    
    if not match:
        return False, (
            f"Implementation of `{func_name}` not found.\n"
            f"- Reason: The function body corresponding to the registered pybind entry could not be located.\n"
            f"- This usually indicates a mismatch between registration and implementation."
        )

    # Bracket matching algorithm
    start_pos = source_code.find('{', match.end())
    if start_pos == -1: return False, (
        "Opening brace not found.\n"
        "- Reason: The function signature was detected, but its body could not be parsed.\n"
        "- The source code may be incomplete or malformed."
    )

    brace_count = 0
    for i in range(start_pos, len(source_code)):
        if source_code[i] == '{': brace_count += 1
        elif source_code[i] == '}': brace_count -= 1
        
        if brace_count == 0:
            body = source_code[start_pos + 1 : i]
            # Clean indentation
            lines = [l for l in body.splitlines() if l.strip()]
            if not lines: return False, (
                "Function body is empty.\n"
            )
            common_indent = len(lines[0]) - len(lines[0].lstrip())
            cleaned_body = "\n".join(l[common_indent:] if l[:common_indent].isspace() else l.lstrip() for l in lines)
            return True, cleaned_body

    return False, (
        "Matching closing brace not found.\n"
        "- Reason: Function body appears truncated or contains unbalanced braces."
    )


def filter_pybind_src(generated_code: str):
    status, func_name = extract_custom_impl_npu_func_name(generated_code)
    if not status:
        return status, func_name   
    status, func_body = extract_custom_impl_npu_func_body(generated_code, func_name)
    if not status:
        return status, func_body
    return evaluate_pybind_src(func_body, func_name)
    
 
def filter_model_src(code: str):
    f = CodeFilter()
    return f.filter_code(code)

def filter_code_result_all(generated_code: str):
    context = {}
    try:
        compile(generated_code, "<string>", "exec")
        exec(generated_code, context)
    except Exception as e:
        raise Exception(f'Error in generated code {e}')
    python_bind_src = context['python_bind_src']
    
    valid1 = find_EXEC_NPU_CMD(python_bind_src)
    if not valid1[0]:
        return valid1
    valid2 = filter_pybind_src(python_bind_src)
    if not valid2[0]:
        return valid2
    valid3 = filter_model_src(context['model_src'])
    return valid3

if __name__ == "__main__":
    code_file_path = "outputs/gpt-5.2/level2-api-select-shot-complete-err-code-all-correct-12-24_26/20251225_162623/sparse_attention/iteration_1.txt"
    code_file_path = "outputs/gpt-5.2/level2-api-select-shot-complete-err-code-all-correct-12-24_26/20251225_162623/gemm_max_subtract_gelu/iteration_1.txt"
    code_file_path = "outputs/gpt-5.2/level2-api-select-shot-complete-err-code-all-correct-12-24_26/20251225_162623/kv_cached_chat_batch_attention/iteration_1.txt"
    with open(code_file_path, 'r') as f:
        example_code = f.read()
    print(filter_code_result_all(example_code))