import difflib
import random
from difflib import SequenceMatcher
from unidiff import PatchSet
from unidiff.patch import Hunk
from localize.language_specify import LanguageManager
from tree_sitter import Node
from tree_sitter_language_pack import get_language, get_parser

import re
from difflib import unified_diff

def build_unified_diff(original_lines, modified_lines, source_start, target_start, filename, context=3):
    """
    Generate带context的 unified diff 补丁, 兼容 git apply.
    source_start / target_start 是 1-based（来自 hunk）.
    """
    # unified_diff 需要 0-based rowindex
    source_lines = [line for line in original_lines]
    modified_lines = [line for line in modified_lines]

    raw_patch = ''.join(unified_diff(
        source_lines,
        modified_lines,
        fromfile=f"a/{filename}",
        tofile=f"b/{filename}",
        lineterm='\n',
        n=context
    ))

# row
    def fix_header(match):
        old = match.group(1)
        new = match.group(2)
        old_start, old_count = map(int, old.split(','))
        new_start, new_count = map(int, new.split(','))
        # 原始补丁头是 0-based, 需要加上 hunk 的偏移
        old_start_fixed = source_start + old_start - 1
        new_start_fixed = target_start + new_start - 1
        return f"@@ -{old_start_fixed},{old_count} +{new_start_fixed},{new_count} @@"

    fixed_patch = re.sub(r'^@@ -(\d+,\d+) \+(\d+,\d+) @@', fix_header, raw_patch, flags=re.MULTILINE)

    return fixed_patch

