import re
from rapidfuzz import fuzz, utils
import ast

# Function to extract id, type, name, and state from a node content
def parse_node_content(node_content):
    # Regex to match the pattern for extracting the id, type, and name
    id_match = re.search(r'\[(\d+)\]', node_content)
    type_match = re.search(r'\](\s*)([a-zA-Z]+)', node_content)
    name_match = re.search(r"'([^']*)'", node_content)
    state_match = re.search(r"(expanded|selected|disabled)=[a-zA-Z]+", node_content)

    # Extract values or set to None if not found
    node_id = id_match.group(1) if id_match else None
    node_type = type_match.group(2).strip() if type_match else None
    node_name = name_match.group(1) if name_match else None
    node_state = state_match.group(0) if state_match else None

    return node_id, node_type, node_name, node_state

# Function to parse the tree and maintain hierarchy
def parse_tree_to_json(tree_string):
    lines = tree_string.splitlines()
    
    def get_indentation(line):
        return len(line) - len(line.lstrip())

    root = {}
    stack = []
    current_parent = root
    
    for line in lines:
        # Ignore empty lines
        if not line.strip():
            continue
        
        # Create a new node
        indent = get_indentation(line)
        node_content = line.strip()
        node_id, node_type, node_name, node_state = parse_node_content(node_content)  # Parse the node content

        node = {
            "id": node_id,
            "type": node_type,
            "name": node_name,
            "state": node_state,
            "children": []
        }
        
        if len(stack) == 0:
            # This is the root node
            root = node
            stack.append((indent, node))
            current_parent = node
        else:
            # Check the indentation level
            last_indent, last_node = stack[-1]
            if indent > last_indent:
                # This node is a child of the last node
                last_node["children"].append(node)
                stack.append((indent, node))
            else:
                # Pop the stack until finding the correct parent
                while stack and stack[-1][0] >= indent:
                    stack.pop()
                current_parent = stack[-1][1]
                current_parent["children"].append(node)
                stack.append((indent, node))
    
    return root


def json_to_ax_tree(node, indent_level=0):
    # Rebuild the original line based on the JSON node structure
    line = ""

    if node is None:
        return ""
    
    if node.get("id"):
        line += f"[{node['id']}] "
    
    if node.get("type"):
        line += f"{node['type']} "
    
    if node.get("name") is not None:
        line += f"'{node['name']}'"
    
    if node.get("state"):
        line += f", {node['state']}"
    
    # Apply the indentation and create the line
    indented_line = "  " * indent_level + line.strip()

    # Start collecting the lines, including the current node
    lines = [indented_line]
    
    # Recursively process children
    for child in node.get("children", []):
        lines.extend(json_to_ax_tree(child, indent_level + 1))
    
    return lines

def json_to_ax_tree_string(json_tree):
    # Convert the JSON tree to an indented string
    lines = json_to_ax_tree(json_tree)
    return "\n".join(lines)

def filter_tree_json(tree, id_list):
    """
    Filters the tree to keep only elements whose id is in the id_list. If an element has children,
    it keeps the element but removes any children that are not in the id_list.
    """
    
    def filter_node(node):
        if node is None: 
            return None
        # If the node has no id (like the root node), or if the id is not in the list and it has no children
        if (not node.get('children') and str(node.get('id')) not in id_list) :
            return None
        
        # If the node has children, recursively filter the children
        if 'children' in node:
            if str(node.get('id')) not in id_list:
                node['id'] = 'NaN'
            filtered_children = [filter_node(child) for child in node['children'] if filter_node(child)]
            node['children'] = filtered_children
        
        # If the node's id is in the list, keep it, otherwise return None (remove it)
        return node 
    
    # Apply filtering starting from the root
    return filter_node(tree)

def remove_non_ascii(text):
    """
    Removes any non-ASCII characters and Unicode characters from a string.

    Args:
        text: The string to remove non-ASCII characters from.

    Returns:
        The string with only ASCII characters.
    """
    ascii_text = re.sub(r'[^\x00-\x7F]+', ' ', text)
    return clean_extra_space(ascii_text)

def clean_extra_space(text: str):
    return re.sub(r'\s+', ' ', text.strip())


def format_policies(policies):
    # Initialize lists for personal and organizational policies
    personal_policies_list = []
    organizational_policies_list = []

    # Independent index counters for personal and organizational policies
    personal_index = 1
    organizational_index = 1

    # Iterate through the policies and sort them into personal and organizational
    for policy in policies:
        if policy['source'] == 'user':
            personal_policies_list.append(f"{personal_index}) {policy['description']}")
            personal_index += 1
        elif policy['source'] == 'organization':
            organizational_policies_list.append(f"{organizational_index}) {policy['description']}")
            organizational_index += 1

    # Convert the lists to strings with newline separation
    personal_policies = "\n".join(personal_policies_list)
    organizational_policies = "\n".join(organizational_policies_list)

    return organizational_policies, personal_policies

def remove_send_msg(input_string):
    # Regex to find and remove send_msg_to_user blocks
    output_string = re.sub(r"send_msg_to_user\((?:'''.*?'''|\".*?\"|'.*?')\)", "", input_string, flags=re.DOTALL)
    # Strip leading/trailing whitespaces and any extra newline characters
    return output_string.strip()

def extract_functions_with_docstrings(code: str) -> str:
    """
    Extracts all function definitions, their docstrings, and the function body from a given Python code string.
    
    Args:
        code (str): The Python code string containing function definitions and other code.
        
    Returns:
        str: A clean string containing only function definitions, their docstrings, and function bodies.
        
    Explanation:
        1. Parses the input code string into an Abstract Syntax Tree (AST) using `ast.parse`.
        2. Iterates over the body of the AST, identifying function definitions using `ast.FunctionDef`.
        3. Extracts the function name, arguments, docstrings (if available), and constructs the function signature.
        4. Appends the extracted function signatures and docstrings to a list.
        5. Finally, returns the extracted functions as a single formatted string.
    """
    
    # Parse the code string into an Abstract Syntax Tree (AST)
    tree = ast.parse(code)
    
    # This list will hold the formatted function definitions
    functions = []

    # Iterate over the body of the AST (top-level elements of the parsed code)
    for node in tree.body:
        # Check if the node is a function definition (ast.FunctionDef represents a function)
        if isinstance(node, ast.FunctionDef):
            # Extract the function name
            func_name = node.name
            
            # Extract the function arguments (argument names as strings)
            args = [arg.arg for arg in node.args.args]
            
            # Construct the function signature in the form 'def function_name(arg1, arg2):'
            func_signature = f"def {func_name}({', '.join(args)}):"

            # Extract the docstring associated with the function, if it exists
            docstring = ast.get_docstring(node)
            if docstring:
                # Format the docstring to be indented properly and enclosed in triple quotes
                docstring = f'    """\n    {docstring}\n    """\n'
            else:
                # If no docstring is present, leave it empty
                docstring = ''

            # Extract the function body (from the original code string)
            # Here we look for lines that correspond to this function
            # For simplicity, we just grab the lines that start with 'def function_name'
            func_body_lines = [f'    {line}' for line in code.splitlines() if line.startswith(f'def {func_name}')]

            # Construct the complete function string (signature + docstring)
            func_str = f"{func_signature}\n{docstring}\n"
            
            # Append this formatted function string to the list
            functions.append(func_str)

    # Join all the extracted functions into a single string, separating them by new lines
    return '\n'.join(functions)