import re
import json
import yaml
from typing import Tuple, Optional, List

from orchestrator.template import yaml_template
from agents.block_registry import SELECTION_REGISTRY, DEFINED_REGISTRY

def extract_graph(response: str) -> Tuple[bool, Optional[str]]:
    # Extract the graph JSON string from the response
    # Look for JSON content between ```json and ``` or ```JSON and ```
    json_pattern = r'```(?:json|JSON)\s*\n(.*?)\n```'
    match = re.search(json_pattern, response, re.DOTALL)
    
    if match:
        return True, match.group(1).strip()
    
    # Fallback: look for content between { and } that looks like a complete JSON object
    # Use a more sophisticated approach to find balanced braces
    def find_json_objects(text):
        results = []
        i = 0
        while i < len(text):
            if text[i] == '{':
                brace_count = 1
                start = i
                i += 1
                while i < len(text) and brace_count > 0:
                    if text[i] == '{':
                        brace_count += 1
                    elif text[i] == '}':
                        brace_count -= 1
                    i += 1
                if brace_count == 0:
                    results.append(text[start:i])
            else:
                i += 1
        return results
    
    json_objects = find_json_objects(response)
    
    # Find the most complete JSON (contains nodes, edges, conditional_edges)
    for obj in json_objects:
        if '"nodes"' in obj and '"edges"' in obj:
            return True, obj.strip()
    
    # If no JSON found, return failure
    return False, None

def list2yaml(block_list: List[str]) -> str:
    """
    Convert a list of block names to a YAML formatted string by filling in the yaml_template.
    Replaces placeholders for blocks and edges with actual workflow definitions.
    """
    if not block_list:
        # If no blocks, return minimal template with just verifiers and finalizer
        result = yaml_template.replace("**Your selected building blocks**,", "")
        result = result.replace("  - START -> **The first building block to execute when solving the question**", "")
        result = result.replace("  **Your orchestrated edges**", "")
        result = result.replace("  - **The last building block to execute when solving the question** -> sub_verifier", "  - START -> sub_verifier")
        result = result.replace("      - false: **The first building block to execute when solving the question**", "      - false: sub_verifier")
        return result
    
    # Generate YAML block definitions with proper indentation
    # The template has "  **placeholder**," so we need to account for that
    yaml_blocks = []
    for i, block_name in enumerate(block_list):
        if i == 0:
            # First block: no leading spaces (template provides them)
            yaml_blocks.append(f"- name: {block_name}")
            yaml_blocks.append(f"    function: {block_name}")
        else:
            # Subsequent blocks: need full indentation
            yaml_blocks.append(f"  - name: {block_name}")
            yaml_blocks.append(f"    function: {block_name}")
    
    blocks_yaml = '\n'.join(yaml_blocks)
    
    # Generate edges - create a sequential chain through all blocks
    first_block = block_list[0]
    last_block = block_list[-1]
    
    # Create sequential edges between blocks
    # The template has "  **Your orchestrated edges**" so we need to handle indentation properly
    orchestrated_edges = []
    for i, _ in enumerate(range(len(block_list) - 1)):
        if i == 0:
            # First edge: no leading spaces (template provides them)
            orchestrated_edges.append(f"- {block_list[i]} -> {block_list[i + 1]}")
        else:
            # Subsequent edges: need full indentation
            orchestrated_edges.append(f"  - {block_list[i]} -> {block_list[i + 1]}")
    
    edges_yaml = '\n'.join(orchestrated_edges) if orchestrated_edges else ""
    
    # Replace all placeholders in the template
    # Note: The template has "**Your selected building blocks**," with a comma, so we replace the whole thing
    result = yaml_template.replace("**Your selected building blocks**,", blocks_yaml)
    result = result.replace("**The first building block to execute when solving the question**", first_block)
    result = result.replace("**Your orchestrated edges**", edges_yaml)
    result = result.replace("**The last building block to execute when solving the question**", last_block)
    
    return result