class CompletionTaskConstructor:
    """
    用于Process补全task的class.
    containingfrom补丁中Extract各种级别补全的逻辑.
    """

    def __init__(self, patch, language):
        self.patch = patch
        self.language_manager = LanguageManager(language)
        self.parser = get_parser(language)
        self.lang = get_language(language)

    def reconstruct_content_for_hunk(self, hunk: Hunk):
        """
        from Hunk object中重建源filesand目标文件的content.
        """
        try:
            source_lines = [line.value for line in hunk if not line.is_added]
            target_lines = [line.value for line in hunk if not line.is_removed]

            source_content = "".join(source_lines)
            target_content = "".join(target_lines)

            return source_content, target_content
        except (IndexError, ValueError):
            return None, None

    def find_newly_added_functions_in_hunk(self, target_code: str, hunk: Hunk):
        """
        use tree-sitter 在目标code中Find完全是新增的functionnode.
        """
        if not target_code:
            return []

        added_lines = {
            line.target_line_no
            for line in hunk
            if line.is_added
        }
        if not added_lines:
            return []

        tree = self.parser.parse(bytes(target_code, "utf8"))
        root_node = tree.root_node
        
        new_functions = []
        
        func_query = self.language_manager.get_function_queries()
        query = self.lang.query(func_query)
        captures = query.captures(root_node)
        
        if captures:
            for node in captures['func']:
                start_line = hunk.target_start + node.start_point[0]
                end_line = hunk.target_start + node.end_point[0]
                
                is_completely_new = all(i in added_lines for i in range(start_line, end_line + 1))
                
                if is_completely_new:
                    new_functions.append(node)
                
        return new_functions

    def get_indentation(self, source_code: str, node: Node) -> str:
        """Get给定 AST node所在row的缩进."""
        line_start_byte = node.start_byte - node.start_point[1]
        line_bytes = source_code.encode('utf8')[line_start_byte : node.start_byte]
        line_str = line_bytes.decode('utf8', 'ignore')
        return ''.join(char for char in line_str if char.isspace())

    def create_function_completion_patches_for_file(self, filename: str, hunk: Hunk):
        """
        为一个filesGenerate所有可能的function级别补全.
        return一个containing补丁dataand元数据的dictlist.
        """
        if not filename.endswith('.py'):
            return []
        
        _, target_code = self.reconstruct_content_for_hunk(hunk)
        if not target_code:
            return []

        # 将目标code预先Split成rowlist, 作为Extract原始content的来源
        target_lines = target_code.splitlines(keepends=True)
        target_bytes = target_code.encode('utf8')

        candidate_nodes = self.find_newly_added_functions_in_hunk(target_code, hunk)
        
        valid_functions = [
            node for node in candidate_nodes
            if (node.end_point[0] - node.start_point[0] + 1) > 5
        ]

        if not valid_functions:
            return []

        all_patch_candidates = []

        for func_node in valid_functions:
            body_node = func_node.child_by_field_name("body")
            if not body_node:
                continue

            docstring_node = None
            docstring_content = ""
            docstring_lines = 0
            start_replace_byte = -1

            # Check是否exists文档strings
            if body_node.named_child_count > 0:
                first_child = body_node.named_children[0]
                if (first_child.type == 'expression_statement' and 
                    first_child.named_child_count > 0 and 
                    first_child.named_children[0].type == 'string'):
                    docstring_node = first_child
                    start_replace_byte = docstring_node.end_byte
                    docstring_lines = docstring_node.end_point[0] - docstring_node.start_point[0] + 1
            
            if start_replace_byte == -1:
                parameters_node = func_node.child_by_field_name('parameters')
                search_start = parameters_node.end_byte if parameters_node else func_node.start_byte
                colon_pos = target_bytes.find(b':', search_start, body_node.start_byte)
                if colon_pos != -1:
                    start_replace_byte = colon_pos + 1
                else:
                    start_replace_byte = body_node.start_byte
            
            # --- FIX START: fromrowlistExtract签名and文档strings以保留format ---

            # 1. Extractfunction签名
            sig_start_line_idx = func_node.start_point[0]
            # found签名的Endingrow（即冒号所在的行）
            colon_pos = target_bytes.find(b':', func_node.start_byte, body_node.start_byte)
            sig_end_line_idx = target_code[:colon_pos].count('\n')
            signature = "".join(target_lines[sig_start_line_idx : sig_end_line_idx + 1]).rstrip()

            # 2. Extract文档strings
            if docstring_node:
                doc_start_line_idx = docstring_node.start_point[0]
                doc_end_line_idx = docstring_node.end_point[0]
                docstring_content = "".join(target_lines[doc_start_line_idx : doc_end_line_idx + 1]).rstrip()
            
            # --- FIX END ---
            
            # Calculate纯coderow数
            total_body_lines = body_node.end_point[0] - body_node.start_point[0] + 1
            loc = total_body_lines - docstring_lines
            
            # createtask补丁
            indent = self.get_indentation(target_code, body_node)
            if not indent:
                indent = self.get_indentation(target_code, func_node) + "    "

            prefix = target_bytes[:start_replace_byte].decode('utf8', 'ignore')
            suffix = target_bytes[func_node.end_byte:].decode('utf8', 'ignore')
            code_with_todo = (
                prefix +
                f"\n{indent}{self.language_manager.get_comment_str()} [TODO]" +
                suffix
            )
            
            todo_lines = code_with_todo.splitlines(keepends=True)

            task_patch = build_unified_diff(
                original_lines=target_lines,
                modified_lines=todo_lines,
                source_start=hunk.target_start,
                target_start=hunk.target_start,
                filename=filename
            )

            # 构造 completion_patch（恢复function体）
            completion_patch = build_unified_diff(
                original_lines=todo_lines,
                modified_lines=target_lines,
                source_start=hunk.target_start,
                target_start=hunk.target_start,
                filename=filename
            )

            all_patch_candidates.append({
                "task_patch": task_patch,
                "completion_patch": completion_patch,
                "file_name": filename,
                "LOC": loc,
                "signature": signature,
                "docstring": docstring_content
            })

        return all_patch_candidates

    def create_function_level_completion(self):
        """
        from一个可能containing多files更改的 patch 中, 随机选择一个新增function, 
        并为其create一group包含task/补全 patch and元data的dict.
        """
        all_candidates = []
        
        # 1. grouped by换row拆成行, 保留所有换行符
        lines = self.patch.splitlines(keepends=True)

# 2. “row --- ”
        breaks = [i for i, line in enumerate(lines)
                if line.startswith('--- ')]

