import subprocess
import os
import re
from pathlib import Path
from logging import Logger
from typing import List, Dict, Optional
from cresearcher.utils.types import SymbolDefinition

def updateRepoTagsFile(repoPath: Path, tagsFilePath: Path, langs:str, logger: Logger) -> bool:
    """Update the ctags file for the repository directory."""
#    return True
    logger.info("Updating ctags...")
    if os.path.exists(tagsFilePath):
        logger.info("Tags file for current repository state already exists. Removing...")
        os.remove(tagsFilePath)
    
    logger.debug(f"Running command: ctags -R --fields=+neK --languages={langs} -f {str(tagsFilePath.resolve())} in directory {repoPath}")
    result = subprocess.run(["ctags", "-R", "--fields=+neK", f"--languages={langs}", "-f", str(tagsFilePath.resolve())], 
                cwd=repoPath, capture_output=True, text=True)
    
    if result.returncode != 0:
        logger.error(f"Failed to generate tags: {result.stderr}")
        return False
    
    logger.debug("Tags updated successfully")
    return True

def parseReadtagsOutput(output: str) -> List[Dict[str,str | int]]:
    """ Get matches from readtags output as a list of dicts with keys: file, name, start, end """
    matches = []
    matchStrings = output.split("\n")
    matchStrings = [ m.strip().split(',') for m in matchStrings if m.strip() ]
    for m in matchStrings:
        if len(m) == 4:
            d = {
                'file': m[0],
                'name': m[1],
                'start': int(m[2]),
                'end': int(m[3])
            }
            if d not in matches:
                matches.append(d)
    return matches

def getAllRegexes(bugDict: Dict[str, any], buggyFunctions: List[SymbolDefinition],  openDefinitions: Dict[str, List[SymbolDefinition]]) -> List[str]:
    """Get a hierarchical (buggy func files, open defn. files, bug subsystem files, all files) list of regexes for the given bug. These are used in the readtags command to search for symbols hierarchically."""
    regexes = []

    if buggyFunctions:
        fileStrings = set(f.filePath for f in buggyFunctions)
        # Create all extension variants for each file path
        fileVariants = set()
        for f in fileStrings:
            base = os.path.splitext(f)[0] if '.' in f else f
            fileVariants.add(f"{base}.c")
            fileVariants.add(f"{base}.h")
            fileVariants.add(f"{base}.py")
        fileStrings = fileStrings.union(fileVariants)
        fileStrings = fileStrings.union(f"include/{f}" for f in fileStrings)
        buggyRegex = f'({"|".join([ re.escape(f) for f in fileStrings ]) })'
        regexes.append(buggyRegex)
    
    if openDefinitions:
        fileStrings = set(f for f in openDefinitions.keys())
        # Create all extension variants for each file path
        fileVariants = set()
        for f in fileStrings:
            base = os.path.splitext(f)[0] if '.' in f else f
            fileVariants.add(f"{base}.c")
            fileVariants.add(f"{base}.h")
            fileVariants.add(f"{base}.py")
        fileStrings = fileStrings.union(fileVariants)
        fileStrings = fileStrings.union(f"include/{f}" for f in fileStrings)
        openRegex = f'({"|".join([ re.escape(f) for f in fileStrings ]) })'
        regexes.append(openRegex)

    if bugDict.get('subsystems', None):
        ss = bugDict['subsystems']
        fileStrings = set( re.escape(f) + "/.*\\.(c|h|py)" for f in ss )
        fileStrings = fileStrings.union( f"include/{f}" for f in fileStrings )
        subsystemRegex = f'({"|".join([ f for f in fileStrings ]) })'
        regexes.append(subsystemRegex)

    allRegex = '(.*\\.(c|h|py))'
    regexes.append(allRegex)
    return regexes

