from typing import Dict, List, Tuple, Optional
from pathlib import Path
from logging import Logger
import subprocess
import re
from cresearcher.utils.types import SupportedModels, GlobalCtxAgentResult
from cresearcher.utils.modelHandlers import getModelHandler
from cresearcher.utils.gitdiff import convertPatchToGitDiff
from cresearcher.utils.ctags import getMarkdownShortForm


def generatePatchSystemPrompt(prompt_preamble:str, prompt_patch_gen_examples:str, llm: SupportedModels, repoName: str) -> str :
    return f"""{prompt_preamble}

The context contains definitions of relevant symbols (e.g. functions, structs, unions, global constants, macros), regex search results from the {repoName} codebase, regex search results from the historical commit messages and patches, and relevant context from certain files. One or more functions or other symbols may need to be edited to fix the bug.{' You will also be given a detailed explanation of why each piece of context is relevant to the bug and how different parts of the context are connected. The explanation will reveal how the person who collected this context understood the bug and all the code snippets and commits that are available to you.' if llm !='o1' else ''}

Your goal is to reason about the bug, generate a hypothesis about the root cause of the bug, and propose a patch to fix the bug.

You should reason from first principles about all possible root causes of the crash. You should understand the dataflow of variable and control flow through conditional statements, loops, jump statements and function calls. You should analyze this flow in reference to the crash report and the context provided. You should use the code search snippets to understand relevant patterns in the codebase, and the past commits to understand how code has evolved and how the developers usually write code.

You should then generate a specific and detailed hypothesis about what the bug is.

After generating the hypothesis, you should generate a patch based on the hypothesis. In the patch, you should provide the complete new rewrite of each symbol/function definition snippet that you want to modify. You should ensure that the starting line of your rewrite is the same as in the code given to you, and your rewrite is a complete replacement of the code snippet.

{'You should reach your hypothesis and patch through sound inferences from the crash report, the ' + repoName + ' repository code snippets, and the past commits. Focus on building up to the solution step-by-step, through careful reasoning steps.' if llm != 'o1' else ''}

You must provide your response in the following format:

<thoughts>
Write your analysis of the crash report, potentially buggy functions, symbol definitions, code snippets and past commits. You should reason about all possible root causes of the crash, and different ways to address them.
</thoughts>

<hypothesis>
Write a specific and detailed hypothesis about the root cause of the crash. Your hypothesis should be based on your analysis of the crash report, the context provided, and your understanding of the Linux kernel codebase.
</hypothesis>

<patch>
Write the complete new rewrite of each symbol/function definition snippet that you want to modify. Each symbol definition should be wrapped in <symbol> tags with file, name and start attributes. The code should start from the same line as the symbol definition code given to you, and should be a complete replacement of the code snippet.
</patch>

Your patch should be in the following format:

<patch>
<symbol file="/path/to/file.c" name="function_name" start="start_line_number">
New function body with modifications and starting from the same line as the function code snippet given to you
</symbol>
<symbol file="/path/to/another_file.c" name="another_function_name" start="another_start_line_number">
New function body with modifications and starting from the same line as the function code snippet given to you
</symbol>
<symbol file="/path/to/another_file.c" name="struct_name" start="struct_start_line_number">
New struct definition with modifications and starting from the same line as the struct code snippet given to you
</symbol>
</patch>
</actions>

Here is an example response:

{prompt_patch_gen_examples}


"""


