#!/usr/bin/env python3
"""
AST Equivalence Checker for Structure Preservation Verification.

Implements deterministic AST comparison to verify that JO's style
transformations preserve program structure (group invariance).
"""

import ast
import re
from typing import Dict, Set, Tuple, Optional


class ASTNormalizer(ast.NodeTransformer):
    """
    Normalize AST for structure comparison.

    Normalization steps:
    1. Remove location info (lineno, col_offset)
    2. Alpha-normalize local variable names
    3. Remove docstrings (they contain name-dependent text)
    4. Normalize string quotes (irrelevant at AST level)
    """

    def __init__(self, strip_docstrings: bool = True):
        self.scope_stack = [{}]  # Stack of {original_name: canonical_name}
        self.name_counter = 0
        self.strip_docstrings = strip_docstrings
        self.builtins = {
            'print', 'len', 'range', 'int', 'str', 'list', 'dict', 'set',
            'tuple', 'True', 'False', 'None', 'open', 'sum', 'min', 'max',
            'abs', 'sorted', 'enumerate', 'zip', 'map', 'filter', 'any',
            'all', 'isinstance', 'type', 'float', 'bool', 'bytes', 'object',
            'Exception', 'ValueError', 'TypeError', 'KeyError', 'IndexError',
            'AttributeError', 'RuntimeError', 'StopIteration', 'super',
            'property', 'staticmethod', 'classmethod', 'reversed', 'iter',
            'next', 'input', 'format', 'repr', 'hash', 'id', 'callable',
            'getattr', 'setattr', 'hasattr', 'delattr', 'vars', 'dir',
            'globals', 'locals', 'exec', 'eval', 'compile', '__name__',
            '__doc__', '__file__', '__import__', 'heapq'
        }

    def _is_docstring(self, node):
        """Check if node is a docstring."""
        return (isinstance(node, ast.Expr) and
                isinstance(node.value, ast.Constant) and
                isinstance(node.value.value, str))

    def _strip_docstring_from_body(self, body):
        """Remove docstring from body if present."""
        if self.strip_docstrings and body and self._is_docstring(body[0]):
            return body[1:] if len(body) > 1 else [ast.Pass()]
        return body

    def visit_Module(self, node):
        node.body = self._strip_docstring_from_body(node.body)
        self.generic_visit(node)
        return node

    def _get_canonical_name(self, original: str) -> str:
        """Get or create canonical name for variable."""
        # Don't rename builtins or attributes
        if original in self.builtins:
            return original

        # Check current and outer scopes
        for scope in reversed(self.scope_stack):
            if original in scope:
                return scope[original]

        # Create new canonical name in current scope
        canonical = f'v{self.name_counter}'
        self.name_counter += 1
        self.scope_stack[-1][original] = canonical
        return canonical

    def _push_scope(self):
        """Enter a new scope."""
        self.scope_stack.append({})

    def _pop_scope(self):
        """Exit current scope."""
        self.scope_stack.pop()

    def visit_FunctionDef(self, node):
        """Normalize function definition."""
        self._push_scope()

        # Normalize function name
        if not node.name.startswith('__'):  # Keep dunder methods
            node.name = self._get_canonical_name(node.name)

        # Normalize arguments
        for arg in node.args.args:
            arg.arg = self._get_canonical_name(arg.arg)

        # Strip docstring from body
        node.body = self._strip_docstring_from_body(node.body)

        # Visit body
        self.generic_visit(node)
        self._pop_scope()
        return node

    def visit_AsyncFunctionDef(self, node):
        """Normalize async function definition."""
        return self.visit_FunctionDef(node)

    def visit_ClassDef(self, node):
        """Normalize class definition."""
        self._push_scope()
        # Strip docstring from body
        node.body = self._strip_docstring_from_body(node.body)
        # Keep class names as-is for now (they're often meaningful)
        self.generic_visit(node)
        self._pop_scope()
        return node

    def visit_Name(self, node):
        """Normalize variable names."""
        if node.id not in self.builtins:
            node.id = self._get_canonical_name(node.id)
        return node

    def visit_arg(self, node):
        """Normalize function argument."""
        node.arg = self._get_canonical_name(node.arg)
        return node

    def generic_visit(self, node):
        """Remove location info from all nodes."""
        for attr in ('lineno', 'col_offset', 'end_lineno', 'end_col_offset'):
            if hasattr(node, attr):
                setattr(node, attr, 0)
        return super().generic_visit(node)


