from typing import Dict, List, Tuple, Set, Optional, Union
from pathlib import Path
from logging import Logger
import re
import os
from cresearcher.utils.types import SymbolDefinition
from cresearcher.utils.faultLocalisationUtils import TagsHandler, FileSystemManager, count_tokens
from cresearcher.prompts.localizePrompts import *
from cresearcher.utils.modelHandlers import getModelHandler
from cresearcher.utils.chunk import FilesChunker

MAX_PROMPT_LENGTH = 50000

class FunctionSelector(FilesChunker):
    def __init__(self, bug_dict:Dict, root_dir: str, tags_file: str, logger:Logger,  filter_per_file:bool = True, final_function_count:int = 5 , token_budget:int = 40000):
        super().__init__(
            bug_dict=bug_dict,
            root_dir=root_dir,
            tags_file=tags_file,
            logger=logger,
            filter_per_file=filter_per_file,
            token_budget=token_budget
        )
        self.final_function_count = final_function_count
        self.function_reasoning_paris = None

        self.modelHandler = getModelHandler(
            modelName = "gpt-4o",
            systemPrompt = STANDARD_SYSTEM_PROMPT,
            temperature = 0,
            conversational=False
        )
        self.promptTokens = 0 
        self.completionTokens = 0

    def get_result_and_token_counts(self) -> Tuple[List[Tuple[SymbolDefinition, str]], int, int]:
        """
        Returns the set of pruned files, prompt tokens, and completion tokens used during the pruning process.
        """
        return self.function_reasoning_paris, self.promptTokens, self.completionTokens

    def get_functions_pruned(self):
        return self.functions_pruned

    def prune_functions_with_ranking(self) -> List[Tuple[SymbolDefinition, str]]:
        self.function_reasoning_paris = self._prune_functions_with_ranking(self.function_reasoning_paris)
        return self

    def log_and_generate_response(self, prompt:str, prompt_name:str):
        self.logger.debug(f"Generated {prompt_name} prompt")
        self.logger.debug("PROMPT:\n" + "="*80 + "\n" + prompt + "\n" + "="*80)
        
        # Get response from LLM
        responses, tokenCounts = self.modelHandler.get_responses(prompt, return_token_count=True)
        self.promptTokens += tokenCounts['prompt_tokens']
        self.completionTokens += tokenCounts['completion_tokens']
        response = responses[0]
        self.logger.debug("Received response from LLM")
        self.logger.debug("RESPONSE:\n" + "="*80 + "\n" + response + "\n" + "="*80)
        return response

    def _prune_functions_with_ranking(self, function_reasonging_pairs: List[Tuple[SymbolDefinition, str]]) -> List[Tuple[SymbolDefinition, str]]:
        self.logger.info(f"Started pruning {len(function_reasonging_pairs)} functions with ranking")
        function_reasoning_pairs_str = "".join([(
            f"function: {pair[0].name}\n"
            f"reason: {pair[1]}\n"
            f"body of the function {pair[0].name} is\n"
            f"{pair[0].body}\n\n\n" )
            for pair in function_reasonging_pairs])

        function_names_str = "\n".join([pair[0].name for pair in function_reasonging_pairs])

        func_to_func_reason_pairs = {
            pair[0].name: pair for pair in function_reasonging_pairs
        }

        prompt = FILTER_SYMBOLS_WITH_RANKING_PROMPT.format(
            title = self.title,
            crash_report = self.crash_report,
            function_reasoning_pairs_str = function_reasoning_pairs_str,
            function_names_str = function_names_str,
            final_functions_count = self.final_function_count  
        )
        filtered_func_reason_pairs = []
        response = self.log_and_generate_response(prompt, "FILTER_SYMBOLS_WITH_RANKING_PROMPT")
        try:
            functions = re.findall(r'<functions>(.*?)</functions>', response, re.DOTALL)
            if functions:
                functions = functions[0].splitlines()
                functions = [line.split(":")[1].strip() for line in functions if ":" in line]
                
                # map to the original function reasoning pairs
                for function in set(functions):
                    if function in func_to_func_reason_pairs:
                        filtered_func_reason_pairs.append(func_to_func_reason_pairs[function])
                    else:
                        self.logger.error(f"Function {function} not found in the original function reasoning pairs")
        except:
            functions = self._retry_prompt( prompt)

            for function in set(functions):
                if function in func_to_func_reason_pairs:
                    filtered_func_reason_pairs.append(func_to_func_reason_pairs[function])
                else:
                    self.logger.error(f"Function {function} not found in the original function reasoning pairs")

        self.logger.info(f"Ranker returned {len(filtered_func_reason_pairs)} functions after ranking") 
        self.logger.info(f"Retruning top {self.final_function_count} functions")
        return filtered_func_reason_pairs

    def _retry_prompt(self, prompt:str):
        retry_count = 0
        while retry_count <=5:
            self.logger.debug(f"retrying the prompt {retry_count} time")
            response = self.log_and_generate_response(prompt, "FILTER_SYMBOLS_WITH_RANKING_PROMPT")
            try:
                functions = re.findall(r'<functions>(.*?)</functions>', response, re.DOTALL)
                if functions:
                    functions = functions[0].splitlines()
                    functions = [line.split(":")[1].strip() for line in functions if ":" in line]
                    print(functions)
                    return functions
            except:
                retry_count += 1
                continue
        self.logger.cirital("Retried function parsing prompt 5 times and failed, cannot continue executaion without parsing properly, raise exception")
        raise Exception("Failed to extract files from the response")

    def prune_files_to_functions(self, files: List[str]) -> 'FunctionSelector': 
        self.function_reasoning_paris = self._prune_files_to_functions(files)
        return self

    def _filter_symbols_with_llm(self,symbolDefList: List[SymbolDefinition], file_contents_truncated) -> List[Tuple[SymbolDefinition,str]]:
        """Prompt the LLM the right way and get the filtered symbols"""
        self.logger.info(f"Filtering {len(symbolDefList)} symbols using llm")

        if len(symbolDefList) == 0 or (len(symbolDefList) == 1 and symbolDefList[0].name == "EOF"):
            return []

        # Step 1: Get all the files in the symbolDefList
        # Step 2: Get the contents of the files
        # Step 3 : Chunk the combined files and filter the symbols
        symbolDefList = list(filter(
            lambda x: x.name != "EOF" ,
            symbolDefList
        ))
        # function_start_line_string = "".join([f"{symbol.name},{symbol.start}\n" for symbol in symbolDefList])
        function_string = "".join([f"{symbol.name}\n" for symbol in symbolDefList])
        # function_startline_to_symboldef = {
        #     (symbol.name, int(symbol.start)): symbol for symbol in symbolDefList
        # }
        function_to_symboldef = {
            symbol.name: symbol for symbol in symbolDefList
        }
        prompt = FILTER_SYMBOLS_PROMPT.format(
            title = self.title,
            crash_report = self.crash_report,
            file_contents_truncated = file_contents_truncated,
            function_string = function_string
        )
        response = self.log_and_generate_response(prompt, "FILTER_SYMBOLS_PROMPT")

        try:
            locations = re.findall(r'<function>(.*?)</function>', response, re.DOTALL)
            locations = [locations.strip() for locations in locations]
            # locations = [location.split(',') for location in locations]
        
            locations_filtered = []
            for location in locations:
                # try:
                locations_filtered.append({
                    'name': location.strip()
                    # 'start': int(location[1])
                })
                # except ValueError:
                #     print("Error in parsing the start line")
                #     locations_filtered.append({
                #         'name': location[0].strip(),
                #         'start': -1
                #     })
            
            locations = locations_filtered 

            reasonings = re.findall(r'<reasoning>(.*?)</reasoning>', response, re.DOTALL)
            reasonings = [reasoning.strip() for reasoning in reasonings]
            filteredSymbolsAndReasonings= []
            for location, reasoning in zip(locations, reasonings):
                # if (location['name'], location['start']) in function_startline_to_symboldef:
                #     filteredSymbolsAndReasonings.append(
                #             (function_startline_to_symboldef[(location['name'], location['start'])], reasoning)
                #     ) 
                if location['name'] in function_to_symboldef:
                    filteredSymbolsAndReasonings.append(
                        (function_to_symboldef[location['name']], reasoning)
                    ) 
                else: 
                    print(f"Symbol not found in the symbolDefList: {location['name']}")
                    continue
            self.logger.info(f"The origianl list contains these many elements: {len(symbolDefList)}")
            self.logger.info(f"filtered list contains these many elements: f{len(filteredSymbolsAndReasonings)}")
            return filteredSymbolsAndReasonings
        except Exception as e:
            self.logger.critical("Error in parsing responsein _filter_symbols_with_llm in FunctionSelector, cant continue, raising exception")
            raise e
    