def findSymbolInTags(tagsFilePath: Path, symbolName: str, filePath: Optional[str], buggyFunctions: List[SymbolDefinition], openDefinitions: Dict[str, List[SymbolDefinition]], bugDict: Dict[str, any], logger: Logger) -> List[Dict[str,str|int]]:
    """Find all occurrences of a symbol in the tags file. Return a list of dicts with keys: file, name, start, end."""
    if not tagsFilePath.exists():
        logger.info(f"Tags file {tagsFilePath} not found. Can't search for symbol.")
        return []
    
    matches, MAX_MATCHES = [], 5
    logger.debug(f"Searching for symbol: {symbolName}{f' in {filePath}' if filePath else ''}")

    formatExpr = '(list $input "," $name "," $line "," $end #t)'
    if filePath:
        filterExpr = f'(and (eq? $input "{filePath}") ($line) ($end))'
        # Run readtags with file filter and format
        try:
            logger.debug(f"Running command: readtags -t {tagsFilePath.name} -Q {filterExpr} -F {formatExpr} - {symbolName}")
            result = subprocess.run(
                ["readtags", "-t", tagsFilePath.name, "-Q", filterExpr, "-F", formatExpr, "-", symbolName],
                cwd = tagsFilePath.parent, capture_output=True, text=True
            )
            logger.debug(f"readtags result: {result.stdout}")
            matches = parseReadtagsOutput(result.stdout)
            logger.debug(f"Found {len(matches)} matching symbols")
            logger.debug(f"Matches: {matches}")
            return matches[:MAX_MATCHES]
        except Exception as e:
            logger.error(f"Error searching tags: {e}")
            return []

    # Hierarchically search first the files of the potentiallyBuggyFunctions, then the files of the openDefinitions, then the files in the subsystems of the bugs, and then all files
    for r in getAllRegexes(
        bugDict=bugDict,
        buggyFunctions=buggyFunctions,
        openDefinitions=openDefinitions
    ):
        filterExpr = f'(and ( (string->regexp "{r}") $input) ($line) ($end))'
        try:
            logger.debug(f"Running command: readtags -t {tagsFilePath.name} -Q {filterExpr} -F {formatExpr} - {symbolName}")
            result = subprocess.run(
                ["readtags", "-t", tagsFilePath.name, "-Q", filterExpr, "-F", formatExpr, "-", symbolName],
                cwd = tagsFilePath.parent, capture_output=True, text=True
            )
            logger.debug(f"readtags result: {result.stdout}")
            newMatches = parseReadtagsOutput(result.stdout)
            logger.debug(f"Found {len(newMatches)} matching symbols")
            logger.debug(f"Matches: {newMatches}")
            for m in newMatches:
                if m not in matches:
                    matches.append(m)
        except Exception as e:
            logger.error(f"Error searching tags: {e}")
        
        if len(matches) >= MAX_MATCHES:
            return matches[:MAX_MATCHES]
    
    return matches[:MAX_MATCHES]

def extractSymbolBody(filePath: Path, startLine: int, endLine: int, logger: Logger) -> Optional[str]:
    """Extract symbol body from a file."""
    logger.info(f"Extracting symbol body from {filePath} lines {startLine}-{endLine}")
    try:
        if not filePath.exists():
            logger.error(f"Error: File {filePath} not found")
            return None
            
        with open(filePath, 'r') as f:
            lines = f.readlines()
            
        if startLine < 1 or endLine > len(lines):
            logger.error(f"Error: Invalid line numbers {startLine}-{endLine} for file with {len(lines)} lines")
            return None
            
        symbolLines = lines[startLine-1:endLine]
        return ''.join(line if line.endswith('\n') else line + '\n' for line in symbolLines)
            
    except Exception as e:
        logger.error(f"Error extracting symbol body: {e}")
        return None

def createSymbolDefinition(filePath: str, repoPath: Path, symbolName: str, startLine: int, endLine: int, logger: Logger) -> Optional[SymbolDefinition]:
    """Create a SymbolDefinition object from file and line information."""
    logger.debug(f"Creating SymbolDefinition object for {symbolName} in {filePath}")
    fullPath = repoPath / filePath
    body = extractSymbolBody(fullPath, startLine, endLine, logger)
    if body is None:
        return None
        
    return SymbolDefinition(
        filePath=filePath,
        name=symbolName,
        start=startLine,
        end=endLine,
        body=body
    )

def getMarkdownShortForm(filePath: str) -> str:
    """Get the markdown short form for a file path."""
    _, fileExtension = os.path.splitext(filePath)
    if fileExtension == '.c':
        return 'c'
    elif fileExtension == '.h':
        return 'c'
    elif fileExtension == '.cpp':
        return 'cpp'
    elif fileExtension == '.py':
        return 'py'
    elif fileExtension in ['.S', '.s', '.asm']:
        return 'asm'
    elif fileExtension == '.cu':
        return 'cuda'
    elif fileExtension == '.hpp':
        return 'hpp'
    elif fileExtension == '.md':
        return 'md'
    else:
        return 'txt'