def normalize_ast(code: str) -> Optional[ast.AST]:
    """Parse and normalize an AST for comparison."""
    try:
        tree = ast.parse(code)
        normalizer = ASTNormalizer()
        normalized = normalizer.visit(tree)
        ast.fix_missing_locations(normalized)
        return normalized
    except SyntaxError:
        return None


def ast_to_signature(tree: ast.AST) -> str:
    """Convert AST to a canonical string signature."""
    return ast.dump(tree, annotate_fields=True, include_attributes=False)


def are_ast_equivalent(code1: str, code2: str) -> Tuple[bool, Optional[str]]:
    """
    Check if two code snippets are AST-equivalent.

    Returns:
        (equivalent: bool, diff_info: Optional[str])
    """
    tree1 = normalize_ast(code1)
    tree2 = normalize_ast(code2)

    if tree1 is None:
        return False, "code1 has syntax error"
    if tree2 is None:
        return False, "code2 has syntax error"

    sig1 = ast_to_signature(tree1)
    sig2 = ast_to_signature(tree2)

    if sig1 == sig2:
        return True, None

    # Find first difference for debugging
    for i, (c1, c2) in enumerate(zip(sig1, sig2)):
        if c1 != c2:
            context = sig1[max(0, i-20):i+20]
            return False, f"Differs at position {i}: ...{context}..."

    if len(sig1) != len(sig2):
        return False, f"Length mismatch: {len(sig1)} vs {len(sig2)}"

    return False, "Unknown difference"


def compute_ast_edit_distance(code1: str, code2: str) -> int:
    """
    Compute approximate AST edit distance.

    Counts number of different AST node types/structures.
    """
    tree1 = normalize_ast(code1)
    tree2 = normalize_ast(code2)

    if tree1 is None or tree2 is None:
        return -1  # Invalid code

    nodes1 = list(ast.walk(tree1))
    nodes2 = list(ast.walk(tree2))

    # Count node types
    types1 = [type(n).__name__ for n in nodes1]
    types2 = [type(n).__name__ for n in nodes2]

    # Simple difference count
    diff = abs(len(types1) - len(types2))
    for t1, t2 in zip(types1, types2):
        if t1 != t2:
            diff += 1

    return diff


