import ast
from typing import List, Dict, Any
from collections import defaultdict

def _extract_logic_paths(source_code: str, max_paths: int = 50) -> List[Dict[str, Any]]:
    """
    遍历 AST 中的 if/elif/else、for、while、try/except，
    提取所有可能的路径：
      - conditions: 该路径上每个分支的条件或循环/异常描述
      - return: 在该路径上执行的返回表达式（字符串）
    对于 if 的 else 分支，统一标记为 "else"。
    """
    tree = ast.parse(source_code)
    paths: List[Dict[str, Any]] = []

    def visit(node: ast.AST, conds: List[str]):
        if len(paths) >= max_paths:
            return

        if isinstance(node, ast.Return):
            # 记录当前路径及返回值表达式
            return_expr = ast.unparse(node.value).strip() if node.value else None
            paths.append({
                "conditions": list(conds),
                "return": return_expr
            })
            return

        if isinstance(node, ast.If):
            test = ast.unparse(node.test).strip()
            for child in node.body:
                visit(child, conds + [test])
            for child in node.orelse:
                visit(child, conds + ["else"])

        elif isinstance(node, ast.For):
            desc = f"for {ast.unparse(node.target).strip()} in {ast.unparse(node.iter).strip()}"
            exit_desc = f"exit {desc}"
            for child in node.body:
                visit(child, conds + [desc])
            for child in node.orelse:
                visit(child, conds + [exit_desc])

        elif isinstance(node, ast.While):
            test = ast.unparse(node.test).strip()
            exit_desc = f"exit {test}"
            for child in node.body:
                visit(child, conds + [test])
            for child in node.orelse:
                visit(child, conds + [exit_desc])

        elif isinstance(node, ast.Try):
            for child in node.body:
                visit(child, conds + ["try"])
            for handler in node.handlers:
                etype = ast.unparse(handler.type).strip() if handler.type else "Exception"
                for child in handler.body:
                    visit(child, conds + [f"except {etype}"])
            # 忽略 else 和 finally

        else:
            for child in ast.iter_child_nodes(node):
                visit(child, conds)

    visit(tree, [])
    return paths if paths else [{"conditions": [], "return": None}]


def generate_logic_summary(source_code: str) -> Dict[str, Any]:
    """
    返回增强版逻辑摘要，包含：
      - paths: 每条可能路径的条件列表与返回表达式
      - path_conditions: 所有唯一的条件描述（可用于概览）
      - definitions: 变量定义所在的行号
      - operations: 每个变量在赋值或增强赋值时对应的完整语句（字符串列表）
      - calls: 所有函数调用的调用表达式及其出现次数
    """
    try:
        paths = _extract_logic_paths(source_code)
    except SyntaxError as e:
        return {"error": f"SyntaxError at line {e.lineno}: {e.msg}"}

    # 解析 AST，收集 defs, operations, calls
    tree = ast.parse(source_code)
    definitions: Dict[str, int] = {}
    operations: Dict[str, List[str]] = defaultdict(list)
    calls: Dict[str, int] = defaultdict(int)

    class Visitor(ast.NodeVisitor):
        def visit_FunctionDef(self, node: ast.FunctionDef):
            for arg in node.args.args:
                definitions[arg.arg] = node.lineno
            self.generic_visit(node)

        def visit_Assign(self, node: ast.Assign):
            stmt = ast.unparse(node).strip()
            for tgt in node.targets:
                if isinstance(tgt, ast.Name):
                    var = tgt.id
                    definitions[var] = node.lineno
                    operations[var].append(stmt)
            self.generic_visit(node)

        def visit_AugAssign(self, node: ast.AugAssign):
            stmt = ast.unparse(node).strip()
            if isinstance(node.target, ast.Name):
                var = node.target.id
                definitions[var] = node.lineno
                operations[var].append(stmt)
            self.generic_visit(node)

        def visit_Call(self, node: ast.Call):
            call_expr = ast.unparse(node).strip()
            calls[call_expr] += 1
            self.generic_visit(node)

    Visitor().visit(tree)

    # 提取所有唯一的条件描述
    unique_conditions = sorted({
        cond
        for path in paths
        for cond in path["conditions"]
    })

    return {
        "paths": paths,
        "path_conditions": unique_conditions,
        "definitions": definitions,
        "operations": dict(operations),
        "calls": dict(calls)
    }

if __name__ == "__main__":
    # 测试用例：简单的 add 函数
    source_code = '''
def add(x, y):
    return x + y
'''
    # 生成逻辑摘要
    summary = generate_logic_summary(source_code)

    # 美化输出
    import json
    print(json.dumps(summary, indent=2, ensure_ascii=False))