def generatePatchPrompt(bugDict: Dict[str, any], globalCtx: GlobalCtxAgentResult, fileCtx: str, relevantLines: List[Tuple[str, str, Optional[str]]], llm: SupportedModels, repoName: str) -> str:
    sections = []
    separator = "\n" + "="*80 + "\n"

    sections.append(f"You are analyzing a {repoName} crash report. Your task is to reason about all the context given to you, generate a hypothesis about the root cause of the bug, and propose a patch to fix the bug. Here is all the context given to you.")
    sections.append(separator)

    sections.append(f"BUG TITLE: {bugDict['title']}")
    sections.append("CRASH REPORT:")
    sections.append(bugDict['crash_report_data'])
    sections.append(separator)

    if fileCtx:
        sections.append("FILE CONTEXT:")
        sections.append(fileCtx)
        sections.append(separator)

    if globalCtx.relevantQueries:
        sections.append("RELEVANT SEARCH QUERIES AND RESULTS:")
        for query, results in globalCtx.relevantQueries:
            results = [r for r in results if (r != "No matches found") and ("===Commits matching in" not in r) ]
            if not results: continue
            sections.append(f"Query: {query}")
            sections.append("Results:")
            for result in results:
                sections.append("")
                sections.append(result)
                sections.append("")
                sections.append("===")
        sections.append(separator)

    if globalCtx.relevantDefinitions.items():
        sections.append("RELEVANT SYMBOL DEFINITIONS:")
        for _, symbols in globalCtx.relevantDefinitions.items():
            for s in symbols:
                sections.append(f"File: {s.filePath}\t\tSymbol: {s.name}\t\tLines {s.start} to {s.end}")
                sections.append(f"```{getMarkdownShortForm(s.filePath)}")
                for i, l in enumerate(s.body.splitlines()):
                    if list(filter(lambda f : (s.filePath == f[0]) and (str(i+s.start) == f[2]), relevantLines)):
                        sections.append(f"{i+s.start}| {l} // IMPORTANT LINE: This line is mentioned in the bug crash report. Pay attention to it while generating your hypothesis and patch.")
                    else:
                        sections.append(f"{i+s.start}| {l}")
                sections.append("```")
                sections.append("")
        sections.append(separator)

    if globalCtx.relevantFunctions:
        sections.append("POTENTIALLY BUGGY FUNCTIONS:")
        for func in globalCtx.relevantFunctions:
            sections.append(f"File: {func.filePath}\t\tFunction: {func.name}\t\tLines {func.start} to {func.end}")
            sections.append(f"```{getMarkdownShortForm(func.filePath)}")
            for i, l in enumerate(func.body.splitlines()):
                if list(filter(lambda f : (func.filePath == f[0]) and (str(i+func.start) == f[2]), relevantLines)):
                    sections.append(f"{i+func.start}| {l} // IMPORTANT LINE: This line is mentioned in the bug crash report. Pay attention to it while considering what pieces of context to mark as relevant from the repository code and commit history.")
                else:
                    sections.append(f"{i+func.start}| {l}")
            sections.append("```")
            sections.append("")
        sections.append(separator)

    if globalCtx.relevanceExplanation and llm != "o1":
        sections.append("CONTEXT EXPLANATION:")
        sections.append(globalCtx.relevanceExplanation)
        sections.append(separator)

    sections.append("You should examine the crash report and all the context given to you, reason about all possible root causes of the crash, and propose a patch to fix the bug. You should place your thoughts inside <thoughts> tags, generate a specific and detailed hypothesis about the root cause of the crash in <hypothesis> tags, and write the patch to fix the bug in <patch> tags. Your patch should contain a complete new rewrite of each symbol/function definition snippet that you want to modify. Each symbol definition should be wrapped in <symbol> tags with file, name and start attributes. The code should start from the same line as the symbol definition code given to you, and should be a complete replacement of the code snippet.\n\n")

    return "\n".join(sections)

def generatePatch(prompt_preamble:str, prompt_patch_gen_examples:str, bugDict: Dict[str, any], repoPath: Path, globalCtx: GlobalCtxAgentResult, fileCtx: str, relevantLines: List[Tuple[str, str, Optional[str]]], logger: Logger, llm: SupportedModels, repoName: str, langs:str) -> Tuple[str, str, str, int, int, int]:
    """Generate a patch for a bug, given global context, file context, and relevant lines. Returns thoughts, hypothesis, patch, cached prompt tokens, prompt tokens, completion tokens."""
    logger.info("Generating patch for bug")
    prompt = generatePatchPrompt(
        bugDict=bugDict,
        globalCtx=globalCtx,
        fileCtx=fileCtx,
        relevantLines=relevantLines,
        llm=llm,
        repoName=repoName,
    )
    cachedPromptTokens = 0
    promptTokens = 0
    completionTokens = 0

    # Retry mechanism
    max_retries = 3
    for attempt in range(max_retries):
        logger.info(f"Attempt {attempt + 1} to generate patch")
        modelHandler = getModelHandler(
            modelName=llm,
            systemPrompt=generatePatchSystemPrompt(prompt_preamble, prompt_patch_gen_examples, llm, repoName),
            temperature=attempt*0.3,
            conversational=False,
            pinFirstN=0,
        )
        logger.info("PROMPT:\n" + prompt)
        responses, tokenCounts = modelHandler.get_responses(prompt, return_token_count=True)
        response = responses[0]
        logger.info("RESPONSE:\n" + response)
        if tokenCounts['prompt_tokens_details']:
            cachedPromptTokens += tokenCounts['prompt_tokens_details'].cached_tokens
        promptTokens += tokenCounts['prompt_tokens']
        completionTokens += tokenCounts['completion_tokens']

        thoughts, hypothesis, symbols = parsePatchResponse(response, logger)

        patch, err = convertPatchToGitDiff(
            repoPath=repoPath,
            changes=symbols,
            langs=langs,
            logger=logger,
        )

        # Check if patch can be applied with git apply
        logger.debug(f"Running git apply --check on patch:\n{patch}")
        proc = subprocess.run(
            ['git', 'apply', '--check'],
            input=patch,
            cwd=repoPath,
            capture_output=True,
            text=True,
            encoding='utf-8',
        )
        logger.debug(f"git apply --check output: {proc.stdout}")
        logger.debug(f"git apply --check error: {proc.stderr}")
        logger.debug(f"git apply --check return code: {proc.returncode}")
        if proc.returncode != 0:
            logger.error(f"Patch failed to apply on attempt {attempt + 1}.\nMessage: {proc.stderr}\n")
        if patch and proc.returncode == 0:
            logger.info("New patch generated.Patch:\n" + patch)
            return thoughts, hypothesis, patch, cachedPromptTokens, promptTokens, completionTokens
        else:
            logger.error(f"Failed to create patch on attempt {attempt + 1}.\nMessage: {err}\n")

    logger.error("Failed to create patch after 3 attempts.")
    return thoughts, hypothesis, "", cachedPromptTokens, promptTokens, completionTokens

