"""Post-processing LLM-generated Python code implemented using tree-sitter."""

import ast
import os
import pathlib
from typing import Any, Dict, Generator, List, Optional, Set, Tuple

from tqdm import tqdm

from bigcodebench.syncheck import syntax_check

try:
    from pqdm.processes import pqdm
except ImportError:  # pragma: no cover - optional dependency
    pqdm = None

try:
    import tree_sitter_python
    from tree_sitter import Language, Node, Parser
except ImportError:  # pragma: no cover - optional dependency
    tree_sitter_python = None
    Language = None
    Parser = None
    Node = Any

_HAS_TREE_SITTER = tree_sitter_python is not None

CLASS_TYPE = "class_definition"
FUNCTION_TYPE = "function_definition"
IMPORT_TYPE = ["import_statement", "import_from_statement"]
IDENTIFIER_TYPE = "identifier"
ATTRIBUTE_TYPE = "attribute"
RETURN_TYPE = "return_statement"
EXPRESSION_TYPE = "expression_statement"
ASSIGNMENT_TYPE = "assignment"


def _node_lines(node: ast.AST, lines: List[str]) -> List[str]:
    lineno = getattr(node, "lineno", None)
    end_lineno = getattr(node, "end_lineno", None)
    if lineno is None or end_lineno is None:
        return []
    return lines[lineno - 1 : end_lineno]


def _function_has_return(node: ast.AST) -> bool:
    return any(isinstance(child, ast.Return) for child in ast.walk(node))


def _collect_ast_deps(node: ast.AST) -> Set[str]:
    deps: Set[str] = set()
    for child in ast.walk(node):
        if isinstance(child, ast.Name) and isinstance(child.ctx, ast.Load):
            deps.add(child.id)
    return deps


def _trim_after_entrypoint(output: str, entrypoint: Optional[str]) -> str:
    if not entrypoint or not output:
        return output
    lines = output.splitlines()
    outer_lines = []
    for i in range(len(lines) - 1, -1, -1):
        if lines[i].startswith(" "):
            break
        if entrypoint in lines[i]:
            outer_lines.append(i)
    if outer_lines:
        return "\n".join(lines[: outer_lines[-1]])
    return output


def _extract_target_code_or_empty_ast(code: str, entrypoint: Optional[str]) -> str:
    try:
        module = ast.parse(code)
    except SyntaxError:
        return ""

    lines = code.splitlines()
    import_nodes: List[ast.AST] = []
    definition_nodes: List[Tuple[str, ast.AST]] = []
    class_names: Set[str] = set()
    function_names: Set[str] = set()
    variable_names: Set[str] = set()

    for child in module.body:
        if isinstance(child, (ast.Import, ast.ImportFrom)):
            import_nodes.append(child)
        elif isinstance(child, ast.ClassDef):
            name = child.name
            if name not in class_names and name not in variable_names and name not in function_names:
                definition_nodes.append((name, child))
                class_names.add(name)
        elif isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
            name = child.name
            if name not in class_names and name not in variable_names and name not in function_names:
                if _function_has_return(child):
                    definition_nodes.append((name, child))
                    function_names.add(name)
        elif isinstance(child, ast.Assign):
            if len(child.targets) == 1 and isinstance(child.targets[0], ast.Name):
                name = child.targets[0].id
                if (
                    name not in class_names
                    and name not in variable_names
                    and name not in function_names
                ):
                    definition_nodes.append((name, child))
                    variable_names.add(name)
        elif isinstance(child, ast.AnnAssign):
            if isinstance(child.target, ast.Name):
                name = child.target.id
                if (
                    name not in class_names
                    and name not in variable_names
                    and name not in function_names
                ):
                    definition_nodes.append((name, child))
                    variable_names.add(name)

    if entrypoint:
        name2deps = {name: _collect_ast_deps(node) for name, node in definition_nodes}
        reachable = get_function_dependency(entrypoint, name2deps)
    else:
        reachable = set()

    chunks: List[str] = []
    for node in import_nodes:
        node_lines = _node_lines(node, lines)
        if node_lines:
            chunks.append("\n".join(node_lines))

    for name, node in definition_nodes:
        if entrypoint and name not in reachable:
            continue
        node_lines = _node_lines(node, lines)
        if node_lines:
            chunks.append("\n".join(node_lines))

    return _trim_after_entrypoint("\n".join(chunks), entrypoint)


