from typing import Dict, List, Optional, Tuple
from pathlib import Path
from datetime import datetime
import dataclasses
import logging
import json
import os
import re
import subprocess
import traceback
from .utils.repo import cloneRepo, checkoutCommit, extractRelevantLines
from .utils.ctags import updateRepoTagsFile
from .utils.types import SupportedModels, dictToSymbolDefinition, dictToGlobalCtxAgentResult
from .globalContext import GlobalCtxAgent
from .patchGeneration import generatePatch

class CodeResearcher:

    def __init__(self, repoUrl:str, repoName:str, prompt_preamble:str, prompt_analysis_examples:str, prompt_patch_gen_examples:str, bugDict: Dict[str, any], langs:str,  maxGlobalCtxSteps: int = 10, workdirPath: Path = Path("workdir"), logDir: Path = None, backportCommitsJsonPath: Optional[Path] = None, initRunDataPath: Optional[Path] = None, patchGenOnly: bool = False, patchGenllm: SupportedModels = "gpt-4o", numPatches: int = 5):
        self.REPO_GIT_URL = repoUrl
        self.REPO_NAME = repoName
        self.TAGS_FILE_NAME = f"{self.REPO_NAME}_tags"
        self.logDir = logDir
        self.prompt_preamble = prompt_preamble
        self.prompt_analysis_examples = prompt_analysis_examples
        self.prompt_patch_gen_examples = prompt_patch_gen_examples
        self._validateBugDict(bugDict)
        self.bugDict = bugDict
        self.langs = langs
        self.maxGlobalCtxSteps = maxGlobalCtxSteps
        self.workdirPath = workdirPath
        # Create workdir if it doesn't exist
        self.workdirPath.mkdir(exist_ok=True)
        self.repoPath =  workdirPath / self.REPO_NAME
        self.tagsFilePath = workdirPath / self.TAGS_FILE_NAME
        if backportCommitsJsonPath:
            self.backportCommitsJsonPath = backportCommitsJsonPath
        else:
            self.backportCommitsJsonPath = None
        if initRunDataPath:
            with open(initRunDataPath, 'r') as f:
                self.runData = json.load(f)
        else:
            self.runData = {} 
        self.patchGenOnly = patchGenOnly
        self.patchGenllm = patchGenllm
        self.numPatches = numPatches

        self._setupLogging(logDir)
        cloneRepo(self.repoPath, self.REPO_GIT_URL, self.logger)
        if not self._setupWorkDir(self.patchGenOnly):
            self.logger.error("Failed to setup working directory")
            raise RuntimeError("Failed to setup working directory")

    def _setupLogging(self, logDir: Optional[Path] = None):
        # Setup logging
        if not logDir:
            logDir = self.workdirPath / self.logDir
        logDir.mkdir(exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        logFile = logDir / f"cresearcher_{timestamp}.log"
        self.runDataFile = logDir / f"cresearcher_{timestamp}.json"
        
        self.logger = logging.getLogger(f"cresearcher_{timestamp}")
        self.logger.setLevel(logging.DEBUG)
        
        # File handler with detailed formatting
        fh = logging.FileHandler(logFile)
        fh.setLevel(logging.DEBUG)
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        fh.setFormatter(formatter)
        self.logger.addHandler(fh)
        

    def _setupWorkDir(self, patchGenOnly: bool) -> bool:
        if not checkoutCommit(
            repoPath=self.repoPath,
            commitId=self.bugDict['parent_of_fix_commit'],
            backportCommitsJsonPath=self.backportCommitsJsonPath,
            logger=self.logger
        ):
            error_msg = f"Failed to checkout and apply backports for commit {self.bugDict['parent_of_fix_commit']}"
            self.logger.error(error_msg)
            raise ValueError(error_msg)

        if patchGenOnly:
            return True # We don't need a repo tags file for just patch generation
            
        return updateRepoTagsFile(
            repoPath=self.repoPath,
            tagsFilePath=self.tagsFilePath,
            langs=self.langs,
            logger=self.logger
        )

    def _sanityCheckRepo(self) -> bool:
        """Check if there are any non-staged or untracked changes in the repo. Note: changes to be committed are allowed because they may be backport commits."""
        result = subprocess.run(
            ['git', 'status'],
            cwd=self.repoPath,
            capture_output=True,
            text=True
        )
        return not (result.returncode!=0 or "Untracked files:" in result.stdout or "Changes not staged for commit:" in result.stdout or "publish your local commits" in result.stdout)

    def _validateBugDict(self, bugDict: Dict[str, any]):
        if 'crash_report_data' not in bugDict:
            raise ValueError("Crash report cannot be empty")
        if 'title' not in bugDict:
            raise ValueError("Title cannot be empty")
        if 'parent_of_fix_commit' not in bugDict:
            raise ValueError("Bug dictionary must contain parent_of_fix_commit")
        if 'crashes' not in bugDict or not bugDict['crashes'] or 'kernel-source-commit' not in bugDict['crashes'][0]:
            raise ValueError("Bug dictionary must contain crashes with kernel-source-commit")

    def run(self) -> List[str]:
        """Run the agent on a bug. Returns a list of patches."""
        self.logger.info("Starting new run")
        self.logger.info(f"Bug ID: {self.bugDict['id']}\nRepo Commit: {self.bugDict['parent_of_fix_commit']}")
        self.runData['id'] = self.bugDict['id']
        self.runData['time'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        return self._run()

    def _run(self) -> List[str]:
        try:
            relevantLines = extractRelevantLines(self.bugDict['crash_report_data'])
            self.logger.info(f"Relevant lines: {relevantLines}")
        except Exception as e:
            self.logger.error(f"An error occurred while extracting relevant lines: {e}")
            self.runData['error'] = str(e)
            self.runData['traceback'] = traceback.format_exc()
            with open(self.runDataFile, 'w') as f:
                json.dump(self.runData, f, indent=2)
            self.logger.info(f"Run data saved to {self.runDataFile}")
            return []

        patches = []
        for i in range(self.numPatches): 
            try:
                if not self._sanityCheckRepo():
                    self.logger.error(f"Sanity check failed for iteration {i}. Checking out repo again.")
                    if not self._setupWorkDir(self.patchGenOnly):
                        self.logger.error("Failed to setup working directory")
                        raise RuntimeError("Failed to setup working directory")
                self.logger.info(f"Processing iteration {i}")
                if not self.patchGenOnly:
                    self.runData[f'iteration_{i}'] = {}
                    globalCtxAgent = GlobalCtxAgent(
                        prompt_preamble=self.prompt_preamble,
                        prompt_analysis_examples=self.prompt_analysis_examples,
                        bugDict=self.bugDict,
                        buggyFunctions=[],
                        relevantLines=relevantLines,
                        repoPath=self.repoPath,
                        tagsFilePath=self.tagsFilePath,
                        logger=self.logger,
                        maxSteps=self.maxGlobalCtxSteps,
                        repoName=self.REPO_NAME,
                    )
                    globalCtxResult = globalCtxAgent.run()
                    self.runData[f'iteration_{i}']['globalCtxResult'] = dataclasses.asdict(globalCtxResult)
                    self.runData[f'iteration_{i}']['globalCtxPromptTokens'] = globalCtxResult.promptTokens
                    self.runData[f'iteration_{i}']['globalCtxCompletionTokens'] = globalCtxResult.completionTokens
                else:
                    globalCtxResult = dictToGlobalCtxAgentResult(self.runData[f'iteration_{i}']['globalCtxResult'])
                    # Delete all old keys named error or patch* to avoid confusion
                    for key in list(self.runData[f'iteration_{i}'].keys()):
                        if key == 'error' or key.startswith('patch'):
                            del self.runData[f'iteration_{i}'][key]

                thoughts, hypothesis, patch, patchCachedPromptTokens, patchPromptTokens, patchCompletionTokens = generatePatch(
                    prompt_preamble=self.prompt_preamble,
                    prompt_patch_gen_examples=self.prompt_patch_gen_examples,
                    bugDict=self.bugDict,
                    repoPath=self.repoPath,
                    globalCtx=globalCtxResult,
                    fileCtx="",
                    relevantLines=relevantLines,
                    logger=self.logger,
                    llm=self.patchGenllm,
                    repoName=self.REPO_NAME,
                    langs=self.langs,
                )
                self.runData[f'iteration_{i}']['patchThoughts'] = thoughts
                self.runData[f'iteration_{i}']['patchHypothesis'] = hypothesis
                self.runData[f'iteration_{i}']['patch'] = patch
                self.runData[f'iteration_{i}']['patchCachedPromptTokens'] = patchCachedPromptTokens
                self.runData[f'iteration_{i}']['patchPromptTokens'] = patchPromptTokens
                self.runData[f'iteration_{i}']['patchCompletionTokens'] = patchCompletionTokens
                patches.append(patch)
            except Exception as e:
                self.logger.error(f"An error occurred in iteration {i}: {e}")
                self.runData[f'iteration_{i}']['error'] = str(e)
                self.runData[f'iteration_{i}']['traceback'] = traceback.format_exc()

        with open(self.runDataFile, 'w') as f:
            json.dump(self.runData, f, indent=2)
        self.logger.info(f"Run data saved to {self.runDataFile}")
        self.logger.info("Run complete")
        return patches