import ast
import traceback
from enum import Enum
from typing import Dict, Generator, List, Optional, Set, Tuple, Any

import os
import re
import json
import xml.etree.ElementTree as ET

from tree_sitter import Node

from autoenv.engine.logs import logger

# Utils about parsing and extracting

def try_parse_nested(value):
    try:
        return json.loads(value)
    except Exception:
        pass
    try:
        element = ET.fromstring(value)
        return {child.tag: try_parse_nested(child.text or "") for child in element}
    except Exception:
        pass
    return value

def extract_pddl_domain(text: str) -> Optional[str]:
    """
    Extracts the PDDL domain file content from a code block marked with ~~~pddl ... ~~~.
    """
    match = re.search(r"~~~pddl(.*?)~~~", text, re.DOTALL)
    if match:
        return match.group(1).strip()
    return None

def extract_python_code(markdown_content: str):
    """
    Extracts Python code blocks from markdown content.

    Args:
        markdown_content (str): The markdown text to extract code from.

    Returns:
        str | None: Extracted Python code as a string, or None if no code found.
    """
    pattern = r"```python\s*(.*?)\s*```"
    matches = re.findall(pattern, markdown_content, re.DOTALL)
    if not matches:
        return None
    extracted_code = "\n\n".join(matches)
    return extracted_code

class NodeType(Enum):
    CLASS = "class_definition"
    FUNCTION = "function_definition"
    IMPORT = ["import_statement", "import_from_statement"]
    IDENTIFIER = "identifier"
    ATTRIBUTE = "attribute"
    RETURN = "return_statement"
    EXPRESSION = "expression_statement"
    ASSIGNMENT = "assignment"


def traverse_tree(node: Node) -> Generator[Node, None, None]:
    """
    Traverse the tree structure starting from the given node.

    :param node: The root node to start the traversal from.
    :return: A generator object that yields nodes in the tree.
    """
    cursor = node.walk()
    depth = 0

    visited_children = False
    while True:
        if not visited_children:
            yield cursor.node
            if not cursor.goto_first_child():
                depth += 1
                visited_children = True
        elif cursor.goto_next_sibling():
            visited_children = False
        elif not cursor.goto_parent() or depth == 0:
            break
        else:
            depth -= 1


def syntax_check(code, verbose=False):
    try:
        ast.parse(code)
        return True
    except (SyntaxError, MemoryError):
        if verbose:
            traceback.print_exc()
        return False


def code_extract(text: str) -> str:
    lines = text.split("\n")
    longest_line_pair = (0, 0)
    longest_so_far = 0

    for i in range(len(lines)):
        for j in range(i + 1, len(lines)):
            current_lines = "\n".join(lines[i : j + 1])
            if syntax_check(current_lines):
                current_length = sum(1 for line in lines[i : j + 1] if line.strip())
                if current_length > longest_so_far:
                    longest_so_far = current_length
                    longest_line_pair = (i, j)

    return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1])


def get_definition_name(node: Node) -> str:
    for child in node.children:
        if child.type == NodeType.IDENTIFIER.value:
            return child.text.decode("utf8")


def has_return_statement(node: Node) -> bool:
    traverse_nodes = traverse_tree(node)
    for node in traverse_nodes:
        if node.type == NodeType.RETURN.value:
            return True
    return False


def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]:
    def dfs_get_deps(node: Node, deps: Set[str]) -> None:
        for child in node.children:
            if child.type == NodeType.IDENTIFIER.value:
                deps.add(child.text.decode("utf8"))
            else:
                dfs_get_deps(child, deps)

    name2deps = {}
    for name, node in nodes:
        deps = set()
        dfs_get_deps(node, deps)
        name2deps[name] = deps
    return name2deps


def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]:
    queue = [entrypoint]
    visited = {entrypoint}
    while queue:
        current = queue.pop(0)
        if current not in call_graph:
            continue
        for neighbour in call_graph[current]:
            if neighbour not in visited:
                visited.add(neighbour)
                queue.append(neighbour)
    return visited