def code_extract(text: str) -> str:
    lines = text.split("\n")
    longest_line_pair = (0, 0)
    longest_so_far = 0

    for i in range(len(lines)):
        for j in range(i + 1, len(lines)):
            current_lines = "\n".join(lines[i : j + 1])
            if syntax_check(current_lines):
                current_length = sum(1 for line in lines[i : j + 1] if line.strip())
                if current_length > longest_so_far:
                    longest_so_far = current_length
                    longest_line_pair = (i, j)

    return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1])


def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]:

    def dfs_get_deps(node: Node, deps: Set[str]) -> None:
        for child in node.children:
            if child.type == IDENTIFIER_TYPE:
                deps.add(child.text.decode("utf8"))
            else:
                dfs_get_deps(child, deps)

    name2deps = {}
    for name, node in nodes:
        deps = set()
        dfs_get_deps(node, deps)
        name2deps[name] = deps
    return name2deps


def get_function_dependency(
    entrypoint: str, call_graph: Dict[str, Set[str]]
) -> Set[str]:
    queue = [entrypoint]
    visited = {entrypoint}
    while queue:
        current = queue.pop(0)
        if current not in call_graph:
            continue
        for neighbour in call_graph[current]:
            if not (neighbour in visited):
                visited.add(neighbour)
                queue.append(neighbour)
    return visited


def get_definition_name(node: Node) -> str:
    for child in node.children:
        if child.type == IDENTIFIER_TYPE:
            return child.text.decode("utf8")


def traverse_tree(node: Node) -> Generator[Node, None, None]:
    cursor = node.walk()
    depth = 0

    visited_children = False
    while True:
        if not visited_children:
            yield cursor.node
            if not cursor.goto_first_child():
                depth += 1
                visited_children = True
        elif cursor.goto_next_sibling():
            visited_children = False
        elif not cursor.goto_parent() or depth == 0:
            break
        else:
            depth -= 1


def has_return_statement(node: Node) -> bool:
    traverse_nodes = traverse_tree(node)
    for node in traverse_nodes:
        if node.type == RETURN_TYPE:
            return True
    return False


def extract_target_code_or_empty(code: str, entrypoint: Optional[str] = None) -> str:
    code = code_extract(code.strip())
    if not _HAS_TREE_SITTER:
        return _extract_target_code_or_empty_ast(code, entrypoint)
    code_bytes = bytes(code, "utf8")
    parser = Parser(Language(tree_sitter_python.language()))
    tree = parser.parse(code_bytes)
    class_names = set()
    function_names = set()
    variable_names = set()

    root_node = tree.root_node
    import_nodes = []
    definition_nodes = []

    for child in root_node.children:
        if child.type in IMPORT_TYPE:
            import_nodes.append(child)
        elif child.type == CLASS_TYPE:
            name = get_definition_name(child)
            if not (
                name in class_names or name in variable_names or name in function_names
            ):
                definition_nodes.append((name, child))
                class_names.add(name)
        elif child.type == FUNCTION_TYPE:
            name = get_definition_name(child)
            if not (
                name in function_names or name in variable_names or name in class_names
            ) and has_return_statement(child):
                definition_nodes.append((name, child))
                function_names.add(get_definition_name(child))
        elif (
            child.type == EXPRESSION_TYPE and child.children[0].type == ASSIGNMENT_TYPE
        ):
            subchild = child.children[0]
            name = get_definition_name(subchild)
            if not (
                name in variable_names or name in function_names or name in class_names
            ):
                definition_nodes.append((name, subchild))
                variable_names.add(name)

    if entrypoint:
        name2deps = get_deps(definition_nodes)
        reacheable = get_function_dependency(entrypoint, name2deps)

    sanitized_output = b""

    for node in import_nodes:
        sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"

    for pair in definition_nodes:
        name, node = pair
        if entrypoint and not (name in reacheable):
            continue
        sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"

    sanitized_output = sanitized_output[:-1].decode("utf8")

    # ad-hoc approach to remove unnecessary lines, but it works
    lines = sanitized_output.splitlines()
    outer_lines = []
    for i in range(len(lines) - 1, -1, -1):
        if lines[i].startswith(" "):
            break
        if not lines[i].startswith(" ") and entrypoint in lines[i]:
            outer_lines.append(i)
    if outer_lines:
        sanitized_output = "\n".join(lines[: outer_lines[-1]])
    return sanitized_output