def validate_list(block_list: List[str]) -> Tuple[bool, str]:
    """
    Validate the list of block names against the accepted patterns.
    Check if all block names exist in the available registries and follow workflow rules.
    
    Returns:
        Tuple[bool, str]: (is_valid, error_message)
    """
    # Empty list is allowed (will create minimal workflow)
    if not block_list:
        return True, "Empty block list is valid (creates minimal workflow)"
    
    # Get all available block names from both registries
    available_blocks = set()
    
    # Add blocks from SELECTION_REGISTRY
    for block in SELECTION_REGISTRY:
        available_blocks.add(block["block_name"])
    
    # Add blocks from DEFINED_REGISTRY
    for block in DEFINED_REGISTRY:
        available_blocks.add(block["block_name"])
    
    # Check for invalid block names
    invalid_blocks = []
    for block_name in block_list:
        if block_name not in available_blocks:
            invalid_blocks.append(block_name)
    
    if invalid_blocks:
        return False, f"Invalid block names found: {', '.join(invalid_blocks)}. Available blocks: {', '.join(sorted(available_blocks))}"
    
    # Check for duplicate block names
    seen_blocks = set()
    duplicate_blocks = []
    for block_name in block_list:
        if block_name in seen_blocks:
            duplicate_blocks.append(block_name)
        else:
            seen_blocks.add(block_name)
    
    if duplicate_blocks:
        return False, f"Duplicate block names found: {', '.join(duplicate_blocks)}"
    
    # Additional validation: Check workflow patterns
    block_types = {}
    for block in SELECTION_REGISTRY + DEFINED_REGISTRY:
        block_types[block["block_name"]] = block["block_type"]
    
    # Get types of selected blocks
    selected_types = [block_types[block_name] for block_name in block_list]
    
    # Basic workflow validation rules
    has_searcher = any(t == "searcher" for t in selected_types)
    has_browser = any(t == "browser" for t in selected_types)
    has_summarizer = any(t == "summarizer" for t in selected_types)
    
    # If there's no searcher or browser, must have a summarizer for processing existing information
    if not has_searcher and not has_browser and not has_summarizer:
        return False, "Workflow must include at least one searcher, browser, or summarizer block"
    
    return True, "Block list is valid"


def extract_yaml_graph(response: str) -> Tuple[bool, Optional[str]]:
    """
    Extract the graph YAML string from the response.
    Similar to extract_graph but looks for YAML content instead of JSON.
    """
    # Look for YAML content between ```yaml and ``` or ```YAML and ```
    yaml_pattern = r'```(?:yaml|YAML)\s*\n(.*?)\n```'
    match = re.search(yaml_pattern, response, re.DOTALL)
    
    if match:
        return True, match.group(1).strip()
    
    # Fallback: look for content that looks like YAML structure
    # YAML typically starts with field names followed by colons
    def find_yaml_blocks(text):
        results = []
        lines = text.split('\n')
        i = 0
        while i < len(lines):
            line = lines[i].strip()
            # Look for YAML block starting with nodes:, edges:, or conditional_edges:
            if line in ['nodes:', 'edges:', 'conditional_edges:'] or \
               (line.startswith('nodes:') or line.startswith('edges:') or line.startswith('conditional_edges:')):
                # Found potential YAML start, collect until we find another top-level block or end
                start = i
                yaml_content = [lines[i]]
                i += 1
                
                # Collect indented content and other top-level fields
                while i < len(lines):
                    current_line = lines[i]
                    # If we hit another main YAML field or end of meaningful content
                    if current_line.strip() in ['nodes:', 'edges:', 'conditional_edges:'] or \
                       (current_line.strip().startswith('nodes:') or 
                        current_line.strip().startswith('edges:') or 
                        current_line.strip().startswith('conditional_edges:')):
                        yaml_content.append(current_line)
                        i += 1
                    elif current_line.startswith('  ') or current_line.startswith('\t') or \
                         current_line.strip() == '' or current_line.startswith('-'):
                        # Indented content or list items or empty lines
                        yaml_content.append(current_line)
                        i += 1
                    else:
                        # Hit something that doesn't look like YAML content
                        break
                
                if len(yaml_content) > 1:  # Must have more than just the header
                    results.append('\n'.join(yaml_content))
            else:
                i += 1
        return results
    
    yaml_blocks = find_yaml_blocks(response)
    
    # Find the most complete YAML (contains nodes, edges, conditional_edges)
    for block in yaml_blocks:
        if 'nodes:' in block and 'edges:' in block:
            return True, block.strip()
    
    # If no YAML found, return failure
    return False, None

