import re
import jedi
import json
from pathlib import Path
from collections import defaultdict

from typing import Union, Iterable
# from ..utils import is_test
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)


def is_test(name, test_phrases=None):
    if test_phrases is None:
        test_phrases = ["test", "tests", "testing"]
    words = set(re.split(r" |_|\/|\.", name.lower()))
    return any(word in words for word in test_phrases)


def get_definition_lines(name):
    # Gather all the origins of the definition
    origins = [
        origin
        for origin in name.goto(follow_imports=True)
        if origin.module_path
    ]
    definitions = list()
    for origin in origins:
        start_line, _ = origin.get_definition_start_position()
        end_line, _ = origin.get_definition_end_position()
        definitions.append((origin.module_path.as_posix(), start_line, end_line))
    return definitions


def find_all_name_defs(filepath):
    try:
        script = jedi.Script(path=filepath)
    except Exception as e:
        logger.warning(f"Failed to parse file {filepath}: {str(e)}")
        return {}
    results = {script.get_context().full_name: [(filepath.as_posix(), 0, 0)]}
    names = [
        name
        for name in script.get_names(
            all_scopes=False, definitions=True, references=False
        )
        if not name.in_builtin_module()
    ]
    for name in names:
        try:
            full_name = ".".join([name.module_name, name.name])
            results[full_name] = get_definition_lines(name)
        except Exception as e:
            logger.warning(f"Failed to parse definition {full_name}: {str(e)}")
    return results


def find_all_names(filepath):
    """
    Find the original definitions of all the names in a Python file.
    """
    try:
        script = jedi.Script(path=filepath)
    except Exception as e:
        logger.warning(f"Failed to parse file {filepath}: {str(e)}")
        return {}
    results = {script.get_context().full_name: [filepath.as_posix()]}
    names = [
        name
        for name in script.get_names(
            all_scopes=False, definitions=True, references=False
        )
        if not name.in_builtin_module()
    ]
    for name in names:
        try:
            origins = [
                origin.module_path
                for origin in name.goto(follow_imports=True)
                if origin.module_path
            ]
            if origins:
                full_name = ".".join([name.module_name, name.name])
                results[full_name] = [origin.as_posix() for origin in origins]
        except Exception as e:
            logger.warning(f"Failed to find origin of name '{name.name}': {str(e)}")
    return results


class RepoGraph:
    """
    A graph of all the definitions in a Python repository.
    currently only tracks files. If this is too inefficient, we can, track functions and definitions specifically.
    """

    def __init__(self, repo_path: Union[str, Path], include_tests: bool = False, build_graph: bool = True):
        self.repo_path = Path(repo_path).absolute().resolve()
        self.include_tests = include_tests
        self.test_phrases = ["test", "tests", "testing"]
        self.tree = {}
        self.graph = {}
        if build_graph:
            self._build_graph()

    @classmethod
    def from_json(cls, json_path: Union[str, Path]):
        with open(json_path, "r") as f:
            data = json.load(f)
        self = cls(data["repo_path"], data["include_tests"], build_graph=False)
        self.tree = {k: set(v) for k, v in data["tree"].items()}
        self.graph = {k: set(v) for k, v in data["graph"].items()}
        return self

    def _build_graph(self):
        root = self.repo_path.as_posix() + "/"
        start_ix = len(root)
        all_definitions = defaultdict(set)
        graph = defaultdict(set)
        for file in Path(root).rglob("*.py"):
            name = file.as_posix()[start_ix:]
            if not self.include_tests and is_test(name, self.test_phrases):
                continue
            node = set()
            for varname, filenames in find_all_names(file).items():
                for filename in filenames:
                    if filename.startswith(root):  # ignore external imports
                        file_key = filename[start_ix:]
                        if not self.include_tests and is_test(file_key, self.test_phrases):
                            continue
                        all_definitions[varname].add(file_key)
                        node.add(file_key)
            graph[name] = node
        self.tree = dict(all_definitions)
        self.graph = dict(graph)

    def get_closure(self, start: Union[str, Iterable[str]]):
        if isinstance(start, str):
            stack = [start]
        else:
            stack = list(start)
        visited = set()
        while stack:
            node = stack.pop()
            if node not in visited:
                visited.add(node)
                stack.extend(self.graph[node])
        return visited

    def to_json(self):
        data = {
            "tree": {k: list(v) for k, v in self.tree.items()},
            "graph": {k: list(v) for k, v in self.graph.items()},
            "repo_path": self.repo_path.as_posix(),
            "include_tests": self.include_tests,
            "test_phrases": self.test_phrases,
        }
        return json.dumps(data, indent=4)


class RepoDefGraph(RepoGraph):
    def _build_graph(self):
        root = self.repo_path.as_posix() + "/"
        start_ix = len(root)
        all_definitions = defaultdict(set)
        graph = defaultdict(set)
        for file in Path(root).glob("**/*.py"):
            name = file.as_posix()[start_ix:]
            if not self.include_tests and is_test(name, self.test_phrases):
                continue
            node = set()
            for varname, filename_defs in find_all_name_defs(file).items():
                if not varname:
                    continue
                for filename_def in filename_defs:
                    filename, start, end = filename_def
                    if filename.startswith(root):  # ignore external imports
                        file_key = filename[start_ix:]
                        if not self.include_tests and is_test(file_key, self.test_phrases):
                            continue
                        value = (file_key, start, end)
                        all_definitions[varname].add(value)
                        node.add(value)
            graph[name] = node
        self.tree = dict(all_definitions)
        self.graph = dict(graph)

    @classmethod
    def from_json(cls, json_path: Union[str, Path]):
        with open(json_path, "r") as f:
            data = json.load(f)
        self = cls(data["repo_path"], data["include_tests"], build_graph=False)
        self.tree = {k: set(map(tuple, v)) for k, v in data["tree"].items()}
        self.graph = {k: set(map(tuple, v)) for k, v in data["graph"].items()}
        return self
