import os
import json
from tree_sitter import Language, Parser
import networkx as nx
import sys
from utils.utils import CONSTANTS, iterate_repository_file
import re

class ImportGraphBuilder:
    def __init__(self, project_root, project_name):
        self.project_root = project_root
        self.project_name = project_name
        self.index = {}  # {file_path: {functions: {}, classes: {}}}
        self.import_graph = nx.DiGraph()
        self.language_map = {
            '.py': 'python',
            '.java': 'java'
        }

        Language.build_library('./my-languages.so', ['./tree-sitter-python', './tree-sitter-java'])

    def _get_language(self, file_path):
        ext = os.path.splitext(file_path)[1]
        return self.language_map.get(ext)

    def build_index(self,repo_name):
        print("Building index..." + repo_name)
        code_files = iterate_repository_file(CONSTANTS.repo_base_dir, repo_name)
        repo_base_dir_len = len(CONSTANTS.repo_base_dir.split('/'))
        for file in code_files:
            file_path=file[repo_base_dir_len:]
            language = self._get_language(file_path)

            if not language:
                continue

            with open(file_path, 'r', encoding='utf-8') as f:
                code = f.read()

            if language == 'python':
                code = re.sub(r'#.*$', '', code, flags=re.MULTILINE)  # 单行注释
                code = re.sub(r'(\'\'\'[\s\S]*?\'\'\'|\"\"\"[\s\S]*?\"\"\")', '', code)  # 文档注释
            elif language == 'java':
                code = re.sub(r'//.*$', '', code, flags=re.MULTILINE)  # 单行注释
                code = re.sub(r'/\*[\s\S]*?\*/', '', code)  # 多行注释


            self._parse_file_index(file_path, code, language)

    def _parse_file_index(self, file_path, code, language):

        lang = Language('./my-languages.so', language)
        parser = Parser()
        parser.set_language(lang)
        tree = parser.parse(bytes(code, "utf8"))

        functions = {}
        classes = {}

        if language == 'python':

            query = lang.query("""
            (function_definition
                name: (identifier) @name) @function
            (class_definition
                name: (identifier) @name) @class
            """)

            for node, tag in query.captures(tree.root_node):
                if tag in ('function', 'method'):
                    name_node = node.child_by_field_name('name')
                    name = code[name_node.start_byte:name_node.end_byte]
                    start_line = node.start_point[0]
                    end_line = node.end_point[0]
                    functions[name] = {
                        'start_line': start_line,
                        'end_line': end_line,
                        'source': code.split('\n')[start_line:end_line + 1]
                    }
                elif tag == 'class':
                    name_node = node.child_by_field_name('name')
                    name = code[name_node.start_byte:name_node.end_byte]
                    start_line = node.start_point[0]
                    end_line = node.end_point[0]
                    class_info = {
                        'start_line': start_line,
                        'end_line': end_line,
                        'source': code.split('\n')[start_line:end_line + 1],
                        'methods': {},
                        'class_vars':{}
                    }

                    class_body = node.child_by_field_name('body')

                    for child in class_body.children:

                        if child.type == 'function_definition':
                            method_name_node = child.child_by_field_name('name')
                            method_name = code[method_name_node.start_byte:method_name_node.end_byte]


                            params_node = child.child_by_field_name('parameters')
                            params = code[params_node.start_byte:params_node.end_byte] if params_node else ''


                            decorators = []
                            decorator_nodes = [d for d in child.children if d.type == 'decorator']
                            for decorator in decorator_nodes:
                                decorator_code = code[decorator.start_byte:decorator.end_byte]
                                decorators.append(decorator_code)

                            # method_body = child.child_by_field_name('body')
                            # method_lines = code[method_body.start_byte:method_body.end_byte].split('\n')

                            function_definition_node = child  # 假设 child 是函数定义节点
                            start_byte = function_definition_node.start_byte
                            end_byte = function_definition_node.end_byte
                            complete_method_code = code[start_byte:end_byte]

                            complete_method_lines = complete_method_code.split('\n')

                            class_info['methods'][method_name] = {
                                'decorators': decorators,
                                'params': params,
                                'source': complete_method_lines,
                                'start_line': child.start_point[0],
                                'end_line': child.end_point[0]
                            }


                        elif child.type == 'expression_statement':
                            assignment = child.child_by_field_name('value')
                            if assignment and assignment.type == 'assignment':
                                left_nodes = assignment.children_by_field_name('left')
                                var_names = []
                                for left in left_nodes:
                                    if left.type == 'identifier':
                                        var_names.append(code[left.start_byte:left.end_byte])


                                right_node = assignment.child_by_field_name('right')
                                var_value = code[right_node.start_byte:right_node.end_byte] if right_node else ''

                                for var in var_names:
                                    class_info['class_vars'][var] = {
                                        'value': var_value,
                                        'source': code[assignment.start_byte:assignment.end_byte].split('\n'),
                                        'start_line': assignment.start_point[0],
                                        'end_line': assignment.end_point[0]
                                    }

                    classes[name] = class_info

        elif language == 'java':
            query = lang.query("""
            (class_declaration
                name: (identifier) @class_name
                body: (class_body) @class_body) @class
    
            (method_declaration
                name: (identifier) @method_name) @method
    
            (field_declaration
                (variable_declarator
                    name: (identifier) @field_name)) @field
             """)
            current_class = None
            for node, tag in query.captures(tree.root_node):
                # 类定义处理
                if tag == 'class':
                    name_node = node.child_by_field_name('name')
                    class_name = code[name_node.start_byte:name_node.end_byte]
                    # class_name = self._get_node_text(node.child_by_field_name('name'), code)
                    classes[class_name] = {
                        'start_line': node.start_point[0],
                        'end_line': node.end_point[0],
                        'source': code.split('\n')[node.start_point[0]:node.end_point[0] + 1],
                        'methods': {},
                        'fields': {}
                    }
                    current_class = class_name

                elif tag == 'method' and current_class:
                    name_node = node.child_by_field_name('name')
                    method_name = code[name_node.start_byte:name_node.end_byte]
                    # method_name = self._get_node_text(node.child_by_field_name('name'), code)
                    classes[current_class]['methods'][method_name] = {
                        'start_line': node.start_point[0],
                        'end_line': node.end_point[0],
                        'source': code.split('\n')[node.start_point[0]:node.end_point[0] + 1]
                    }

                elif tag == 'field' and current_class:
                    declarators = [
                        child for child in node.children
                        if child.type == 'variable_declarator'
                    ]

                    for declarator in declarators:
                        name_node = declarator.child_by_field_name('name')
                        if not name_node or name_node.is_missing:
                            continue

                        try:
                            field_name = code[name_node.start_byte:name_node.end_byte]
                        except IndexError:
                            print(f"字段名越界：{name_node}")
                            continue

                        start_line = declarator.start_point[0]
                        end_line = declarator.end_point[0]

                        classes[current_class]['fields'][field_name] = {
                            'start_line': start_line,
                            'end_line': end_line,
                            'source': code.split('\n')[start_line:end_line + 1]
                        }


        self.index[file_path] = {
            'functions': functions,
            'classes': classes
        }

    def parse_imports(self,repo_name):
        """Parse the import relationships of all files"""
        print("Building import relation.." + repo_name)
        code_files = iterate_repository_file(CONSTANTS.repo_base_dir, repo_name)
        repo_base_dir_len = len(CONSTANTS.repo_base_dir.split('/'))
        for file in code_files:
            file_path=file[repo_base_dir_len:]
            language = self._get_language(file_path)

            if not language:
                continue

            with open(file_path, 'r', encoding='utf-8') as f:
                code = f.read()

            rel_path = os.path.relpath(file_path, self.project_root)
            self.import_graph.add_node(rel_path)

            imports = self._parse_imports(code, language, file_path)
            for imp in imports:
                self._process_import(rel_path, imp)


    def _parse_imports(self, code, language, file_path):
        """Parse the import statement of a single file"""
        imports = []
        lang = Language('./my-languages.so', language)
        parser = Parser()
        parser.set_language(lang)
        tree = parser.parse(bytes(code, "utf8"))

        if language == 'python':
            query = lang.query("""
            (import_statement) @import
            (import_from_statement) @import_from
            """)
        elif language == 'java':
            query = lang.query("""
            (import_declaration) @import
            """)

        for node, _ in query.captures(tree.root_node):
            imp_info = self._parse_import_node(node, code, language, file_path)
            if imp_info:
                imports.append(imp_info)

        return imports

    def _parse_import_node(self, node, code, language, current_path):
        if language == 'python':
            # print(node.type)
            if node.type == 'import_statement':
                return self._parse_python_import(node, code, current_path)
            elif node.type == 'import_from_statement':
                return self._parse_python_from_import(node, code, current_path)
        elif language == 'java':

            return self._parse_java_import(node, code)
        return None

    def _parse_python_import(self, node, code, current_path):
        """
        import module
        import package.module as alias
        """
        module_parts = []
        alias = None
        is_third_party = False
        for child in node.children:
            if child.type == 'dotted_name':
                module_parts = self._get_dotted_name(child, code)
            elif child.type == 'aliased_import':
                alias_node = child.children[-1]
                alias = code[alias_node.start_byte:alias_node.end_byte]

        # import_line = code[node.start_byte:node.end_byte].strip()
        full_module = '.'.join(module_parts)
        is_third_party = self._is_third_party_module(full_module, current_path)

        return {
            'module': full_module,
            'alias': alias,
            'is_third_party': is_third_party,
            'import_type': 'module',
            'imported_items': [{
                'name': full_module or alias ,
                'type': 'module',
                'import_line':[code[node.start_byte:node.end_byte]],
                'source_lines': []  # 模块级导入无具体代码
            }]
        }

    def _parse_python_from_import(self, node, code, current_path):
        """
        from package import module
        from package.submodule import func as f
        from . import relative_module
        """
        # 获取from后面的模块路径
        module_parts = []
        for child in node.children:
            if child.type == 'dotted_name':
                module_parts=self._get_dotted_name(child, code)
            elif child.type == 'relative_import':
                module_parts=self._get_relative_import(child, code, current_path)
            elif child.type == 'import':
                break

        imported_items = []
        import_clause = node.child_by_field_name('name')
        if import_clause:
            for item_node in import_clause.children:
                #print(item_node.type)
                if item_node.type == 'aliased_import':
                    name_node = item_node.children[0]
                    alias_node = item_node.children[-1]
                    item_name = code[name_node.start_byte:name_node.end_byte]
                    alias = code[alias_node.start_byte:alias_node.end_byte]
                else:
                    item_name = code[item_node.start_byte:item_node.end_byte]
                    alias = None


                item_info = self._find_imported_item(
                    module_parts,
                    item_name,
                    current_path
                )


                if item_info.get('type', 'unknown') == 'class':
                    imported_items.append({
                        'name': item_name,
                        'alias': alias,
                        'type': item_info.get('type', 'unknown'),
                        'import_line': [code[node.start_byte:node.end_byte]],
                        'source_lines': item_info.get('source', []),
                        'methods': item_info.get('methods', []),
                        'class_vals':item_info.get('class_vals', []),
                    })
                else:
                    imported_items.append({
                        'name': item_name,
                        'alias': alias,
                        'type': item_info.get('type', 'unknown'),
                        'import_line': [code[node.start_byte:node.end_byte]],
                        'source_lines': item_info.get('source', []),
                    })

        full_module = '.'.join(module_parts)
        is_third_party = self._is_third_party_module(full_module, current_path)

        return {
            'module': full_module,
            'is_third_party': is_third_party,
            'import_type': 'from_import',
            'imported_items': imported_items
        }

    def _is_third_party_module(self, module_name, current_path):
        """Determine whether it is a third-party library"""
        if module_name in sys.stdlib_module_names:
            return True

        module_path = module_name.replace('.', '/')
        possible_paths = [
            os.path.join(os.path.dirname(current_path), f"{module_path}.py"),
            os.path.join(self.project_root, module_path.replace('.', '/') + '.py')
        ]

        for path in possible_paths:
            if os.path.exists(path):
                return False


        if hasattr(self, 'dependency_dirs') and self.dependency_dirs:

            top_level_module = module_name.split('.', 1)[0]

            for dep_dir in self.dependency_dirs:
                module_py_path = os.path.join(dep_dir, f"{top_level_module}.py")
                module_dir_path = os.path.join(dep_dir, top_level_module)

                if os.path.exists(module_py_path) or os.path.isdir(module_dir_path):
                    return True

        return True

    def _find_imported_item(self, module_parts, item_name, current_path):
        """Search for detailed information about the imported item in the index"""
        module_path = '/'.join(module_parts) + '.py'

        abs_path = os.path.join(
            os.path.dirname(current_path),
            module_path
        )

        target_file = self.index.get(abs_path)
        if not target_file:
            return {}

        if item_name in target_file['functions']:
            func = target_file['functions'][item_name]
            return{
                'type': 'function',
                'source': func['source']
            }

        if item_name in target_file['classes']:
            cls = target_file['classes'][item_name]

            return{
                'type': 'class',
                'source': cls['source'],
                'methods': cls['methods'],
                'class_vars': cls['class_vars'],
            }
            # results.append(result)

        if '.' in item_name:
            cls_name, method_name = item_name.split('.', 1)
            if cls_name in target_file['classes']:
                return{
                    'type': 'method',
                    'source': target_file['classes'][cls_name]['source']
                }

        return {'type':'none'}

    def _get_dotted_name(self, node, code):
        """Parse point-split expressions"""
        parts = []
        current = node
        while True:
            if current.type == 'dotted_name':
                for child in current.children:
                    if child.type == 'identifier':
                        parts.append(code[child.start_byte:child.end_byte])
                break
            elif current.type == 'identifier':
                parts.append(code[current.start_byte:current.end_byte])
                break
            current = current.children[0]
        return parts

    def _get_relative_import(self, node, code, current_path):
        relative_level = 0
        parts=[]
        for child in node.children:
            if child.type == '.':
                relative_level += 1
            elif child.type == 'dotted_name':
                parts = self._get_dotted_name(child, code)
            elif child.type == 'identifier':
                parts = [code[child.start_byte:child.end_byte]]

        # Convert to an absolute path
        current_dir = os.path.dirname(current_path)
        base_dir = os.path.abspath(os.path.join(
            self.project_root,
            *['..'] * relative_level
        ))

        abs_path = os.path.abspath(os.path.join(
            base_dir,
            *parts
        )).replace(self.project_root, '').lstrip('/')

        return abs_path.split('.')[0].split('/')

    def _parse_java_import(self, node, code):
        try:
            # 解析基础信息
            is_static = False
            is_wildcard = False
            import_path = []
            import_range = (node.start_byte, node.end_byte)

            # 遍历AST节点
            for child in node.children:
                if child.type == 'static':
                    is_static = True
                elif child.type == 'scoped_identifier':
                    import_path = self._get_java_scoped_identifier(child, code)
                elif child.type == '*' and child.text.decode() == '*':
                    is_wildcard = True


            full_path = '.'.join(import_path)
            if is_wildcard:
                full_path += '.*'


            import_type = self._determine_java_import_type(
                full_path,
                is_static,
                is_wildcard
            )


            imported_items = self._get_java_imported_items(
                import_path,
                is_static,
                is_wildcard,
                code[import_range[0]:import_range[1]]
            )

            return {
                'module': full_path,
                'is_third_party': self._is_java_third_party(full_path),
                'import_type': import_type,
                'imported_items': imported_items
            }
        except Exception as e:
            print(f"Java导入解析错误: {str(e)}")
            return None

    def _get_java_scoped_identifier(self, node, code):
        """Enhanced scope resolution"""
        parts = []
        current = node

        while current and current.type in ('scoped_identifier', 'generic_type'):

            if current.type == 'scoped_identifier':
                name_node = current.child_by_field_name('name')
                parts.append(code[name_node.start_byte:name_node.end_byte])
                current = current.child_by_field_name('scope')

            elif current.type == 'generic_type':
                type_node = current.child_by_field_name('type')
                parts.append(code[type_node.start_byte:type_node.end_byte])
                current = current.child_by_field_name('type_arguments')

        if current and current.type == 'identifier':
            parts.append(code[current.start_byte:current.end_byte])

        return list(reversed(parts))

    def _is_java_third_party(self, full_path):
        """Intelligent judgment of third-party libraries"""
        if full_path.startswith(('java.', 'javax.', 'sun.', 'com.sun.')):
            return True

        package_path = full_path.replace('.', '/')
        possible_paths = [
            os.path.join(self.project_root, f"{package_path}.java"),
            os.path.join(self.project_root, "src/main/java", package_path + ".java"),
            os.path.join(self.project_root, "src/test/java", package_path + ".java")
        ]

        return not any(os.path.exists(p) for p in possible_paths)

    def _determine_java_import_type(self, full_path, is_static, is_wildcard):
        """Determine the type of the imported item"""
        if is_wildcard:
            return 'wildcard_import'
        if is_static:
            return 'static_import'
        if full_path.endswith('.*'):
            return 'package_import'
        return 'class_import'

    def _get_java_imported_items(self, import_path, is_static, is_wildcard, import_stmt):
        """Obtain the details of the imported items from the index"""
        items = []

        target_class = import_path[-1] if import_path else ''
        found_files = [
            fpath for fpath, data in self.index.items()
            if target_class in data['classes']
        ]

        if not found_files:
            return items

        target_file = found_files[0]
        class_info = self.index[target_file]['classes'].get(target_class, {})

        if is_wildcard:
            package = '.'.join(import_path[:-1])
            return [{
                'name': '*',
                'type': 'wildcard',
                'source_lines': [import_stmt]
            }]

        if is_static and len(import_path) > 1:
            member_name = import_path[-1]
            member_type = 'static_method' if member_name in class_info.get('methods', {}) else 'static_field'
            source = class_info.get(member_type.split('_')[-1] + 's', {}).get(member_name, {}).get('source', [])

            items.append({
                'name': member_name,
                'type': member_type,
                'import_line': [import_stmt],
                'source_lines': source,
                'target_file': target_file
            })

        else:
            items.append({
                'name': target_class,
                'type': 'class',
                'import_line': [import_stmt],
                'source_lines': class_info.get('source', []),
                'methods': class_info.get('methods', {}),
                'fields': class_info.get('fields', {}),
                'target_file': target_file
            })
        return items

    def _process_import(self, source_file, imp_info):
        """Handle import relationships"""
        target_module = imp_info['module']

        if imp_info['is_third_party']:
            self.import_graph.add_edge(source_file, target_module)
        else:

            target_path = self._resolve_local_path(target_module, source_file)
            #print(target_path)
            if target_path:
                self.import_graph.add_edge(source_file, target_path)

        for item in imp_info['imported_items']:
            self.import_graph.nodes[source_file].setdefault(
                        'imports', []).append(item)

    def _resolve_local_path(self, target_module, source_file):
        """Parse the imported module into the file path within the project"""

        source_dir = os.path.dirname(os.path.join(self.project_root, source_file))

        ext = os.path.splitext(source_file)[1]
        if ext == '.py':
            return self._resolve_python_path(target_module, source_dir)
        elif ext == '.java':
            return self._resolve_java_path(target_module, source_dir)
        return None

    def _resolve_python_path(self, module_name, source_dir):
        """Parse the path of the Python module"""
        if module_name.startswith('.'):
            return self._resolve_relative_import(module_name, source_dir)

        module_path = module_name.replace('.', '/')

        search_paths = [
            os.path.join(self.project_root, module_path + '.py'),
            os.path.join(self.project_root, module_path, '__init__.py'),
            os.path.join(source_dir, module_path + '.py'),
            os.path.join(source_dir, module_path, '__init__.py')
        ]

        for path in search_paths:
            if os.path.isfile(path):
                return os.path.relpath(path, self.project_root)

        package_dir = os.path.join(self.project_root, module_path)
        if os.path.isdir(package_dir):
            return os.path.relpath(package_dir, self.project_root)

        return None

    def _resolve_relative_import(self, module_name, source_dir):
        """Parse Python relative import"""
        levels = 0
        while module_name.startswith('.'):
            levels += 1
            module_name = module_name[1:]

        base_dir = os.path.abspath(os.path.join(source_dir, *['..'] * levels))

        if not base_dir.startswith(self.project_root):
            return None

        module_path = module_name.replace('.', '/')
        full_path = os.path.join(base_dir, module_path)

        possible_paths = [
            full_path + '.py',
            os.path.join(full_path, '__init__.py')
        ]

        for path in possible_paths:
            if os.path.isfile(path):
                return os.path.relpath(path, self.project_root)

        if os.path.isdir(full_path):
            return os.path.relpath(full_path, self.project_root)

        return None

    def _resolve_java_path(self, full_class_name, source_dir):

        class_path = full_class_name.replace('.', '/') + '.java'

        search_paths = [
            os.path.join(self.project_root, class_path),
            os.path.join(source_dir, class_path)
        ]

        maven_src_path = os.path.join(
            self.project_root,
            'src/main/java',
            class_path
        )
        search_paths.insert(0, maven_src_path)

        for path in search_paths:
            if os.path.isfile(path):
                return os.path.relpath(path, self.project_root)

        return None

    def save_to_jsonl(self, output_dir):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir, exist_ok=True)

        output_path = os.path.join(output_dir, f"{self.project_name}.jsonl")

        with open(output_path, 'w') as f:
            for node in self.import_graph.nodes():
                imports = self.import_graph.nodes[node].get('imports', [])
                entry = {
                    'project': self.project_name,
                    'file_path': node,
                    'import_info': imports
                }
                f.write(json.dumps(entry) + '\n')


if __name__ == '__main__':
    for repo in CONSTANTS.repos:
        print(f'Processing repo {repo}')
        builder = ImportGraphBuilder(CONSTANTS.repo_base_dir , repo)
        builder.build_index(repo)
        builder.parse_imports(repo)
        builder.save_to_jsonl('./output1')