def validate_graph(graph_json: str) -> Tuple[bool, str]:
    # Validate the graph JSON: Check for isolated nodes, illegal edges, or dead loops
    # Return a tuple of (is_valid, error_message)
    try:
        # Parse JSON
        graph_data = json.loads(graph_json)
    except json.JSONDecodeError as e:
        return False, f"Invalid JSON format: {str(e)}"
    
    # Check required fields
    if "nodes" not in graph_data:
        return False, "Missing 'nodes' field in graph JSON"
    if "edges" not in graph_data:
        return False, "Missing 'edges' field in graph JSON"
    if "conditional_edges" not in graph_data:
        return False, "Missing 'conditional_edges' field in graph JSON"
    
    nodes = graph_data["nodes"]
    edges = graph_data["edges"]
    conditional_edges = graph_data["conditional_edges"]
    
    # Validate nodes structure
    if not isinstance(nodes, list):
        return False, "Nodes must be a list"
    
    node_names = set()
    for i, node in enumerate(nodes):
        if not isinstance(node, dict):
            return False, f"Node at index {i} must be a dictionary"
        if "name" not in node:
            return False, f"Node at index {i} missing 'name' field"
        if "function" not in node:
            return False, f"Node at index {i} missing 'function' field"
        
        node_name = node["name"]
        if node_name in node_names:
            return False, f"Duplicate node name: {node_name}"
        node_names.add(node_name)
    
    # Add special nodes
    all_nodes = node_names | {"START", "END"}
    
    # Validate edges structure
    if not isinstance(edges, list):
        return False, "Edges must be a list"
    
    edge_connections = set()
    for i, edge in enumerate(edges):
        if not isinstance(edge, list) or len(edge) != 2:
            return False, f"Edge at index {i} must be a list of 2 elements [from, to]"
        
        from_node, to_node = edge
        if from_node not in all_nodes:
            return False, f"Edge {i}: Unknown source node '{from_node}'"
        if to_node not in all_nodes:
            return False, f"Edge {i}: Unknown target node '{to_node}'"
        
        edge_connections.add((from_node, to_node))
    
    # Validate conditional edges structure
    if not isinstance(conditional_edges, list):
        return False, "Conditional edges must be a list"
    
    for i, cond_edge in enumerate(conditional_edges):
        if not isinstance(cond_edge, dict):
            return False, f"Conditional edge at index {i} must be a dictionary"
        
        required_fields = ["from", "condition", "routes"]
        for field in required_fields:
            if field not in cond_edge:
                return False, f"Conditional edge {i} missing '{field}' field"
        
        from_node = cond_edge["from"]
        if from_node not in all_nodes:
            return False, f"Conditional edge {i}: Unknown source node '{from_node}'"
        
        routes = cond_edge["routes"]
        if not isinstance(routes, dict):
            return False, f"Conditional edge {i}: Routes must be a dictionary"
        
        for route_key, target_node in routes.items():
            if target_node not in all_nodes:
                return False, f"Conditional edge {i}: Unknown target node '{target_node}' in route '{route_key}'"
            edge_connections.add((from_node, target_node))
    
    # Check for isolated nodes (nodes with no incoming or outgoing connections)
    connected_nodes = set()
    for from_node, to_node in edge_connections:
        connected_nodes.add(from_node)
        connected_nodes.add(to_node)
    
    isolated_nodes = node_names - connected_nodes
    if isolated_nodes:
        return False, f"Isolated nodes found (no connections): {', '.join(isolated_nodes)}"

    # Check for duplicate edges: there's more than one edge or conditional edge with the same from and to nodes
    edge_count = {}
    
    # Count regular edges
    for i, edge in enumerate(edges):
        if isinstance(edge, list) and len(edge) == 2:
            from_node, to_node = edge
            edge_key = (from_node, to_node)
            if edge_key not in edge_count:
                edge_count[edge_key] = []
            edge_count[edge_key].append(f"edge {i}")
    
    # Count conditional edge routes
    for i, cond_edge in enumerate(conditional_edges):
        if isinstance(cond_edge, dict) and "from" in cond_edge and "routes" in cond_edge:
            from_node = cond_edge["from"]
            routes = cond_edge["routes"]
            if isinstance(routes, dict):
                for route_key, target_node in routes.items():
                    edge_key = (from_node, target_node)
                    if edge_key not in edge_count:
                        edge_count[edge_key] = []
                    edge_count[edge_key].append(f"conditional edge {i} route '{route_key}'")
    
    # Check for duplicates
    duplicates = []
    for edge_key, sources in edge_count.items():
        if len(sources) > 1:
            from_node, to_node = edge_key
            duplicates.append(f"'{from_node}' -> '{to_node}' ({', '.join(sources)})")
    
    if duplicates:
        return False, f"Duplicate edges found: {'; '.join(duplicates)}"
    
    # Check that there's a path from START
    if not any(from_node == "START" for from_node, _ in edge_connections):
        return False, "No edge from START node found"
    
    # Check that there's a path to END
    if not any(to_node == "END" for _, to_node in edge_connections):
        return False, "No edge to END node found"
    
    # Check for unreachable END node (dead loops without progress)
    # A workflow graph can have intentional cycles via conditional edges, 
    # but there must be a path to END from START
    def can_reach_end():
        # Build adjacency list including all possible paths
        graph = {}
        for from_node, to_node in edge_connections:
            if from_node not in graph:
                graph[from_node] = []
            graph[from_node].append(to_node)
        
        # BFS to check if END is reachable from START
        visited = set()
        queue = ["START"]
        visited.add("START")
        
        while queue:
            current = queue.pop(0)
            if current == "END":
                return True
            
            for neighbor in graph.get(current, []):
                if neighbor not in visited:
                    visited.add(neighbor)
                    queue.append(neighbor)
        
        return False
    
    if not can_reach_end():
        return False, "END node is not reachable from START (potential dead loop without progress)"
    
    return True, "Graph is valid"

