import yaml
import os
from typing import Dict, List, Any, Union

from yaml_converter.qwq.template import template


class YAMLConverter:
    """
    Converts a graph YAML definition to executable langgraph code using a template.
    
    The YAML structure should contain:
    - nodes: List of node definitions with name, function, and optional config
    - edges: List of edge strings in format "from_node -> to_node"
    - conditional_edges: List of conditional edges with routing logic
    """
    
    def __init__(self):
        self.wrapper_functions = []
        self.routing_functions = []
    
    def convert(self, graph_yaml: Dict[str, Any]) -> str:
        """
        Convert a graph YAML definition to executable langgraph code.
        
        Args:
            graph_yaml: Dictionary containing nodes, edges, and conditional_edges
            
        Returns:
            String containing the executable langgraph code
        """
        # Reset state for new conversion
        self.wrapper_functions = []
        self.routing_functions = []
        
        # Generate code sections
        blocks_code = self._generate_blocks(graph_yaml.get("nodes", []))
        edges_code = self._generate_edges(graph_yaml.get("edges", []))
        conditional_edges_code = self._generate_conditional_edges(graph_yaml.get("conditional_edges", []))
        
        # Fill template
        return template.format(
            blocks=blocks_code,
            edges=edges_code,
            conditional_edges=conditional_edges_code
        )
    
    def convert_from_file(self, yaml_file_path: str) -> str:
        """
        Convert a graph YAML file to executable langgraph code.
        
        Args:
            yaml_file_path: Path to the YAML file
            
        Returns:
            String containing the executable langgraph code
        """
        with open(yaml_file_path, 'r') as f:
            graph_yaml = yaml.safe_load(f)
        return self.convert(graph_yaml)
    
    def save_to_file(self, graph_yaml: Dict[str, Any], output_path: str):
        """
        Convert graph YAML and save the result to a Python file.
        
        Args:
            graph_yaml: Dictionary containing the graph definition
            output_path: Path where to save the generated Python code
        """
        code = self.convert(graph_yaml)
        with open(output_path, 'w') as f:
            f.write(code)
    
    def _generate_blocks(self, nodes: List[Dict[str, Any]]) -> str:
        """Generate the blocks section with node definitions and wrapper functions."""
        lines = []
        
        # Add regular nodes and create wrapper functions for nodes with config
        for node in nodes:
            name = node["name"]
            function = node["function"]
            config = node.get("config", {})
            
            if config:
                # Create wrapper function for nodes with config
                wrapper_name = f"{name}_wrapper"
                wrapper_code = self._create_wrapper_function(wrapper_name, function, config)
                self.wrapper_functions.append(wrapper_code)
                lines.append(f'builder.add_node("{name}", {wrapper_name})')
            else:
                # Regular node without config
                lines.append(f'builder.add_node("{name}", {function})')
        
        # Combine wrapper functions and node additions
        result = []
        if self.wrapper_functions:
            result.extend(self.wrapper_functions)
            result.append("")  # Empty line for separation
        
        if lines:
            result.append("# Add nodes")
            result.extend(lines)
        
        return "\n".join(result)
    
    def _create_wrapper_function(self, wrapper_name: str, function: str, config: Dict[str, Any]) -> str:
        """Create a wrapper function for a node with configuration."""
        lines = [
            f"def {wrapper_name}(state: OverallState, config):",
            "    # Inject configuration into the config",
            "    modified_config = dict(config) if config else {}",
            "    if \"configurable\" not in modified_config:",
            "        modified_config[\"configurable\"] = {}",
        ]
        
        # Add each config parameter
        for key, value in config.items():
            if isinstance(value, str):
                lines.append(f'    modified_config["configurable"]["{key}"] = "{value}"')
            else:
                lines.append(f'    modified_config["configurable"]["{key}"] = {value}')
        
        lines.append(f"    return {function}(state, modified_config)")
        lines.append("")  # Empty line after function
        
        return "\n".join(lines)
    
    def _generate_edges(self, edges: List[str]) -> str:
        """Generate the edges section with direct connections from arrow notation."""
        if not edges:
            return ""
        
        lines = ["# Add edges"]
        
        # Check if we need to add a START edge
        # If no edge starts with START, add one to the first node mentioned
        has_start_edge = any("START ->" in edge_str for edge_str in edges)
        if not has_start_edge and edges:
            # Find the first node that appears as a "from" node
            first_edge = edges[0]
            if " -> " in first_edge:
                from_node = first_edge.split(" -> ", 1)[0].strip()
                lines.append(f"builder.add_edge(START, \"{from_node}\")")
        
        for edge_str in edges:
            # Parse arrow notation: "from_node -> to_node"
            if " -> " in edge_str:
                from_node, to_node = edge_str.split(" -> ", 1)
                from_node = from_node.strip()
                to_node = to_node.strip()
                
                # Handle special START and END nodes
                if from_node == "START":
                    lines.append(f"builder.add_edge(START, \"{to_node}\")")
                elif to_node == "END":
                    lines.append(f"builder.add_edge(\"{from_node}\", END)")
                else:
                    lines.append(f"builder.add_edge(\"{from_node}\", \"{to_node}\")")
        
        return "\n".join(lines)
    
    def _generate_conditional_edges(self, conditional_edges: List[Dict[str, Any]]) -> str:
        """Generate conditional edges with routing functions."""
        if not conditional_edges:
            return ""
        
        lines = []
        
        # Generate routing functions
        for i, edge in enumerate(conditional_edges):
            from_node = edge["from"]
            condition = edge["condition"]
            routes = edge["routes"]
            
            # Convert list format routes to dictionary format
            routes_dict = self._parse_routes(routes)
            
            # Special handling for sub_verifier - add max_iterations route to END
            if from_node == "sub_verifier" and condition == "state.get('sub_verified')":
                routes_dict["max_iterations"] = "END"
            
            # Create routing function
            function_name = f"route_from_{from_node.replace('-', '_')}"
            routing_code = self._create_routing_function(function_name, condition, routes_dict)
            self.routing_functions.append(routing_code)
            
            # Create conditional edge
            routes_dict_code = self._format_routes_dict(routes_dict)
            conditional_edge_code = [
                f"builder.add_conditional_edges(",
                f"    \"{from_node}\",",
                f"    {function_name},",
                f"    {routes_dict_code},",
                f")"
            ]
            lines.extend(conditional_edge_code)
            if i < len(conditional_edges) - 1:  # Add empty line between conditional edges
                lines.append("")
        
        # Combine routing functions and conditional edges
        result = []
        if self.routing_functions:
            result.extend(self.routing_functions)
            result.append("")  # Empty line for separation
        
        if lines:
            result.append("# Add conditional edges")
            result.extend(lines)
        
        return "\n".join(result)
    
    def _parse_routes(self, routes: Union[List[Dict[str, str]], Dict[str, str]]) -> Dict[str, str]:
        """Parse routes from YAML list format or dict format to dict format."""
        if isinstance(routes, dict):
            return routes
        
        # Convert list format to dict format
        # List format: [{"false": "searcher"}, {"true": "final_verifier"}]
        # Dict format: {"false": "searcher", "true": "final_verifier"}
        routes_dict = {}
        for route_item in routes:
            if isinstance(route_item, dict):
                routes_dict.update(route_item)
            elif isinstance(route_item, str):
                # Handle simple string format like "false: searcher"
                if ": " in route_item:
                    key, value = route_item.split(": ", 1)
                    routes_dict[key.strip()] = value.strip()
        
        return routes_dict
    
    def _create_routing_function(self, function_name: str, condition: str, routes: Dict[str, str]) -> str:
        """Create a routing function for conditional edges."""
        lines = [
            f"def {function_name}(state: OverallState) -> str:",
        ]
        
        # Convert boolean keys to strings for consistency
        string_routes = {}
        for key, value in routes.items():
            if isinstance(key, bool):
                string_routes[str(key)] = value
            else:
                string_routes[str(key)] = value
        
        # Special handling for sub_verifier routing - check max iterations first
        if "sub_verifier" in function_name and condition == "state.get('sub_verified')":
            lines.append("    # Check if max iteration count reached - route to END")
            lines.append("    from agents.qwq.config import Configuration")
            lines.append("    config = Configuration()")
            lines.append("    current_iteration = state.get('current_sub_question_iteration', 0)")
            lines.append("    if current_iteration >= config.max_iteration_count:")
            lines.append("        return \"max_iterations\"")
            lines.append("")
        
        # Create the routing logic based on condition
        if "True" in string_routes and "False" in string_routes:
            # Boolean condition - return the route key, not the route value
            lines.append(f"    return \"False\" if not ({condition}) else \"True\"")
        elif "true" in string_routes and "false" in string_routes:
            # Boolean condition with lowercase
            lines.append(f"    return \"false\" if not ({condition}) else \"true\"")
        else:
            # Generate proper routing logic for all cases
            has_return = False
            for route_key, route_value in string_routes.items():
                if route_key in ["true", "True"]:
                    lines.append(f"    if {condition}:")
                    lines.append(f"        return \"{route_key}\"")
                    has_return = True
                elif route_key in ["false", "False"]:
                    lines.append(f"    if not ({condition}):")
                    lines.append(f"        return \"{route_key}\"")
                    has_return = True
                else:
                    # For custom conditions, we'll need to add logic later
                    lines.append(f"    # TODO: Add condition for route '{route_key}' -> '{route_value}'")
            
            if not has_return:
                lines.append(f"    return \"{list(string_routes.keys())[0]}\"  # Default route key")
        
        lines.append("")  # Empty line after function
        return "\n".join(lines)
    
    def _format_routes_dict(self, routes: Dict[str, str]) -> str:
        """Format the routes dictionary for the conditional edge."""
        items = []
        for key, value in routes.items():
            # Convert boolean keys to strings
            key_str = str(key) if isinstance(key, bool) else key
            
            # Handle special END node - use constant instead of string
            if value == "END":
                items.append(f'        "{key_str}": END')
            else:
                items.append(f'        "{key_str}": "{value}"')
        
        return "{\n" + ",\n".join(items) + ",\n    }"


# Convenience function for quick conversion
def convert_graph_yaml(graph_yaml: Dict[str, Any]) -> str:
    """
    Quick function to convert a graph YAML to executable code.
    
    Args:
        graph_yaml: Dictionary containing the graph definition
        
    Returns:
        String containing the executable langgraph code
    """
    converter = YAMLConverter()
    return converter.convert(graph_yaml)


def convert_graph_file(yaml_file_path: str, output_path: str = None) -> str:
    """
    Convert a graph YAML file to executable langgraph code.
    
    Args:
        yaml_file_path: Path to the YAML file
        output_path: Optional path to save the generated code
        
    Returns:
        String containing the executable langgraph code
    """
    converter = YAMLConverter()
    code = converter.convert_from_file(yaml_file_path)
    
    if output_path:
        with open(output_path, 'w') as f:
            f.write(code)
    
    return code