import ast
import builtins
from typing import Dict, List, Set


class FunctionRenamer(ast.NodeTransformer):
    """
    Rename function definitions and all references to them.
    
    Attributes:
        mapping: {old_name -> new_name}
    """
    def __init__(self, mapping: Dict[str, str]):
        super().__init__()
        self.mapping = mapping
    
    
    # ---------- Helpers ----------
    def rename(self, name: str) -> str:
        return self.mapping.get(name, name)
    
    
    # ---------- Visitors ----------
    def visit_FunctionDef(self, node: ast.FunctionDef):
        node.name = self.rename(node.name)
        self.generic_visit(node)
        return node
    
    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
        node.name = self.rename(node.name)
        self.generic_visit(node)
        return node
    
    def visit_Call(self, node: ast.Call):
        self.generic_visit(node)
        f = node.func
        if isinstance(f, ast.Name):
            f.id = self.rename(f.id)
        elif isinstance(f, ast.Attribute) and isinstance(f.value, ast.Name):
            f.value.id = self.rename(f.value.id)
        return node
    
    def visit_Attribute(self, node: ast.Attribute):
        self.generic_visit(node)
        if isinstance(node.value, ast.Name):
            node.value.id = self.rename(node.value.id)
        return node
    
    def visit_Name(self, node: ast.Name):
        node.id = self.rename(node.id)
        return node


class ScopeAwareIdentifierRenamer(ast.NodeTransformer):
    """
    Scope-aware renaming of identifiers:
    - Only rename names that are actually bound in the current scope chain.
    - Do not rename built-in names (like len/print) unless shadowed by current or enclosing scopes.
    - Also handle keyword argument names in calls.
    
    Attributes:
        mapping: {old_name -> new_name}
    """
    
    def __init__(self, mapping: Dict[str, str]):
        super().__init__()
        self.mapping = mapping
        self.builtins: Set[str] = set(dir(builtins))
        self.scope_stack: List[Set[str]] = []
    
    
    # ---------- Scope helpers ----------
    def push_scope(self):
        self.scope_stack.append(set())
    
    def pop_scope(self):
        self.scope_stack.pop()
    
    def add_binding(self, name: str):
        if self.scope_stack:
            self.scope_stack[-1].add(name)
    
    def is_bound_in_chain(self, name: str) -> bool:
        for s in reversed(self.scope_stack):
            if name in s:
                return True
        return False
    
    
    # ---------- Visitors (Generated by GPT-5) ----------
    def visit_Module(self, node: ast.Module):
        self.push_scope()
        new = self.generic_visit(node)
        self.pop_scope()
        return new

    def visit_FunctionDef(self, node: ast.FunctionDef):
        self.add_binding(node.name)
        self.push_scope()
        for a in node.args.args:
            self.add_binding(a.arg)
        if node.args.vararg:
            self.add_binding(node.args.vararg.arg)
        for a in node.args.kwonlyargs:
            self.add_binding(a.arg)
        if node.args.kwarg:
            self.add_binding(node.args.kwarg.arg)
        node = self.generic_visit(node)
        self.pop_scope()
        return node

    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
        self.add_binding(node.name)
        self.push_scope()
        for a in node.args.args:
            self.add_binding(a.arg)
        if node.args.vararg:
            self.add_binding(node.args.vararg.arg)
        for a in node.args.kwonlyargs:
            self.add_binding(a.arg)
        if node.args.kwarg:
            self.add_binding(node.args.kwarg.arg)
        node = self.generic_visit(node)
        self.pop_scope()
        return node

    def visit_ClassDef(self, node: ast.ClassDef):
        self.add_binding(node.name)
        return self.generic_visit(node)

    def visit_Assign(self, node: ast.Assign):
        node.value = self.visit(node.value)
        for t in node.targets:
            self._collect_target_bindings(t)
            self.visit(t)
        return node

    def _collect_target_bindings(self, target):
        if isinstance(target, ast.Name):
            self.add_binding(target.id)
        elif isinstance(target, ast.Tuple):
            for e in target.elts:
                self._collect_target_bindings(e)

    def visit_AnnAssign(self, node: ast.AnnAssign):
        if node.value:
            node.value = self.visit(node.value)
        if isinstance(node.target, ast.Name):
            self.add_binding(node.target.id)
            node.target = self.visit(node.target)
        return node

    def visit_For(self, node: ast.For):
        node.iter = self.visit(node.iter)
        self._collect_target_bindings(node.target)
        node.target = self.visit(node.target)
        node.body = [self.visit(n) for n in node.body]
        node.orelse = [self.visit(n) for n in node.orelse]
        return node

    def visit_With(self, node: ast.With):
        node.items = [self.visit(i) for i in node.items]
        node.body = [self.visit(n) for n in node.body]
        return node

    def visit_withitem(self, node: ast.withitem):
        node.context_expr = self.visit(node.context_expr)
        if node.optional_vars:
            self._collect_target_bindings(node.optional_vars)
            node.optional_vars = self.visit(node.optional_vars)
        return node

    def visit_NamedExpr(self, node: ast.NamedExpr):
        node.value = self.visit(node.value)
        if isinstance(node.target, ast.Name):
            self.add_binding(node.target.id)
            node.target = self.visit(node.target)
        return node

    def visit_ListComp(self, node: ast.ListComp):
        self.push_scope()
        node.elt = self.visit(node.elt)
        node.generators = [self.visit(g) for g in node.generators]
        self.pop_scope()
        return node

    def visit_SetComp(self, node: ast.SetComp):
        self.push_scope()
        node.elt = self.visit(node.elt)
        node.generators = [self.visit(g) for g in node.generators]
        self.pop_scope()
        return node

    def visit_DictComp(self, node: ast.DictComp):
        self.push_scope()
        node.key = self.visit(node.key)
        node.value = self.visit(node.value)
        node.generators = [self.visit(g) for g in node.generators]
        self.pop_scope()
        return node

    def visit_GeneratorExp(self, node: ast.GeneratorExp):
        self.push_scope()
        node.elt = self.visit(node.elt)
        node.generators = [self.visit(g) for g in node.generators]
        self.pop_scope()
        return node

    def visit_comprehension(self, node: ast.comprehension):
        node.iter = self.visit(node.iter)
        self._collect_target_bindings(node.target)
        node.target = self.visit(node.target)
        node.ifs = [self.visit(i) for i in node.ifs]
        return node

    def visit_Global(self, node: ast.Global):
        if self.scope_stack:
            for n in node.names:
                self.scope_stack[0].add(n)
        return node

    def visit_Nonlocal(self, node: ast.Nonlocal):
        for n in node.names:
            for s in reversed(self.scope_stack[:-1]):
                s.add(n)
                break
        return node

    def visit_arg(self, node: ast.arg):
        if node.arg in self.mapping:
            node.arg = self.mapping[node.arg]
        return node

    def visit_keyword(self, node: ast.keyword):
        parent = getattr(node, "parent", None)
        if isinstance(parent, ast.Call):
            # Skip renaming if it's a call to a built-in like dict/list/set
            if isinstance(parent.func, ast.Name) and parent.func.id in {"dict", "list", "set"}:
                node.value = self.visit(node.value)
                return node

        if node.arg in self.mapping:
            node.arg = self.mapping[node.arg]
        node.value = self.visit(node.value)
        return node

    def visit_Name(self, node: ast.Name):
        old = node.id
        if old in self.mapping:
            if old in self.builtins and not self.is_bound_in_chain(old):
                return node
            node.id = self.mapping[old]
        return node