"""
Utility functions for parsing and applying the patch.

Inspired by:
https://github.com/gpt-engineer-org/gpt-engineer/blob/main/gpt_engineer/core/chat_to_files.py
"""

import re
from dataclasses import dataclass
from pprint import pformat
from tempfile import NamedTemporaryFile
from typing import TextIO
from app.data_structures import BugLocation
from app.search import search_utils
from pylint.lint import Run
from pylint.reporters.text import TextReporter


@dataclass
class Edit:
    filename: str
    before: str
    after: str
    before_start_line: int | None # newly added by 
    before_end_line: int | None # newly added by 

    def __str__(self):
        return f"filename: {self.filename}\nbefore_start_line: {self.before_start_line}\nbefore_end_line: {self.before_end_line}\nBefore:\n{pformat(self.before)}\nAfter:\n{pformat(self.after)}\n"

    def __repr__(self):
        return str(self)


def parse_edits(chat_str: str, gen_locs: list[BugLocation]) -> list[Edit]:
    """
    Parse edits from a chat string.
    This function extracts code edits from a chat string and returns them as a list
    of Edit objects.
    Args:
        chat_string (str): The chat content containing code edits. actually is the response from llm
    Returns:
        List[Edit]: A list of Edit objects representing the parsed code edits.
    """
    reason_start = '<reason>'
    reason_end = '</reason>'
    code_start = '<code>'
    code_end = '</code>'
    all_edits: list[Edit] = []
    # print('~~~~~~~chat_str~~~~~~~~')
    # print(chat_str)
    # print('~~~~~~~chat_str~~~~~~~~')
    # use regex to find content
    reason_pattern = re.compile(f"{reason_start}(.*?){reason_end}", re.DOTALL)
    code_pattern = re.compile(f"{code_start}(.*?){code_end}", re.DOTALL)
    reason_matches = reason_pattern.findall(chat_str) # list[reasoning string]
    code_matches = code_pattern.findall(chat_str) # list[code string]
    # print(f'reason_matches: {reason_matches}')
    # print(f'code_matches: {code_matches}')
    # print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
    for reason, code, loc in zip(
        reason_matches, code_matches, gen_locs
    ):
        file = loc.rel_file_path.strip()
        # for original and patched, keep the spaces, since removing spaces at beginning or end
        # may mess up indentation level on some of the lines.
        # However, we should remove the new lines at start and end. These new lines may be
        # inserted by the model, but if in the original code there are no such new lines before
        # the actual code, this can result in non-match
        before_start_line = loc.start
        before_end_line = loc.end
        original = search_utils.get_code_snippets(loc.abs_file_path, before_start_line, before_end_line+1, with_lineno=False)
        # before_end_line+1 is to include raise notimplementerror line in the original.
        
        original = original.strip("\n")
        patched = search_utils.get_code_snippets(loc.abs_file_path, before_start_line, before_end_line, with_lineno=False) + code
        all_edits.append(Edit(filename = file, 
                                before = original, 
                                after = patched,
                                before_start_line = before_start_line,
                                before_end_line = before_end_line
                                )
                            )
    return all_edits


def apply_edit(edit: Edit, file_path: str) -> str | None:
    """
    Apply one Edit to a file. This function reads the file, tries to match
    the before string (after stripping spaces in the original program and the
    before string improve the chance of matching), and then replaces the matched region with the after string.
    Returns:
        - Path to the file containing updated content if successful;
          None otherwise.
    """
    with open(file_path) as f:
        orig_prog_lines = f.readlines()
    
    filename = edit.filename
    before_lines = edit.before
    after_lines = edit.after
    start_line = edit.before_start_line - 1  # Convert to 0-based index
    
    # Convert before/after strings to lists of lines
    before_block = [line + "\n" if not line.endswith("\n") else line for line in "".join(before_lines).splitlines()]
    after_block = [line + "\n" if not line.endswith("\n") else line for line in "".join(after_lines).splitlines()]

    # Check that the before block matches the current file content
    file_block = orig_prog_lines[start_line:start_line + len(before_block)]
    if file_block != before_block:
        print(f"❌ ERROR: The content at line {start_line + 1} does not match the 'Before' block.")
        exit()
    # Replace the block
    new_lines = orig_prog_lines[:start_line] + after_block + orig_prog_lines[start_line + len(before_block):]

    # Write back to file
    with open(file_path, "w") as f:
        f.writelines(new_lines)

    return file_path


class Writable(TextIO):
    "dummy output stream for pylint"

    def __init__(self) -> None:
        self.content: list[str] = []

    def write(self, s: str) -> int:
        self.content.append(s)
        return len(s)

    def read(self, n: int = 0) -> str:
        return "\n".join(self.content)


def lint_python_content(content: str) -> bool:
    """Check if python content lints OK.

    Args:
        content: python file content

    Returns: True if the contents passes linting, False otherwise.

    """
    pylint_out = Writable()
    reporter = TextReporter(pylint_out)

    with NamedTemporaryFile(buffering=0) as f:
        f.write(content.encode())

        _ = Run(["--errors-only", f.name], reporter=reporter, exit=False)

    return not any(error.endswith("(syntax-error)") for error in pylint_out.content)




def fix_function_body_indentation(task, raw_response: str, indent_size: int = 4) -> str:
    """
    Align the function body indentation with the indentation of the signature block.

    Args:
        task (RawLocalTask): Full function header block (e.g., def + docstring)
        raw_response (str): raw_response containing Generated body code (possibly badly indented)
        indent_size (int): How many spaces for each indentation level (default: 4)

    Returns:
        str: response with Indentation-corrected function body
    """
    def extract_code_blocks(text: str) -> list[str]:
        pattern = r"<code>(.*?)</code>"
        return re.findall(pattern, text, flags=re.DOTALL)[0]
    
    body_code = extract_code_blocks(raw_response)
    
    correct_body_indent_len = 0
    pure_func_body= task.GT_list[len(task.prompt_list):]
    for ele in pure_func_body:
        if len(ele.strip()) > 0:
            correct_body_indent_len = len(ele) - len(ele.lstrip())
            break
    correct_body_indent = " " * correct_body_indent_len
    
    body_lines = body_code.splitlines()
    for line in body_lines:
        if line.strip(): # first none empty line
            actural_body_indent = " " * (len(line) - len(line.lstrip()))
            break
    intent_size2fix = (len(correct_body_indent) - len(actural_body_indent))
    if intent_size2fix <= 0: # if need to delete space, we can ignore
        return raw_response
    
    corrected_body_lines = []
    for line in body_lines:
        if line.strip(): # nonempty line
            corrected_body_lines.append(" " * intent_size2fix + line)
    
    response_edit = raw_response.replace(body_code, "\n".join(corrected_body_lines))
    
    return response_edit

    
    
        
        