def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
    """
    Sanitize and extract relevant parts of the given Python code.
    This function parses the input code, extracts import statements, class and function definitions,
    and variable assignments. If an entrypoint is provided, it only includes definitions that are
    reachable from the entrypoint in the call graph.

    :param code: The input Python code as a string.
    :param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis.
    :return: A sanitized version of the input code, containing only relevant parts.
    """
    code = code_extract(code)
    
    try:
        # Use the more reliable fallback method directly to avoid the warnings/errors
        return fallback_sanitize_with_ast(code, entrypoint)
    except Exception as e:
        print(f"ERROR in sanitize: {str(e)}")
        # If even the fallback fails, return the original code
        return code

def fallback_sanitize_with_ast(code: str, entrypoint: Optional[str] = None) -> str:
    """A function that uses Python's built-in ast module instead of tree-sitter."""
    try:
        tree = ast.parse(code)
        imports = []
        definitions = []
        function_names = set()
        class_names = set()
        variable_names = set()
        
        # First collect all top-level definitions
        for node in ast.iter_child_nodes(tree):
            if isinstance(node, (ast.Import, ast.ImportFrom)):
                imports.append(ast.unparse(node))
            elif isinstance(node, ast.FunctionDef):
                function_names.add(node.name)
                definitions.append((node.name, ast.unparse(node)))
            elif isinstance(node, ast.ClassDef):
                class_names.add(node.name)
                definitions.append((node.name, ast.unparse(node)))
            elif isinstance(node, ast.Assign):
                for target in node.targets:
                    if isinstance(target, ast.Name):
                        variable_names.add(target.id)
                        definitions.append((target.id, ast.unparse(node)))
        
        # If entrypoint is specified, find reachable definitions
        if entrypoint:
            # Build a dependency graph
            dependencies = {}
            for name, _ in definitions:
                dependencies[name] = set()
            
            # Add edges to the dependency graph
            for name, code_str in definitions:
                # Parse the code to find references to other definitions
                node = ast.parse(code_str)
                for subnode in ast.walk(node):
                    if isinstance(subnode, ast.Name) and subnode.id in dependencies:
                        dependencies[name].add(subnode.id)
            
            # Find all definitions reachable from the entrypoint
            reachable = set()
            def dfs(name):
                if name in reachable:
                    return
                reachable.add(name)
                for dep in dependencies.get(name, []):
                    dfs(dep)
            
            # Start DFS from the entrypoint
            if entrypoint in dependencies:
                dfs(entrypoint)
            
            # Filter definitions to only include reachable ones
            filtered_defs = []
            for name, code_str in definitions:
                if name in reachable:
                    filtered_defs.append(code_str)
            definitions = filtered_defs
        else:
            # If no entrypoint, include all definitions
            definitions = [code_str for _, code_str in definitions]
        
        # Combine imports and definitions
        return "\n".join(imports + definitions)
    except Exception as e:
        print(f"AST fallback failed: {str(e)}")
        return code  # Return original code if all else fails
    
from typing import Dict
import re

def parse_xml_content(content: str, tag: str) -> dict:
    """
    Parse the given content string and extract all occurrences of the specified XML tag.

    Args:
        content (str): The string containing XML-like data.
        tag (str): The tag name to search for.

    Returns:
        dict: A dictionary with the tag as key and a list of extracted values as value.
    """
    pattern = rf"<{tag}>(.*?)</{tag}>"
    matches = re.findall(pattern, content, re.DOTALL)
    # If only one match, return as string, else as list
    if not matches:
        return {tag: None}
    elif len(matches) == 1:
        return {tag: matches[0].strip()}
    else:
        return {tag: [m.strip() for m in matches]}

def read_file_content(file_path):
    """
    Read the entire content of a Python or YAML file.

    Args:
        file_path (str): The path to the file.

    Returns:
        str: The content of the file as a string.
    """
    with open(file_path, 'r', encoding='utf-8') as f:
        return f.read()
    

def write_file_content(file_path, content):
    """
    Write the given content to a file, overwriting if it exists.

    Args:
        file_path (str): The path to the file.
        content (str): The content to write to the file.

    Returns:
        None
    """
    with open(file_path, 'w', encoding='utf-8') as f:
        f.write(content)


def get_env_paths(base_path: str) -> List[str]:
    env_paths = []
    if os.path.exists(base_path):
        for item in os.listdir(base_path):
            if item.startswith("env_") and os.path.isdir(os.path.join(base_path, item)):
                env_paths.append(os.path.join(base_path, item))
    return env_paths