# 3. , （row）
        patches = [''.join(lines[start:end])
                for start, end in zip([0] + breaks, breaks + [len(lines)])]

        # 4. 去掉可能出现的empty块
        individual_patch_strs = [p for p in patches if p]
                
        for single_patch_str in individual_patch_strs:
            if not single_patch_str:
                continue
            try:
                patch_set = PatchSet(single_patch_str)
                if not patch_set:
                    continue

                for patched_file in patch_set:
                    if patched_file.is_removed_file or patched_file.is_added_file:
                        continue
                    file_name = patched_file.target_file.split('/', 1)[-1]
                    
                    for hunk in patched_file:
                        if not hunk.is_valid():
                            continue
                        candidates_for_file = self.create_function_completion_patches_for_file(file_name, hunk)
                        if candidates_for_file:
                            all_candidates.extend(candidates_for_file)

            except (IndexError, ValueError) as e:
                print(f"Skipping a patch due to parsing error: {e}")
                continue

        if all_candidates:
            return all_candidates
        else:
            return []

    def create_normal_completion(self, min_chars: int = 10, max_lines: int = 50):
        """
        from patch 中Extract“唯一的连续新增row区间”作为补全目标.
        parameter:
            min_chars : 新增段非empty字符下限
            max_lines : 新增段max行数
        """
        all_candidates = []

        # 1. grouped by换row拆成行, 保留所有换行符
        lines = self.patch.splitlines(keepends=True)

# 2. “row --- ”
        breaks = [i for i, line in enumerate(lines)
                if line.startswith('--- ')]

# 3. , （row）
        patches = [''.join(lines[start:end])
                for start, end in zip([0] + breaks, breaks + [len(lines)])]

        # 4. 去掉可能出现的empty块
        individual_patch_strs = [p for p in patches if p]

        for single_patch_str in individual_patch_strs:
            if not single_patch_str:
                continue
            try:
                patch_set = PatchSet(single_patch_str)
                for patched_file in patch_set:
                    if patched_file.is_removed_file or patched_file.is_added_file:
                        continue
                    file_name = patched_file.target_file.split('/', 1)[-1]
                    for hunk in patched_file:
                        if not hunk.is_valid():
                            continue
                        candidate = self._extract_single_add_block(hunk,
                                                                   file_name,
                                                                   min_chars,
                                                                   max_lines)
                        if candidate:
                            all_candidates.append(candidate)
            except Exception as e:
                print(f"Skipping a patch due to parsing error: {e}")
                continue

        if all_candidates:
            return all_candidates
        else:
            return []

    def _extract_single_add_block(self, hunk: Hunk, filename: str,
                                  min_chars: int, max_lines: int):
        """
        Extract唯一连续新增段, Generate带context的补丁, 只把该段Replace为 [TODO].
        """
        # 1. 收集新增row及其行号（1-based）
        add_lines = [(ln.target_line_no, ln.value)
                     for ln in hunk if ln.is_added]
        if not add_lines:
            return None

# 2.
        starts = [no for no, _ in add_lines]
        if list(range(starts[0], starts[-1] + 1)) != starts:
            return None

        # 3. 区间内不能有任何Deleterow
        for ln in hunk:
            if ln.is_removed and starts[0] <= ln.source_line_no <= starts[-1]:
                return None

        # 4. 字符/row数Filter
        seg_lines = [v for _, v in add_lines]
        total_chars = sum(len(l.rstrip()) for l in seg_lines)
        total_lines = len(seg_lines)
        if total_lines > max_lines or total_chars < min_chars:
            return None

        # 5. 构造“原始” = hunk 的 target version
        target_lines_hunk = [ln.value for ln in hunk if not ln.is_removed]

        # 6. 构造“修改后” = 把新增段Replace为 [TODO]
        modified_lines_hunk = []
        todo_done = False
        for ln in hunk:
            if ln.is_added:
                if not todo_done:
                    modified_lines_hunk.append(f"{self.language_manager.get_comment_str()} [TODO]\n")
                    todo_done = True
            elif not ln.is_removed:
                modified_lines_hunk.append(ln.value)

        # 6. 用 unified_diff Generate补丁（0-based -> 修正为绝对row号）
        raw_patch = ''.join(unified_diff(
            [l for l in target_lines_hunk],
            [l for l in modified_lines_hunk],
            fromfile=f"a/{filename}",
            tofile=f"b/{filename}",
            lineterm='\n',
            n=3
        ))

        # 7. 修正补丁头row号（hunk.target_start 是 1-based）
        def fix_header(m):
            old_start = int(m.group(1))
            old_cnt = int(m.group(2) or 1)
            new_start = int(m.group(3))
            new_cnt = int(m.group(4) or 1)
            abs_old_start = hunk.target_start + old_start - 1
            abs_new_start = hunk.target_start + new_start - 1
            return f"@@ -{abs_old_start},{old_cnt} +{abs_new_start},{new_cnt} @@"

        fixed_patch = re.sub(r'^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@',
                             fix_header, raw_patch, flags=re.MULTILINE)

