"""Code validation utilities shared across the project.

This module provides common code validation functions to avoid duplication.
"""

import ast
from typing import Tuple

from .agent_utils import print_with_color, write_to_file


def validate_python_syntax(code: str, log_path: str = None) -> Tuple[bool, str]:
    """Validate Python code syntax using AST parsing.
    
    This is a shared utility used by both env_operation and rpa_builder
    to ensure code has valid Python syntax before execution.
    
    Args:
        code: The Python code string to validate
        log_path: Optional path to save syntax error details
        
    Returns:
        Tuple of (is_valid, error_message):
            - is_valid: True if code has valid syntax, False otherwise
            - error_message: Empty string if valid, error description if invalid
            
    Examples:
        >>> is_valid, msg = validate_python_syntax("print('hello')")
        >>> assert is_valid
        
        >>> is_valid, msg = validate_python_syntax("print('hello'")
        >>> assert not is_valid
        >>> assert "SyntaxError" in msg
    """
    try:
        ast.parse(code)
        print_with_color('✅ RPA code syntax validation passed', 'green')
        return True, ""
    except SyntaxError as e:
        text = e.text.strip('\n') if e.text else ''
        pointer = ''
        if e.offset and e.offset > 0:
            pointer = ' ' * (e.offset - 1) + '^\n'
        
        error_msg_lines = [
            "❌ Syntax Error in Code:",
            f"  Line {e.lineno}: {e.msg}",
            f"  Text: {text}",
        ]
        if pointer:
            error_msg_lines.append("  Offset:")
            error_msg_lines.append(pointer.rstrip('\n'))
        
        error_msg = "\n".join(error_msg_lines)
        print_with_color(error_msg, 'red')
        
        # Optionally save detailed error to file
        if log_path:
            detailed_error = (
                f"{error_msg}\n\n"
                f"Full Code:\n"
                f"{'=' * 60}\n"
                f"{code}\n"
                f"{'=' * 60}\n"
            )
            write_to_file(
                file_path=log_path,
                file_name='syntax_error.txt',
                content=detailed_error
            )
        
        return False, error_msg
    
    except Exception as e:
        error_msg = f"❌ Unexpected error during syntax validation: {e}"
        print_with_color(error_msg, 'red')
        return False, error_msg


def extract_function_names(code: str) -> list[str]:
    """Extract all function names defined in the code.
    
    Args:
        code: Python code string
        
    Returns:
        List of function names found in the code
        
    Examples:
        >>> code = "def foo():\\n    pass\\ndef bar():\\n    pass"
        >>> names = extract_function_names(code)
        >>> assert 'foo' in names and 'bar' in names
    """
    try:
        tree = ast.parse(code)
        function_names = []
        for node in ast.walk(tree):
            if isinstance(node, ast.FunctionDef):
                function_names.append(node.name)
        return function_names
    except Exception:
        return []


def has_dangerous_imports(code: str) -> Tuple[bool, list[str]]:
    """Check if code contains potentially dangerous imports.
    
    Args:
        code: Python code string
        
    Returns:
        Tuple of (has_dangerous, dangerous_modules):
            - has_dangerous: True if dangerous imports found
            - dangerous_modules: List of dangerous module names
    """
    dangerous_modules = {'os', 'sys', 'subprocess', 'shutil', 'pathlib'}
    found_dangerous = []
    
    try:
        tree = ast.parse(code)
        for node in ast.walk(tree):
            if isinstance(node, ast.Import):
                for alias in node.names:
                    if alias.name in dangerous_modules:
                        found_dangerous.append(alias.name)
            elif isinstance(node, ast.ImportFrom):
                if node.module in dangerous_modules:
                    found_dangerous.append(node.module)
        
        return len(found_dangerous) > 0, found_dangerous
    except Exception:
        return False, []