class FileSelector():
    def __init__(self, bug_dict:Dict, root_dir:str, tags_file:str, after_1:int, after_2:int, after_3:int, logger:Logger, token_budget:int = 40000, valid_extensions:Union[List[str], Set[str]]= ['.c', '.h']):
        self.bug_dict = bug_dict
        self.crash_report = bug_dict['crash_report_data']
        self.title = bug_dict['title']
        self.root_dir = root_dir
        self.extensions = valid_extensions
        self.max_token_budget = token_budget
        self.after_1 = after_1
        self.after_2 = after_2
        self.after_3 = after_3
        self.logger = logger
        self.fs_manager = FileSystemManager(root_dir, valid_extensions)
        self.tags_handler = TagsHandler(tags_file, self.fs_manager)
        self.files_pruned = None

        self.modelHandler = getModelHandler(
            modelName = "gpt-4o",
            systemPrompt = STANDARD_SYSTEM_PROMPT,
            temperature = 0.1,
            conversational=False
        )
        self.promptTokens = 0 
        self.completionTokens = 0

    def log_and_generate_response(self, prompt:str, prompt_name:str):
        self.logger.debug(f"Generated {prompt_name} prompt")
        self.logger.debug("PROMPT:\n" + "="*80 + "\n" + prompt + "\n" + "="*80)
        
        # Get response from LLM
        responses, tokenCounts = self.modelHandler.get_responses(prompt, return_token_count=True)
        self.promptTokens += tokenCounts['prompt_tokens']
        self.completionTokens += tokenCounts['completion_tokens']
        response = responses[0]
        self.logger.debug("Received response from LLM")
        self.logger.debug("RESPONSE:\n" + "="*80 + "\n" + response + "\n" + "="*80)
        return response

    def get_files_pruned(self) -> Set[str]:
        """
        Returns the set of pruned files after the pruning process.
        """
        return self.files_pruned

    def get_result_and_token_counts(self) -> Tuple[Set[str], int, int]:
        """
        Returns the set of pruned files, prompt tokens, and completion tokens used during the pruning process.
        """
        return self.files_pruned, self.promptTokens, self.completionTokens

    def _get_files_from_crash_report(self) -> Set[str]:
        """
        Extracts unique C/C++ source and header file paths mentioned in a crash report. Uses a regex pattern
        to capture files ending with .c, .h, .hpp, etc., including nested directory structures. Returns a set
        of relative file paths found in the crash report.
        """
        # split the lines
        pattern = r'((?:[\w\-/]+/)*[\w\-]+\.(?:c|h))'

        crash_report_lines = self.crash_report.splitlines()

        list_of_matches = []
        for line in crash_report_lines:
            files = re.findall(pattern, line)
            list_of_matches.extend(files)
        self.logger.info(f"Found {len(list_of_matches)} files in the crash report") 
        return set(list_of_matches)


    def _get_close_to_far_files(self, files_in_crash_report:List[str], fill_budget:bool = True):
        """
        Orchestrates a multi-level file selection process starting from crash-relevant files. Progressively includes
        files from parent directories (up to 3 levels deep) while respecting the token budget. When fill_budget=True,
        supplements with subsystem files to fully utilize remaining tokens. Returns a curated set of files optimized
        for contextual debugging within LLM token limits.
        """
        files_in_crash_report = set(path.lstrip('/') for path in files_in_crash_report)
        files_in_crash_report = set(filter(
            lambda x: len(x.split(os.sep)) > 1,
            files_in_crash_report
        ))
        files_in_previous_level = set()    
        subsystems = set()
        print("Fill budget is set to", fill_budget)
        for level in [0, 1, 2, 3]:
            print("The tokens are at the start of level", level, count_tokens(str(files_in_previous_level)))
            files_in_current_level = set()
            relevant_dirs = set(self.fs_manager._get_nth_level_dir(f, level) for f in files_in_crash_report)
            relevant_dirs = sorted(relevant_dirs, key = lambda x: (-1) * len(x.split(os.sep)))
            print(relevant_dirs)
            for directory in relevant_dirs:
                if len(directory.split(os.sep)) == 1:
                    subsystems.add(directory)
                    continue 
                if count_tokens(str(files_in_previous_level | files_in_current_level | self.fs_manager._get_all_files_in_directory( directory, level = level))) > self.max_token_budget:
                    if fill_budget == True:
                        return sorted(self._fill_budget_remains(files_in_previous_level | files_in_current_level, subsystems))
                    return sorted(files_in_previous_level |  files_in_current_level)
                files_in_current_level.update(self.fs_manager._get_all_files_in_directory( directory, level = level))
                
            files_in_previous_level = files_in_previous_level | files_in_current_level

        if fill_budget == True:
            return sorted(self._fill_budget_remains(files_in_previous_level, subsystems))
        return sorted(files_in_previous_level)

    def _fill_budget_remains(self,filled_files, subsystems):
        """
        Strategically expands the file selection to consume remaining token budget by prioritizing subsystems
        based on their file frequency in relevant_files. Adds files in 50-file chunks from high-priority subsystems
        until the budget is reached. Returns an expanded set of files that maximizes token usage without overflow.
        """
        print("Filling budget")
        if len(subsystems) == 0:
            return filled_files
        subsystem_freq = {}
        for subsystem in subsystems:
            subsystem_freq[subsystem] = sum(1 for file in filled_files if file.startswith(subsystem))
        sorted_subsystems = sorted(subsystem_freq.items(), key = lambda x: x[1], reverse = True)
        for subsystem, _ in sorted_subsystems:
            print("Filling budget for", subsystem)
            files_in_subsystem = list(self.fs_manager._get_all_files_in_directory(subsystem, level = 1))
            chunk_size = 50
            for i in range(0, len(files_in_subsystem), 50):
                chunk = set(files_in_subsystem[i:i+chunk_size])
                if count_tokens(str(filled_files | chunk)) < self.max_token_budget:
                    filled_files = filled_files | chunk
                else:
                    return filled_files
        return filled_files

    def prune_repo_to_N_files(self) -> Set[str]:
        self.files_pruned = self._prune_repo_to_N_files()
        return self

    def _prune_repo_to_N_files(self) -> Set[str]:
        self.logger.info("Started pruning the repository to N files")
        relevant_files = self._get_close_to_far_files(
            files_in_crash_report= self._get_files_from_crash_report(),
            fill_budget=True)

        self.logger.info(f"Found {len(relevant_files)} files using close to far algo. on the repository")
        
        relevant_files = "".join([f"{file}\n" for i, file in enumerate(relevant_files)])
        prompt = LOCALIZE_N_FILES_PROMPT.format(
            title = self.title,
            crash_report = self.crash_report,
            relavent_files = (relevant_files),
            NUM_FILES_TO_LOCALIZE = self.after_1
        )
        response = self.log_and_generate_response(prompt, "LOCALIZE_N_FILES_PROMPT")

        try:
            files = re.findall(r'<files>(.*?)</files>', response, re.DOTALL)
            if files:
                files = files[0].splitlines()
                files = [line.split(":")[1].strip() for line in files if ":" in line]
                print(files)
        except:
            self.logger.info("Failed to extract files from the response in prume_repo_to_N_files, entering retry loop")
            files = self._retry_prompt(prompt)
            self.logger.info("retry loop succeded inextracting files from the response")
        
        self.logger.info(f"Pruned the repository to {len(set(files))} files")
        return set(files)

    def prune_files_with_symbols(self) -> Set[str]:
        self.files_pruned = self._prune_files_with_symbols(self.files_pruned)
        return self

    def _prune_files_with_symbols(self, files: Set[str]) -> Set[str]: 
        self.logger.info(f"Started pruning {len(files)} files with symbols")
        files_and_symbols = ""
        if isinstance(files, list):
            files = set(files)
        
        files_str = "".join([f"{file}\n" for i, file in enumerate(files)])
        for file in files:
            files_and_symbols += f"File: {file}\n"
            files_and_symbols += f"Symbols related to {file}\n" +  self.tags_handler._symbolsInFile(file)
        prompt = FINE_LOCALIZE_N_FILES_PROMPT.format(
            title = self.title,
            crash_report = self.crash_report,
            files_and_symbols = files_and_symbols,
            files_str = files_str,
            NUM_FILES_TO_LOCALIZE = 5
        ) 
        response = self.log_and_generate_response(prompt, "FINE_LOCALIZE_N_FILES_PROMPT")
        try:
            files = re.findall(r'<files>(.*?)</files>', response, re.DOTALL)
            if files:
                files = files[0].splitlines()
                files = [line.split(":")[1].strip() for line in files if ":" in line]
                print(files)
        except:
            self.logger.info("Failed to extract files from the response in prune_files_with_symbols, entering retry loop")
            files = self._retry_prompt(prompt)
            self.logger.info("retry loop succeded in extracting files from the response")

        self.logger.info(f"Pruned the files to {len(set(files))} files using symbols")
        return set(files)


    def prune_files_with_ranking(self) -> Set[str]:
        self.files_pruned = self._prune_files_with_ranking(self.files_pruned)
        return self

    def _prune_files_with_ranking(self, files: Set[str]) -> Set[str]:
        self.logger.info(f"Started ranking {len(files)} files")
        if isinstance(files, list):
            files = set(files)
        files_to_str = "".join([f"{file}\n" for i, file in enumerate(files)])
        prompt = RANK_N_FILES_PROMPT.format(
            title = self.title,
            crash_report = self.crash_report,
            files_list = (files_to_str),
        )
        response = self.log_and_generate_response(prompt, "RANK_N_FILES_PROMPT")

        try:
            files = re.findall(r'<files>(.*?)</files>', response, re.DOTALL)
            if files:
                files = files[0].splitlines()
                files = [line.split(":")[1].strip() for line in files if ":" in line]
                print(files)
        except:
            self.logger.info("Failed to extract files from the response in prune_files_with_ranking, entering retry loop")
            files = self.retry_prompt(prompt)
            self.logger.info("retry loop succeded in extracting files from the response")

        self.logger.info(f"Pruned the files to {len(set(files))} files using ranking") 
        self.logger.info(f"Returning {self.after_3} files")
        return set(list(set(files))[:self.after_3])

    def _retry_prompt(self, prompt:str):
        retry_count = 0
        while retry_count <=5:
            self.logger.info("retrying prompt {retry_count} time")
            response = self.log_and_generate_response(prompt, "RETRY_PROMPT")
            try:
                files = re.findall(r'<files>(.*?)</files>', response, re.DOTALL)
                if files:
                    files = files[0].splitlines()
                    files = [line.split(":")[1].strip() for line in files if ":" in line]
                    self.logger.info("success: extracted files from the response")
                    return files
            except:
                self.logger.info("failed again")
                retry_count += 1
                continue
        self.logger.critical("Retried function parsing prompt 5 times and failed, cannot continue executaion without parsing properly, raise exception")
        raise Exception("Failed to extract files from the response")