def archive_files(env_folder_path: str, env_id: str = None) -> bool:
    """
    Clean up environment directory by archiving auxiliary files.
    Keeps only core environment files in the root directory.
    
    Args:
        env_folder_path (str): Path to the environment folder
        env_id (str, optional): Environment ID for logging
        
    Returns:
        bool: True if successful, False otherwise
    """
    if not env_folder_path:
        raise ValueError("env_folder_path cannot be empty")
    
    import subprocess
    import sys
    import logging
    
    logger = logging.getLogger(__name__)
    
    # Get the path to the archive script
    project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
    archive_script = os.path.join(project_root, "scripts", "run_archive_files.py")
    
    if env_id:
        logger.info(f"Archiving auxiliary files for environment: {env_id}")
    logger.info(f"Environment folder: {env_folder_path}")
    
    try:
        # Run the archive script
        result = subprocess.run(
            [sys.executable, archive_script, env_folder_path],
            capture_output=True,
            text=True,
            cwd=project_root
        )
        
        if result.returncode == 0:
            logger.info("Directory cleanup completed successfully")
            logger.info(f"Archive output: {result.stdout}")
            
            # Create done.txt file to mark completion
            done_file_path = os.path.join(env_folder_path, "done.txt")
            write_file_content(done_file_path, "")
            logger.info(f"Created done.txt file: {done_file_path}")
            
            return True
        else:
            logger.error(f"Archive script failed with return code {result.returncode}")
            logger.error(f"Error output: {result.stderr}")
            return False
            
    except Exception as e:
        logger.error(f"Error running archive script: {e}")
        return False


def parse_llm_action_response(resp: str) -> Dict[str, Any]:
    """Parse LLM response to extract action data.
    
    This function handles various LLM response formats:
    - JSON wrapped in ```json``` blocks
    - JSON wrapped in ``` blocks  
    - Raw JSON strings
    - List responses (takes first action)
    - Malformed responses (returns default action)
    
    Args:
        resp: Raw LLM response string
        
    Returns:
        Dict containing action data with 'action' and 'params' keys
    """
    try:
        # Check if response is None or empty
        if not resp:
            logger.warning("Received None or empty response from LLM")
            return {"action": "no_action", "params": {}, "_parse_error": "Empty LLM response"}
        
        # Extract JSON content from response
        start_idx = resp.find('```json')
        if start_idx != -1:
            start_idx += 7  # Skip '```json'
            end_idx = resp.find('```', start_idx)
            if end_idx != -1:
                json_str = resp[start_idx:end_idx].strip()
            else:
                json_str = resp[start_idx:].strip()
        else:
            # Fallback: try to find JSON content within ```
            start_idx = resp.find('```')
            if start_idx != -1:
                start_idx += 3  # Skip '```'
                end_idx = resp.find('```', start_idx)
                if end_idx != -1:
                    json_str = resp[start_idx:end_idx].strip()
                else:
                    json_str = resp[start_idx:].strip()
            else:
                # Final fallback: try to find JSON-like content
                json_str = resp.strip()
        
        try:
            action_data = json.loads(json_str)
        except Exception as e:
            # JSON parsing failed; include error detail for trajectory consumers
            logger.warning(f"Failed to parse action JSON '{resp}': {e}. Using default action.")
            return {
                "action": "Invalid",
                "params": {},
                "_parse_error": f"{type(e).__name__}: {e}",
            }
        
        # Handle case where LLM returns a list instead of single action
        if isinstance(action_data, list):
            if len(action_data) > 0:
                logger.warning("LLM returned a list of actions; taking the first entry")
                action_data = action_data[0]  # Take the first action
            else:
                logger.warning("LLM returned an empty list, using default action.")
                return {"action": "Invalid", "params": {}, "_parse_error": "Empty list returned by LLM"}
        
        # Ensure action_data has required structure
        if not isinstance(action_data, dict) or "action" not in action_data:
            logger.warning(f"Invalid action format: {action_data}. Using default action.")
            return {"action": "Invalid", "params": {}, "_parse_error": "Missing 'action' key or invalid dict"}
            
        return action_data
    except Exception as e:
        logger.warning(f"Unexpected error while parsing action: {e}. Using default action.")
        return {"action": "Invalid", "params": {}, "_parse_error": f"{type(e).__name__}: {e}"}