def validate_yaml_graph(graph_yaml: str) -> Tuple[bool, str]:
    """
    Validate the graph YAML: Check for isolated nodes, illegal edges, or dead loops.
    Similar to validate_graph but handles YAML format with arrow notation and list routes.
    Return a tuple of (is_valid, error_message)
    """
    try:
        # Parse YAML
        graph_data = yaml.safe_load(graph_yaml)
    except yaml.YAMLError as e:
        return False, f"Invalid YAML format: {str(e)}"
    
    if not isinstance(graph_data, dict):
        return False, "YAML content must be a dictionary"
    
    # Check required fields
    if "nodes" not in graph_data:
        return False, "Missing 'nodes' field in graph YAML"
    if "edges" not in graph_data:
        return False, "Missing 'edges' field in graph YAML"
    if "conditional_edges" not in graph_data:
        return False, "Missing 'conditional_edges' field in graph YAML"
    
    nodes = graph_data["nodes"]
    edges = graph_data["edges"]
    conditional_edges = graph_data["conditional_edges"]
    
    # Validate nodes structure
    if not isinstance(nodes, list):
        return False, "Nodes must be a list"
    
    node_names = set()
    for i, node in enumerate(nodes):
        if not isinstance(node, dict):
            return False, f"Node at index {i} must be a dictionary"
        if "name" not in node:
            return False, f"Node at index {i} missing 'name' field"
        if "function" not in node:
            return False, f"Node at index {i} missing 'function' field"
        
        node_name = node["name"]
        if node_name in node_names:
            return False, f"Duplicate node name: {node_name}"
        node_names.add(node_name)
    
    # Add special nodes
    all_nodes = node_names | {"START", "END"}
    
    # Validate edges structure (YAML uses arrow notation)
    if not isinstance(edges, list):
        return False, "Edges must be a list"
    
    edge_connections = set()
    for i, edge in enumerate(edges):
        if not isinstance(edge, str):
            return False, f"Edge at index {i} must be a string in format 'from -> to'"
        
        # Parse arrow notation
        if " -> " not in edge:
            return False, f"Edge at index {i} must use arrow notation 'from -> to': {edge}"
        
        parts = edge.split(" -> ")
        if len(parts) != 2:
            return False, f"Edge at index {i} must have exactly one arrow: {edge}"
        
        from_node, to_node = parts[0].strip(), parts[1].strip()
        
        if from_node not in all_nodes:
            return False, f"Edge {i}: Unknown source node '{from_node}'"
        if to_node not in all_nodes:
            return False, f"Edge {i}: Unknown target node '{to_node}'"
        
        edge_connections.add((from_node, to_node))
    
    # Validate conditional edges structure (YAML uses list format for routes)
    if not isinstance(conditional_edges, list):
        return False, "Conditional edges must be a list"
    
    for i, cond_edge in enumerate(conditional_edges):
        if not isinstance(cond_edge, dict):
            return False, f"Conditional edge at index {i} must be a dictionary"
        
        required_fields = ["from", "condition", "routes"]
        for field in required_fields:
            if field not in cond_edge:
                return False, f"Conditional edge {i} missing '{field}' field"
        
        from_node = cond_edge["from"]
        if from_node not in all_nodes:
            return False, f"Conditional edge {i}: Unknown source node '{from_node}'"
        
        routes = cond_edge["routes"]
        if not isinstance(routes, list):
            return False, f"Conditional edge {i}: Routes must be a list in YAML format"
        
        # Convert YAML list format routes to dict for validation
        routes_dict = {}
        for j, route_item in enumerate(routes):
            if isinstance(route_item, dict):
                routes_dict.update(route_item)
            elif isinstance(route_item, str):
                # Handle string format like "false: node_name"
                if ": " in route_item:
                    key, value = route_item.split(": ", 1)
                    routes_dict[key.strip()] = value.strip()
                else:
                    return False, f"Conditional edge {i}, route {j}: Invalid route format '{route_item}'"
            else:
                return False, f"Conditional edge {i}, route {j}: Route must be dict or string"
        
        # Validate route targets
        for route_key, target_node in routes_dict.items():
            if target_node not in all_nodes:
                return False, f"Conditional edge {i}: Unknown target node '{target_node}' in route '{route_key}'"
            edge_connections.add((from_node, target_node))
    
    # Check for isolated nodes (nodes with no incoming or outgoing connections)
    connected_nodes = set()
    for from_node, to_node in edge_connections:
        connected_nodes.add(from_node)
        connected_nodes.add(to_node)
    
    isolated_nodes = node_names - connected_nodes
    if isolated_nodes:
        return False, f"Isolated nodes found (no connections): {', '.join(isolated_nodes)}"

    # Check for duplicate edges
    edge_count = {}
    
    # Count regular edges
    for i, edge in enumerate(edges):
        if isinstance(edge, str) and " -> " in edge:
            parts = edge.split(" -> ")
            if len(parts) == 2:
                from_node, to_node = parts[0].strip(), parts[1].strip()
                edge_key = (from_node, to_node)
                if edge_key not in edge_count:
                    edge_count[edge_key] = []
                edge_count[edge_key].append(f"edge {i}")
    
    # Count conditional edge routes
    for i, cond_edge in enumerate(conditional_edges):
        if isinstance(cond_edge, dict) and "from" in cond_edge and "routes" in cond_edge:
            from_node = cond_edge["from"]
            routes = cond_edge["routes"]
            if isinstance(routes, list):
                # Convert list format to dict
                routes_dict = {}
                for route_item in routes:
                    if isinstance(route_item, dict):
                        routes_dict.update(route_item)
                    elif isinstance(route_item, str) and ": " in route_item:
                        key, value = route_item.split(": ", 1)
                        routes_dict[key.strip()] = value.strip()
                
                for route_key, target_node in routes_dict.items():
                    edge_key = (from_node, target_node)
                    if edge_key not in edge_count:
                        edge_count[edge_key] = []
                    edge_count[edge_key].append(f"conditional edge {i} route '{route_key}'")
    
    # Check for duplicates
    duplicates = []
    for edge_key, sources in edge_count.items():
        if len(sources) > 1:
            from_node, to_node = edge_key
            duplicates.append(f"'{from_node}' -> '{to_node}' ({', '.join(sources)})")
    
    if duplicates:
        return False, f"Duplicate edges found: {'; '.join(duplicates)}"
    
    # Check that there's a path from START
    if not any(from_node == "START" for from_node, _ in edge_connections):
        return False, "No edge from START node found"
    
    # Check that there's a path to END
    if not any(to_node == "END" for _, to_node in edge_connections):
        return False, "No edge to END node found"
    
    # Check for unreachable END node (dead loops without progress)
    def can_reach_end():
        # Build adjacency list including all possible paths
        graph = {}
        for from_node, to_node in edge_connections:
            if from_node not in graph:
                graph[from_node] = []
            graph[from_node].append(to_node)
        
        # BFS to check if END is reachable from START
        visited = set()
        queue = ["START"]
        visited.add("START")
        
        while queue:
            current = queue.pop(0)
            if current == "END":
                return True
            
            for neighbor in graph.get(current, []):
                if neighbor not in visited:
                    visited.add(neighbor)
                    queue.append(neighbor)
        
        return False
    
    if not can_reach_end():
        return False, "END node is not reachable from START (potential dead loop without progress)"
    
    # Additional validation: Check workflow logic for non-search workflows
    function_names = {node.get("function", "") for node in nodes}
    has_searcher = any(func in function_names for func in ["searcher", "fast_searcher", "advanced_searcher"])
    has_browser = any(func in function_names for func in ["browser", "fast_browser", "advanced_browser", "deep_browser"])
    has_summarizer = any(func in function_names for func in ["summarizer", "advanced_summarizer"])
    
    # If there's no searching capability, we must have a summarizer to process existing information
    if not has_searcher and not has_browser:
        if not has_summarizer:
            return False, "Workflow without searcher or browser must include a summarizer to process existing information"
        
        # Find summarizer node name
        summarizer_node = None
        for node in nodes:
            if node.get("function") == "summarizer":
                summarizer_node = node.get("name")
                break
        
        if summarizer_node:
            # Ensure there's a direct path from START to summarizer (not just through conditional failure)
            # Build adjacency list for regular (non-conditional) edges only
            direct_graph = {}
            for i, edge in enumerate(edges):
                if isinstance(edge, str) and " -> " in edge:
                    parts = edge.split(" -> ")
                    if len(parts) == 2:
                        from_node, to_node = parts[0].strip(), parts[1].strip()
                        if from_node not in direct_graph:
                            direct_graph[from_node] = []
                        direct_graph[from_node].append(to_node)
            
            # Check if summarizer is reachable via direct (non-conditional) edges from START
            def can_reach_summarizer_directly():
                visited = set()
                queue = ["START"]
                visited.add("START")
                
                while queue:
                    current = queue.pop(0)
                    if current == summarizer_node:
                        return True
                    
                    for neighbor in direct_graph.get(current, []):
                        if neighbor not in visited:
                            visited.add(neighbor)
                            queue.append(neighbor)
                
                return False
            
            # If summarizer is not directly reachable, check if it's only reachable through conditional failure
            if not can_reach_summarizer_directly():
                # This means summarizer is only reachable through conditional edges (likely failure paths)
                # For workflows without searcher/browser, this creates a logical issue where
                # the workflow can only summarize if verification fails first
                return False, f"In workflow without searcher/browser, summarizer '{summarizer_node}' should be directly reachable from START, not only through conditional failure paths"
    
    return True, "Graph is valid"