# 8. （）
        rev_patch = re.sub(r'^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@',
                           fix_header,
                           ''.join(unified_diff(
                               [l for l in modified_lines_hunk],
                               [l for l in target_lines_hunk],
                               fromfile=f"a/{filename}",
                               tofile=f"b/{filename}",
                               lineterm='\n',
                               n=3)),
                           flags=re.MULTILINE)

        return {
            "task_patch": fixed_patch,
            "completion_patch": rev_patch,
            "file_name": filename,
            "LOC": total_lines,
            "signature": "",
            "docstring": ""
        }

### 如何use

# 示例:一个containing多个files更改的 patch strings
sample_patch_multi_file = """
--- a/utils/helpers.py
+++ b/utils/helpers.py
@@ -5,2 +5,14 @@
 def existing_util():
-    return True
+    return False
+
+def format_user_data(user_id, name, email):
+    \"\"\"
+    Formats user data into a consistent string.
+    This is a new utility function added in this patch.
+    It ensures all user representations are standardized.
+    \"\"\"
+    user_str = f"ID: {user_id}, Name: {name}, Email: <{email}>"
+    # Log the formatted string for debugging purposes
+    print(f"Formatted: {user_str}")
+    return user_str
+
--- a/scripts/main_process.py
+++ b/scripts/main_process.py
@@ -54,3 +54,5 @@
     def main_process():
-        pass
+        print("This is the main process function.")
+    def run_main_process():
+        print("Running the main process...")
         return 0
@@ -100,6 +102,18 @@
 class MyClass:
     def existing_method(self):
-        pass
+        print("This is an existing method.")
+
+    def calculate_complex_value(self, x, y, z):
+        \"\"\"
+        This is a new complex function.
+        It performs some important calculations.
+        \"\"\"
+        # Step 1: initial calculation
+        result = (x * y) + z
+        # Step 2: apply a transformation
+        result = result ** 2 / (x + 1)
+        # Step 3: return the final value
+        return result
 
 def another_function():
     return "hello"
"""

# Call主function
if __name__ == "__main__":
    constructor = CompletionTaskConstructor(sample_patch_multi_file, language='python')
    results = constructor.create_function_level_completion()
    result_dict = random.choice(results) if results else None

    if result_dict:
        print("--- 🪄 随机选择的函数补全 ---")
        print("\n📄 签名:")
        print(result_dict['signature'])
        print(f"\n 代码行数 (纯代码): {result_dict['LOC']}")
        print("\n💬 文档字符串:")
        print(result_dict['docstring'])

        print("\n--- 📝 任务补丁 (输入) ---")
        print("此补丁应用后, 会将函数体替换为 [TODO].")
        print(result_dict['task_patch'])
        
        print("\n--- ✅ 补全补丁 (目标) ---")
        print("此补丁可将 [TODO] 恢复为完整的函数体.")
        print(result_dict['completion_patch'])

    else:
print(" patch .")

    results = constructor.create_normal_completion(min_chars=10, max_lines=20)
    result_dict = random.choice(results) if results else None
    if result_dict:
        print("--- 🪄 随机选择的代码区间 ---")
        print("\n📄 签名:")
        print(result_dict['signature'])
        print(f"\n 代码行数 (纯代码): {result_dict['LOC']}")
        print(result_dict['docstring'])

        print("\n--- 📝 任务补丁 (输入) ---")
        print("此补丁应用后, 会将连续区间替换为 [TODO].")
        print(result_dict['task_patch'])
        
        print("\n--- ✅ 补全补丁 (目标) ---")
        print("此补丁可将 [TODO] 恢复为正常的连续区间.")
        print(result_dict['completion_patch'])

    else:
print(" patch .")