# Returns thoughts, hypothesis, list of (filePath, symbolName, startLine, newBody)
def parsePatchResponse(response: str, logger: Logger) -> Tuple[str, str, List[Tuple[str, str, int, str]]]:
    thoughtsMatch = re.search(r"<thoughts>(.*?)</thoughts>", response, re.DOTALL)
    hypothesisMatch = re.search(r"<hypothesis>(.*?)</hypothesis>", response, re.DOTALL)
    patchMatch = re.search(r"<patch>(.*?)</patch>", response, re.DOTALL)
    if thoughtsMatch:
        thoughts = thoughtsMatch.group(1).strip()
    else:
        logger.error("Error: No <thoughts> section found in response")
        thoughts = ""
    if hypothesisMatch:
        hypothesis = hypothesisMatch.group(1).strip()
    else:
        logger.error("Error: No <hypothesis> section found in response")
        hypothesis = ""
    
    patchMatch = re.search(r'<patch>(.*?)</patch>', response, re.DOTALL)
    if not patchMatch:
        logger.error("Error: No <patch> section found in response")
        patchContent = response
    else:
        patchContent = patchMatch.group(1).strip()
    symbolMatches = re.finditer(r'<symbol[^>]*>.*?</symbol>', patchContent, re.DOTALL)
                        
    symbols = []
    for match in symbolMatches:
        logger.info(f"Parsing symbol patch: {match.group(0)}")
        try:
            filePath, symbolName, startLine, newBody = parseSymbolXML(match.group(0), logger)
        except Exception as e:
            logger.error(f"Failed to parse symbol patch: {e}")
            continue
        symbols.append((filePath, symbolName, startLine, newBody))
    
    if not symbols:
        logger.error("No valid symbol rewrites found")
    
    return thoughts, hypothesis, symbols

def parseSymbolXML(symbolXml: str, logger: Logger) -> Tuple[str, str, int, str]:
    """Parse a symbol patch from XML-like format.
    Returns (filePath, symbolName, startLine, newBody)"""
    try:
        # Extract file and name attributes
        fileMatch = re.search(r'file="([^"]+)"', symbolXml)
        nameMatch = re.search(r'name="([^"]+)"', symbolXml)
        startMatch = re.search(r'start="(\d+)"', symbolXml)
        if not startMatch:
            startMatch = re.search(r'start=(\d+)', symbolXml)
        
        if not fileMatch or not nameMatch or not startMatch:
            raise ValueError("Missing file, name or start attribute in symbol patch")
            
        filePath = fileMatch.group(1)
        symbolName = nameMatch.group(1)
        startLine = int(startMatch.group(1))
        
        # Extract symbol body
        bodyMatch = re.search(r'>\s*(.*?)\s*</symbol>', symbolXml, re.DOTALL)
        if not bodyMatch:
            raise ValueError("Missing or invalid symbol body")
            
        newBody = bodyMatch.group(1).strip()
        
        return filePath, symbolName, startLine, newBody
    except Exception as e:
        logger.error(f"Error parsing symbol patch: {e}")
        raise