def localiseFunctions(bugDict: Dict[str, any], repoPath: Path, logger: Logger, tagsFilePath: Path) -> Tuple[List[Tuple[SymbolDefinition, str]], Set[str], int, int]:
    
    # data_token_limit + crash_report + 5000 = MAX_PROMPT_LENGTH
    # 5000 is to allow for the rest of the prompt and response
    # data_token_limit referes to the token budget for the data ( eg: File contents ) in the prompt
    data_token_limit = MAX_PROMPT_LENGTH - count_tokens(bugDict['crash_report_data']) - 5000 # 5000 to allow for the rest of the prompt and response 
    file_selector = FileSelector(
        bug_dict = bugDict,
        root_dir = str(repoPath.resolve()),
        tags_file = str(tagsFilePath.resolve()),
        after_1 = 20,
        after_2 = 10,
        after_3 = 5,
        logger = logger,
        token_budget = data_token_limit, 
        valid_extensions = ['.c', '.h']
    )

    func_selector = FunctionSelector(
        bug_dict = bugDict,
        root_dir = str(repoPath.resolve()),
        tags_file = str(tagsFilePath.resolve()),
        logger = logger,
        filter_per_file = True,
        final_function_count = 5,
        token_budget = min(data_token_limit, 20000 )
    )

    prompt_tokens = 0
    completion_tokens = 0
    
    _prompt_tokens = 0
    _completion_tokens = 0

    files, _prompt_tokens, _completion_tokens  = file_selector\
    .prune_repo_to_N_files()\
    .prune_files_with_symbols()\
    .prune_files_with_ranking()\
    .get_result_and_token_counts()

    prompt_tokens += _prompt_tokens
    completion_tokens += _completion_tokens

    function_reasoning_pairs, _prompt_tokens, _completion_tokens = func_selector\
    .prune_files_to_functions(files)\
    .prune_functions_with_ranking()\
    .get_result_and_token_counts()

    prompt_tokens += _prompt_tokens
    completion_tokens += _completion_tokens

    return function_reasoning_pairs, files, prompt_tokens, completion_tokens