class InvarianceChecker:
    """
    Check JO's stability under semantics-preserving transformations.

    Tests that JO produces consistent outputs under the group H of:
    - Alpha-renaming
    - Whitespace changes
    - Comment insertion/removal
    """

    def __init__(self, jo_pipeline):
        """
        Initialize with a JO pipeline function.

        Args:
            jo_pipeline: Function code -> processed_code
        """
        self.jo = jo_pipeline

    def generate_variants(self, code: str, n_variants: int = 5) -> list:
        """Generate semantics-preserving variants of code."""
        import random
        variants = []

        for i in range(n_variants):
            variant = code

            # Variant type 1: Change whitespace
            if random.random() < 0.5:
                variant = self._add_blank_lines(variant)

            # Variant type 2: Change variable names (alpha-rename)
            if random.random() < 0.5:
                variant = self._alpha_rename(variant, seed=i)

            # Variant type 3: Add/remove comments
            if random.random() < 0.5:
                variant = self._toggle_comments(variant)

            variants.append(variant)

        return variants

    def _add_blank_lines(self, code: str) -> str:
        """Add random blank lines (semantics-preserving)."""
        import random
        lines = code.split('\n')
        result = []
        for line in lines:
            result.append(line)
            if random.random() < 0.2 and line.strip():
                result.append('')
        return '\n'.join(result)

    def _alpha_rename(self, code: str, seed: int = 0) -> str:
        """Consistently rename local variables."""
        import random
        random.seed(seed)

        try:
            tree = ast.parse(code)

            # Collect renameable names
            class NameCollector(ast.NodeVisitor):
                def __init__(self):
                    self.names = set()
                    self.builtins = {'print', 'len', 'range', 'sum', 'min', 'max',
                                     'int', 'str', 'list', 'dict', 'set', 'True',
                                     'False', 'None', 'abs', 'sorted', 'enumerate'}

                def visit_Name(self, node):
                    if isinstance(node.ctx, ast.Store) and node.id not in self.builtins:
                        self.names.add(node.id)

                def visit_arg(self, node):
                    if node.arg not in self.builtins:
                        self.names.add(node.arg)

            collector = NameCollector()
            collector.visit(tree)

            # Create rename mapping
            names = list(collector.names)
            random.shuffle(names)
            alt_names = [f'var_{i}' for i in range(len(names))]
            rename_map = dict(zip(names, alt_names))

            # Apply renaming
            class Renamer(ast.NodeTransformer):
                def visit_Name(self, node):
                    if node.id in rename_map:
                        node.id = rename_map[node.id]
                    return node

                def visit_arg(self, node):
                    if node.arg in rename_map:
                        node.arg = rename_map[node.arg]
                    return node

            tree = Renamer().visit(tree)
            ast.fix_missing_locations(tree)
            return ast.unparse(tree)
        except:
            return code

    def _toggle_comments(self, code: str) -> str:
        """Add inline comments (semantics-preserving)."""
        import random
        lines = code.split('\n')
        result = []
        for line in lines:
            if random.random() < 0.1 and line.strip() and '#' not in line:
                line = line + '  # added comment'
            result.append(line)
        return '\n'.join(result)

    def check_invariance(self, code: str, n_variants: int = 5) -> Dict:
        """
        Check JO's invariance under semantics-preserving transforms.

        Returns dict with:
        - ast_invariance_rate: % of variants with equivalent output AST
        - max_output_distance: Maximum edit distance between outputs
        """
        # Get base output
        base_output = self.jo(code)
        base_ast = normalize_ast(base_output)

        if base_ast is None:
            return {'error': 'Base output has syntax error'}

        # Generate and test variants
        variants = self.generate_variants(code, n_variants)
        ast_matches = 0
        max_distance = 0

        for variant in variants:
            variant_output = self.jo(variant)

            # Check AST equivalence
            equiv, _ = are_ast_equivalent(base_output, variant_output)
            if equiv:
                ast_matches += 1

            # Compute distance
            dist = compute_ast_edit_distance(base_output, variant_output)
            if dist >= 0:
                max_distance = max(max_distance, dist)

        return {
            'ast_invariance_rate': ast_matches / n_variants,
            'max_output_distance': max_distance,
            'n_variants': n_variants
        }


if __name__ == '__main__':
    # Test AST equivalence
    code1 = '''
def foo(x):
    y = x + 1
    return y
'''

    code2 = '''
def bar(a):
    b = a + 1
    return b
'''

    code3 = '''
def foo(x):
    y = x * 2  # different operation
    return y
'''

    print("Testing AST equivalence:")
    print(f"code1 vs code2 (should be equivalent): {are_ast_equivalent(code1, code2)}")
    print(f"code1 vs code3 (should differ): {are_ast_equivalent(code1, code3)}")

    print(f"\nAST edit distance code1 vs code2: {compute_ast_edit_distance(code1, code2)}")
    print(f"AST edit distance code1 vs code3: {compute_ast_edit_distance(code1, code3)}")
