import ast
import re
from itertools import combinations
from functools import reduce
from unidiff import PatchSet


def extract_imports(code):
    tree = ast.parse(code)
    imports = []
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            for alias in node.names:
                imports.append(alias.name)
        elif isinstance(node, ast.ImportFrom):
            for alias in node.names:
                imports.append(
                    f"{node.module}.{alias.name}" if node.module else alias.name
                )
    return imports


def get_possible_names(problem_statement):
    code_pat = re.compile(r"```(?:[a-z]+)?(.*?)```", re.DOTALL)
    inline_pat = re.compile(r"[^`]`[^`]+`")
    name_pat = re.compile(
        r"[a-zA-Z_]+(?:\.[a-zA-Z_][a-zA-Z_0-9]*)*"
    )  # matches any valid python name
    qual_name_pat = re.compile(
        r"[a-zA-Z_]+(?:\.[a-zA-Z_][a-zA-Z_0-9]*)+"
    )  # matches period separated names
    all_names = set()
    for snippet in code_pat.findall(problem_statement):
        try:
            all_names |= extract_imports(snippet)
        except:
            pass
    for snippet in inline_pat.findall(problem_statement):
        for name in name_pat.findall(snippet):
            all_names.add(name)
    all_names |= set(qual_name_pat.findall(problem_statement))
    substrs = set()
    for name1, name2 in combinations(all_names, 2):
        if name1 in name2:
            substrs.add(name1)
        if name2 in name1:
            substrs.add(name2)
    all_names -= substrs
    return all_names


def extract_filenames(problem_statement):
    file_pat = re.compile(r"(?:\/[A-Za-z0-9\.\-\_]+)+\.py")
    return set(file_pat.findall(problem_statement))


def get_related_filenames(instance, repo_graph):
    problem_statement = instance["problem_statement"]
    possible_names = get_possible_names(problem_statement)
    possible_keys = set()
    for names in possible_names:
        for key in repo_graph.tree.keys():
            for name in names:
                if name == key or key.endswith(f".{name}"):
                    possible_keys.add(key)
    files = reduce(
        lambda x, y: x | y, [repo_graph.tree[key] for key in possible_keys], set()
    )
    file_names = extract_filenames(problem_statement)
    files |= {file_name for file_name in file_names if file_name in repo_graph.graph}
    total_closure = repo_graph.get_closure(files)
    return total_closure


def get_gold_filenames(instance):
    patch = PatchSet(instance["patch"])
    source_files = {
        diff.source_file
        for diff in patch
        if diff.source_file is not None and diff.source_file.endswith(".py")
    }
    source_files = {
        source_file[2:] for source_file in source_files
    }  # remove leading 'a/'
    return source_files


def get_gold_names(instance):
    patch = PatchSet(instance["patch"])
    gold_names = set()
    for diff in patch:
        if diff.source_file is None:
            continue
        for hunk in diff:
            for line in hunk:
                if line.is_added:
                    gold_names |= set(extract_imports(line.value))
    return gold_names
