from typing import Union, List, Set, Dict
import os
from subprocess import run
import tiktoken


def count_tokens(text: str, model: str = "gpt-4o") -> int:
    """
    Count the number of tokens in the given text using the specified OpenAI model.

    Parameters:
    - text (str): The input text to be tokenized.
    - model (str): The OpenAI model to use for tokenization (default is "gpt-3.5-turbo").

    Returns:
    - int: The number of tokens in the text.
    """
    # Get the tokenizer for the specified model
    encoding = tiktoken.encoding_for_model(model)
    # Encode the text to tokens and return the length
    return len(encoding.encode(text))


class FileSystemManager():
    def __init__(self, root_dir:str, valid_extensions:Union[List[str], Set[str]]= None):
        self.root_dir = root_dir
        self.extensions = valid_extensions

    @staticmethod 
    def _get_nth_level_dir(file_path: str, levels_up:int) -> str:
        """
        Computes the nth-level parent directory of a given file path. For example, with levels_up=2, returns the
        grandparent directory. Automatically caps levels_up to avoid exceeding the path depth. Returns a directory
        path normalized to the OS's separator format.
        """
        if levels_up == 0 : levels_up = 1
        if levels_up  >= len(file_path.split(os.sep)):
            levels_up = len(file_path.split(os.sep)) - 1
        # Move up levels in the directory tree
        dir_path = file_path
        for _ in range(levels_up):
            dir_path = os.path.dirname(dir_path)
        return dir_path

    def _filter_files_by_extensions(self, files:Set[str]) -> Set[str]:
        """
        Filters a set of file paths to include only those matching specified extensions. Case-sensitive comparison
        is performed on the file suffixes. Returns an unmodified set if no extensions are provided.
        """
        # Use a list comprehension for concise filtering
        if not self.extensions:
            return files
        return set(file for file in files if any(file.endswith(ext) for ext in self.extensions))

    def _get_all_files_in_directory(self, directory:str, level:int) -> Set[str]:
        """
        Recursively or non-recursively collects files from a directory subtree within a base playground directory.
        When level>0, performs a full directory walk. Returns relative paths filtered by extensions, preserving
        the directory structure relative to the base_dir.
        """
        if level == 0:
            return self._filter_files_by_extensions(self._get_files_in_directory(directory))
        all_files = set()
        for dirpath, dirnames, filenames in os.walk(os.path.join(self.root_dir, directory)):
            all_files.update(map(lambda file: os.path.relpath(os.path.join(dirpath, file), self.root_dir), filenames))
        return self._filter_files_by_extensions(all_files)


    def _get_files_in_directory(self, directory) -> Set[str]:
        """
        Lists immediate children files (non-recursive) in a target directory within the playground. Returns paths
        relative to the self.root_dir, excluding subdirectories. Used for shallow directory scans.
        """
        target_dir = os.path.join(self.root_dir, directory)
        files = [
            os.path.relpath(os.path.join(target_dir, file), self.root_dir)
            for file in os.listdir(target_dir)
            if os.path.isfile(os.path.join(target_dir, file))
        ]
        return set(files)

    def get_file_length(self, file_path: str) -> int:
        with open(os.path.join(self.root_dir, file_path), 'r') as file:
            contents = file.readlines()
            return len(contents)
    
    def read_file_contents(self ,file_path: str) -> List[str]:
        with open(os.path.join(self.root_dir, file_path), 'r') as file:
            contents = file.readlines()
            return contents

    def _get_body_content(self, file_path: str, start: int, end: int) -> str:
        with open(os.path.join(self.root_dir, file_path), 'r') as file:
            contents = file.readlines()
            body = "".join(contents[start-1:end])
        return body



class TagsHandler():
    def __init__(self, tags_file:str, fs_manager:FileSystemManager):
        self.tags_file = tags_file
        self.fs_manager = fs_manager

    def _symbolsInFile(self, filePath: str) -> str:
        """Find all functions, macros, structs, typedefs, and unions in the file."""
        
        formatExpr = '(list $name "," $kind #t)'
        filterExpr = (
            f'(and (eq? $input "{filePath}") '
            '(or (eq? $kind "function") '
            '(eq? $kind "macro") '
            '(eq? $kind "struct") '
            '(eq? $kind "typedef") '
            '(eq? $kind "union")) '
            '($line) ($end))'
        )
        
        result = run(
            ["readtags", "-t", self.tags_file, "-Q", filterExpr, "-F", formatExpr, "-l"],
            cwd=os.path.dirname(self.tags_file), capture_output=True, text=True
        )
        return result.stdout
    
    def _findFunctionsInFile(self,filePath: str) -> List[Dict[str, str]]:
        """Find all functions in the given file, along with their start and end lines."""
        
        # Define format expression to extract function name, start line, and end line
        formatExpr = '(list $input "," $name "," $line "," $end #t)'
        filterExpr = f'(and (eq? $input "{filePath}") (eq? $kind "function") ($line) ($end))'
        
        result = run(
            ["readtags", "-t", self.tags_file, "-Q", filterExpr, "-F", formatExpr, "-l"],
            cwd=os.path.dirname(self.tags_file) , capture_output=True, text=True
        )
        matches = self._parseReadtagsOutput(result.stdout, True)
        return matches
    
    def populate_bodies(self, matches):
        for m in matches:
            m['body'] = self.fs_manager._get_body_content(m['file'], m['start'], m['end'])
        return matches

    def _parseReadtagsOutput(self, output: str, add_name:bool = False) -> List[Dict[str,str]]:
        matches = []
        matchStrings = output.split("\n")
        matchStrings = [ m.strip().split(',') for m in matchStrings if m.strip() ]
        for m in matchStrings:
            if len(m) == 4:
                d = {
                    'file': m[0],
                    'start': int(m[2]),
                    'end': int(m[3])
                }
                if add_name:
                    d['name'] = m[1] 

                if d not in matches:
                    matches.append(d)
        return self.populate_bodies(matches)
