import ast
import os
import subprocess
import tempfile

import jedi
import unidiff
from .repo_graph import get_definition_lines


def apply_patch_using_unix(original_str, patch_str):
    with tempfile.NamedTemporaryFile(delete=False, mode="w") as original_file:
        original_file.write(original_str)
        original_filename = original_file.name
    with tempfile.NamedTemporaryFile(delete=False, mode="w") as patch_file:
        patch_file.write(patch_str)
        patch_filename = patch_file.name
    try:
        subprocess.run(["patch", original_filename, patch_filename], check=True, capture_output=True)
        with open(original_filename, "r") as f:
            return f.read()
    finally:
        # Clean up temporary files
        os.remove(original_filename)
        os.remove(patch_filename)


def find_all_local_defs(code, filepath):
    try:
        script = jedi.Script(code=code, path=filepath)
        code_lines = code.split("\n")
    except Exception as e:
        print(f"Failed to parse file {filepath}: {str(e)}")
        return {}
    results = {}
    names = [
        name
        for name in script.get_names(
            all_scopes=False, definitions=True, references=False
        )
        if not name.in_builtin_module()
    ]
    for name in names:
        try:
            line_nos = list()
            definitions = get_definition_lines(name)
            for module_path, start_line_no, end_line_no in definitions:
                if module_path != filepath:
                    continue
                line_nos.append((start_line_no, end_line_no))
            if line_nos:
                full_name = name.full_name
                start_line_no, end_line_no = line_nos[
                    0
                ]  # usually will be one, but not sure what to do if more
                results[full_name] = (
                    start_line_no,
                    end_line_no,
                    "\n".join(code_lines[start_line_no - 1 : end_line_no]),
                )
        except:
            pass
    return results


def get_detailed_source_lines(source_defs, patched_defs):
    changed_keys = list()
    for key in source_defs.keys():
        if key in patched_defs:
            if source_defs[key][-1] != patched_defs[key][-1]:
                changed_keys.append(key)
        else:
            changed_keys.append(key)
    detailed_lines = set()
    for key in changed_keys:
        start_line_no, end_line_no, _ = source_defs[key]
        detailed_lines |= set(range(start_line_no, end_line_no + 1))
    return detailed_lines


## gpt-enhanced pretty good version
def analyze_file(content, detailed_lines=None):
    tree = ast.parse(content)
    detailed_lines = detailed_lines if detailed_lines else set()

    def traverse_tree(node, parent=None):
        if isinstance(node, (ast.Import, ast.ImportFrom)):
            detailed_lines.add(node.lineno)
        elif isinstance(node, (ast.Assign, ast.AnnAssign, ast.Expr)) and isinstance(
            parent, (ast.Module, ast.ClassDef)
        ):
            detailed_lines.add(node.lineno)
        elif isinstance(node, (ast.Global, ast.Nonlocal)):
            detailed_lines.add(node.lineno)
        elif isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.AsyncFunctionDef)):
            detailed_lines.add(node.lineno)
            for decorator in getattr(node, "decorator_list", []):
                detailed_lines.add(decorator.lineno)
            if ast.get_docstring(node):
                detailed_lines.update(
                    range(
                        node.lineno,
                        node.lineno + len(ast.get_docstring(node).splitlines()) + 2, # +2 for quotes (heuristic)
                    )
                )
        # Recursively traverse children
        for child in ast.iter_child_nodes(node):
            traverse_tree(child, node)

    traverse_tree(tree)
    # Handle module-level docstring
    module_docstring = ast.get_docstring(tree)
    if module_docstring:
        detailed_lines.update(range(1, 2 + len(module_docstring.splitlines())))
    lines = content.splitlines()
    line_num = 1
    output_lines = list()
    while line_num <= len(lines):
        if line_num in detailed_lines or lines[line_num - 1].lstrip().startswith("#"):
            output_lines.append(f"{line_num} {lines[line_num - 1]}")
            line_num += 1
        else:
            count = 1
            while (
                line_num + count <= len(lines)
                and line_num + count not in detailed_lines
            ):
                count += 1
            if count == 1:
                output_lines.append(f"{line_num} {lines[line_num - 1]}")
                line_num += 1
            else:
                output_lines.append(f"... {count} lines ...")
                line_num += count
    return "\n".join(output_lines)


def get_summarized_files(instance, repo_path):
    summary_files = dict()
    patch = unidiff.PatchSet(instance["patch"])
    for diff in patch:
        if diff.source_file is None or not diff.source_file.endswith(".py"):
            continue
        file_key = diff.source_file[2:]
        source_file = os.path.join(repo_path, file_key)
        with open(source_file) as infile:
            source_file_str = infile.read()
        patch_str = str(diff)
        patched_file_str = apply_patch_using_unix(source_file_str, patch_str)
        source_defs = find_all_local_defs(source_file_str, source_file)
        patched_defs = find_all_local_defs(patched_file_str, source_file)
        detailed_lines = get_detailed_source_lines(source_defs, patched_defs)
        summary_files[file_key] = analyze_file(source_file_str, detailed_lines)
    return summary_files