def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
    sanitized_code = extract_target_code_or_empty(code, entrypoint).strip()
    if not sanitized_code:
        return code_extract(code)
    return sanitized_code


def process_solution(
    sample_solution: Dict,
    dataset: Dict,
    entry_point: Dict,
    debug_task: str = None,
    calibrate: bool = False,
    is_folder: bool = False,
    target_path: str = None,
):

    task_id = sample_solution.get("task_id")
    if not task_id or task_id not in dataset:
        return None

    dbg_identifier = sample_solution["_identifier"]
    if debug_task is not None and task_id != debug_task:
        return None

    function_name = entry_point.get(task_id)
    old_code = sample_solution.get("solution")

    if old_code is None:
        assert "completion" in sample_solution, sample_solution
        old_code = (
            dataset[task_id]["complete_prompt"]
            + "\n"
            + sample_solution.get("completion")
        )
    else:
        if calibrate:
            old_code = old_code.replace(
                "```python\n    ",
                "```python\n" + dataset[task_id]["complete_prompt"] + "    ",
            )

    new_code = sanitize(code=old_code, entrypoint=function_name)

    # if old code and new code are different, print msg
    if new_code != old_code:
        msg = "Sanitized: " + dbg_identifier
        if is_folder:
            msg += " -> " + dbg_identifier.replace(samples, target_path)
        print(msg)

    return {"task_id": task_id, "solution": new_code}


def script(
    samples: str,
    inplace: bool = False,
    debug_task: str = None,
    calibrate: bool = False,
    parallel: int = 32,
):
    from bigcodebench.data import (
        get_bigcodebench,
        load_solutions,
        write_directory,
        write_jsonl,
    )

    # task_id -> entry_point
    entry_point = {}
    # merge two datasets
    dataset = {**get_bigcodebench()}

    for task_id, problem in dataset.items():
        entry_point[task_id] = problem["entry_point"]

    # make a new folder with "-sanitized" suffix
    is_folder = os.path.isdir(samples)
    target_path = pathlib.Path(samples)
    if not inplace:
        if is_folder:
            if calibrate:
                new_name = target_path.name + "-sanitized-calibrated"
            else:
                new_name = target_path.name + "-sanitized"
        else:
            if calibrate:
                new_name = target_path.name.replace(
                    ".jsonl", "-sanitized-calibrated.jsonl"
                )
            else:
                new_name = target_path.name.replace(".jsonl", "-sanitized.jsonl")
        target_path = target_path.parent / new_name
    target_path = str(target_path)

    nsan = 0
    ntotal = 0

    new_solutions = []

    parallel_arg_list = [
        {
            "sample_solution": sample_solution,
            "dataset": dataset,
            "entry_point": entry_point,
            "debug_task": debug_task,
            "calibrate": calibrate,
            "is_folder": is_folder,
            "target_path": target_path,
        }
        for sample_solution in load_solutions(samples)
    ]

    if pqdm is None:
        results = [process_solution(**args) for args in parallel_arg_list]
    else:
        results = pqdm(
            parallel_arg_list,
            process_solution,
            n_jobs=min(parallel, os.cpu_count()),
            argument_type="kwargs",
        )

    for result in results:
        if result is not None:
            new_solutions.append(result)
            nsan += 1
        ntotal += 1

    if is_folder:
        write_directory(target_path, new_solutions)
    else:
        write_jsonl(target_path, new_solutions)

    if nsan > 0:
        print(f"Sanitized {nsan} out of {ntotal} files.")
    else:
        print(f"All files seems valid -- no files are sanitized.")
    print(f"Check the sanitized files at {target_path}")


def main():
    from fire import Fire

    Fire(script)


if __name__ == "__main__":
    main()
