from typing import Dict, List, Optional, Tuple
from cresearcher.utils.types import SymbolDefinition
from cresearcher.utils.faultLocalisationUtils import count_tokens, FileSystemManager, TagsHandler
from logging import Logger
from abc import ABC, abstractmethod

"""
FilesChunker chunks the files into smaller parts ( constraint : token_budget ), at the function boundaries and then allows filtering of these functions using LLM
"""
class FilesChunker(ABC):
    def __init__(self, bug_dict:Dict, root_dir: str, tags_file: str, logger:Logger,  filter_per_file:bool = True , token_budget:int = 40000):
        self.bug_dict = bug_dict
        self.title = bug_dict['title']
        self.crash_report = bug_dict['crash_report_data']
        self.root_dir = root_dir
        self.tags_file = tags_file
        self.logger = logger
        self.filter_per_file = filter_per_file
        self.max_token_budget = token_budget
        self.filter_per_file = filter_per_file
        self.fs_manager = FileSystemManager(root_dir)
        self.tags_handler = TagsHandler(tags_file, self.fs_manager)

    def _prune_files_to_functions(self, files: List[str]) -> List[Tuple[SymbolDefinition, str]]:
        # get all the functions from the files
        self.logger.info(f"Started pruning functions from {len(files)} files")
        all_function_symbols = []
        filtered_function_symbols = []
        for file in files:
            function_symbols_in_file = self._getfunctoinSymbolsFromFile(file)
            # localFileFunctionFilteredList = _filterFunctionsFromFiles(file, localFileFunctionSymbolDefList)
            all_function_symbols.extend(function_symbols_in_file)
            if self.filter_per_file:
                filtered_function_symbols.extend(self.FilterSymbolsInChunks(function_symbols_in_file))
        # filter the functions from the files
        if not self.filter_per_file: filtered_function_symbols = self.FilterSymbolsInChunks(all_function_symbols)
        self.logger.info(f"Pruned {len(filtered_function_symbols)} functions from {len(files)} files using llm")
        return filtered_function_symbols

    def _getfunctoinSymbolsFromFile(self, file: str) -> Optional[SymbolDefinition]:
        """Get all the function Symbol Defs in the given file."""
        localFileFunctionSymbolDefList = []
        matches = self.tags_handler._findFunctionsInFile(file)
        for match in matches:
                localFileFunctionSymbolDefList.append( SymbolDefinition(
                    filePath=match['file'],
                    start=match['start'],
                    end=match['end'],
                    name=match['name'],
                    body=match['body']
                ))
        # add a fake SymbolDefinition that is the end of the file
        localFileFunctionSymbolDefList.sort(key=lambda x: x.end)
        if len(localFileFunctionSymbolDefList) != 0:
            localFileFunctionSymbolDefList.append(SymbolDefinition(
                filePath=file,
                start=localFileFunctionSymbolDefList[-1].end,
                end=self.fs_manager.get_file_length(file),
                name="EOF",
                body=""
            )) 
            # sort the localFileFunctionSymbolDefList by end line
            localFileFunctionSymbolDefList.sort(key=lambda x: x.end)
        return localFileFunctionSymbolDefList 
    

    def FilterSymbolsInChunks(self,symbolDefList: List[SymbolDefinition]) -> List[Tuple[SymbolDefinition,str]]:
        # until we hit a threshold of 40K tokens
        self.logger.info(f"FilterSymbolsInChunks: Filtering {len(symbolDefList)} functions in chunks")
        filteredSymbolDefList = []
        currentTokenCount = 0
        currentSymbolNumber = 0
        previousGroupEnd = None
        currentSymbolDefList = []
        while currentSymbolNumber < len(symbolDefList):
            symbolDef = symbolDefList[currentSymbolNumber]
            symbolTokenCount = self.symbol_token_count(symbolDef)
            if symbolTokenCount >= self.max_token_budget:
                currentSymbolNumber += 1
                self.logger.info(f"FilterSymbolsInChunks: Skipping symbol {symbolDef.name} as it exceeds token budget")
                continue
            currentTokenCount = self.symbol_group_token_count(currentSymbolDefList+ [symbolDef], previousGroupEnd)
            if currentTokenCount < self.max_token_budget and currentSymbolNumber != len(symbolDefList) - 1:
                currentSymbolDefList.append(symbolDef)
                currentSymbolNumber += 1
                
            else:
                if currentSymbolNumber == len(symbolDefList) - 1:
                    currentSymbolDefList.append(symbolDef)
                    currentSymbolNumber += 1
                filteredSymbolDefList.extend(self.filter_symbols(currentSymbolDefList, previousGroupEnd))
                previousGroupEnd = currentSymbolDefList[-1]
                currentSymbolDefList = []
        self.logger.info(f"FilterSymbolsInChunks: Retruning Filtered list with {len(filteredSymbolDefList)} elements after filtering in chunks")
        return filteredSymbolDefList

    def filter_symbols(self, symbolGroup : List[SymbolDefinition], previousGroupEnd: List[SymbolDefinition]) -> List[Tuple[SymbolDefinition, str]]:
        # for now lets see only what part of the file does these symbol defs cover 
        # to ensure that we are covering the whole file, then we can add LLM calls
        self.logger.info("Resolving file contents for symbol group to filter using llm")
        file_contents_truncated = self._resolve_file_contents_for_symbol_group(symbolGroup, previousGroupEnd)  
        return   self._filter_symbols_with_llm(symbolGroup, file_contents_truncated)


    @abstractmethod
    def _filter_symbols_with_llm(self,symbolDefList: List[SymbolDefinition], file_contents_truncated) -> List[Tuple[SymbolDefinition,str]]:
        """Prompt the LLM the right way and get the filtered symbols"""
        pass 

    def _resolve_file_contents_for_symbol_group(self, symbolGroup: List[SymbolDefinition], previousGroupEnd: List[SymbolDefinition]) -> None:
        # Step 1: Get all the files in the symbolGroup
        # Step 2: Get the contents of the files
        # Step 3: Get the file contents from the previousGroupEnd to the end of the last element in symbolGroup
        # Steo 3.1 : Put the contents in a dict so that we can add file prefixes
        files = [symbolDef.filePath for symbolDef in symbolGroup]
        if previousGroupEnd is not None:
            files.extend([previousGroupEnd.filePath])
        files = list(set(files))
        files.sort()
        filenames_to_filelines = {
            file: self.fs_manager.read_file_contents(file) for file in files
        }
        filenames_to_resolved_contents = {}
        if previousGroupEnd is not None:
            combined_symbols = [previousGroupEnd] + symbolGroup
            pairs_of_symbols = list(zip(combined_symbols, combined_symbols[1:]))
        else:
            filenames_to_resolved_contents = self._resolve_contents_for_symbol(symbolGroup[0], filenames_to_filelines)
            pairs_of_symbols = list(zip(symbolGroup, symbolGroup[1:]))
        for pair in pairs_of_symbols:
            resolved_contents = self._resolve_contents_for_symbol_pair(pair, filenames_to_filelines)
            filenames_to_resolved_contents = self.merge_dicts_with_appending(filenames_to_resolved_contents, resolved_contents) 
        aggregate_contents = "" 
        # add prefixes  to the contents of the files
        print((filenames_to_resolved_contents.keys()))
        for file, contents in filenames_to_resolved_contents.items():
            # print("Length of contents in file: ", len(contents.splitlines()))
            aggregate_contents += f"FILE: {file}\n\n{contents}"
        # print(aggregate_contents)
        # for line_number, line_content in enumerate(aggregate_contents.splitlines(), start=1):
        #     print(f"{line_number}: {line_content}")
        return aggregate_contents

    @staticmethod
    def merge_dicts_with_appending(global_dict: dict[str, str], local_dict: dict[str, str]) -> dict[str, str]:
        for key, value in local_dict.items():
                # Use .get() to fetch the current value or an empty string if the key doesn't exist
                global_dict[key] = global_dict.get(key, "") + value
        return global_dict

    def _resolve_contents_for_symbol(self,symbolDef, filenames_to_filelines):
        return {
            f'{symbolDef.filePath}':  "".join(filenames_to_filelines[symbolDef.filePath][:symbolDef.end])
        } 

    def _resolve_contents_for_symbol_pair(self, pair, filenames_to_filelines):
        start = pair[0].end
        end = pair[1].end
        start_file = pair[0].filePath
        end_file = pair[1].filePath

        if start_file == end_file:
            return {
                f'{start_file}':  "".join(filenames_to_filelines[start_file][start:end])
            } 
        else:
            return {
                f'{start_file}': "".join(filenames_to_filelines[start_file][start:]),
                f'{end_file}': "".join(filenames_to_filelines[end_file][:end])
            }
    

    def symbol_group_token_count(self, currentSymbolDefList, previousGroupEnd):
        # step 1: Get all the files in these symbolDefs
        # Step 2: Get the contents of the files from the previousGroupEnd to the end of the last element in currentSymbolDefList 
        # Step 3: Get and Return the token count of the contents

        # Step 1
        files = [symbolDef.filePath for symbolDef in currentSymbolDefList]
        if previousGroupEnd is not None:
            files.extend([previousGroupEnd.filePath])
        files.sort()
        files = list(set(files))
        # Step 2
        filenames_to_filelines = {
            file: self.fs_manager.read_file_contents(file) for file in files
        }
        token_count = 0
        if previousGroupEnd is not None:
            combined_symbols = [previousGroupEnd] + currentSymbolDefList
            pairs_of_symbols = list(zip(combined_symbols, combined_symbols[1:]))
        else:
            token_count = self.count_token_for_symbol(currentSymbolDefList[0], filenames_to_filelines)
            pairs_of_symbols = list(zip(currentSymbolDefList, currentSymbolDefList[1:]))
        for pair in pairs_of_symbols:
            token_count += self.count_tokens_for_pairs(pair, filenames_to_filelines)
        return token_count

    def symbol_token_count(self, symbolDef):
        return count_tokens(symbolDef.body, model='gpt-4o')

    def count_token_for_symbol(self, symbolDef, filenames_to_filelines):
        contents = "".join(filenames_to_filelines[symbolDef.filePath][:symbolDef.end])
        # print(f"Reading file: {symbolDef.filepath}, start: 0, end: {symbolDef.end}")
        return count_tokens(contents, model='gpt-4o')

    def count_tokens_for_pairs(self, pair, filenames_to_filelines):
        start = pair[0].end
        end = pair[1].end
        start_file = pair[0].filePath
        end_file = pair[1].filePath

        if start_file == end_file:
            contents = "".join(filenames_to_filelines[start_file][start:end])
            # print(f"Reading file: {start_file}, start: {start}, end: {end}")
            return count_tokens(contents, model='gpt-4o')
        else:
            contents = "".join(filenames_to_filelines[start_file][start:]) + "".join(filenames_to_filelines[end_file][:end])
            # print(f"Reading file: {start_file}, start: {start}, end: EOF")
            # print(f"Reading file: {end_file}, start: 0, end: {end}")
            return count_tokens(contents, model='gpt-4o')
