import json
import os
from typing import Dict, List, Any

# Template for generating langgraph code
# template = """from langgraph.graph import StateGraph
# from langgraph.graph import START, END
# from agents.openai.building_blocks import searcher, browser, summarizer, verifier, next_sub_question_writer, finalizer
# from agents.state import OverallState

# builder = StateGraph(OverallState)

# {blocks}

# {edges}

# {conditional_edges}

# graph = builder.compile(name="search-agent")
# """
from json_converter.qwq.template import template


class JSONConverter:
    """
    Converts a graph JSON definition to executable langgraph code using a template.
    
    The JSON structure should contain:
    - nodes: List of node definitions with name, function, and optional config
    - edges: List of direct edges between nodes
    - conditional_edges: List of conditional edges with routing logic
    """
    
    def __init__(self):
        self.wrapper_functions = []
        self.routing_functions = []
    
    def convert(self, graph_json: Dict[str, Any]) -> str:
        """
        Convert a graph JSON definition to executable langgraph code.
        
        Args:
            graph_json: Dictionary containing nodes, edges, and conditional_edges
            
        Returns:
            String containing the executable langgraph code
        """
        # Reset state for new conversion
        self.wrapper_functions = []
        self.routing_functions = []
        
        # Generate code sections
        blocks_code = self._generate_blocks(graph_json.get("nodes", []))
        edges_code = self._generate_edges(graph_json.get("edges", []))
        conditional_edges_code = self._generate_conditional_edges(graph_json.get("conditional_edges", []))
        
        # Fill template
        return template.format(
            blocks=blocks_code,
            edges=edges_code,
            conditional_edges=conditional_edges_code
        )
    
    def convert_from_file(self, json_file_path: str) -> str:
        """
        Convert a graph JSON file to executable langgraph code.
        
        Args:
            json_file_path: Path to the JSON file
            
        Returns:
            String containing the executable langgraph code
        """
        with open(json_file_path, 'r') as f:
            graph_json = json.load(f)
        return self.convert(graph_json)
    
    def save_to_file(self, graph_json: Dict[str, Any], output_path: str):
        """
        Convert graph JSON and save the result to a Python file.
        
        Args:
            graph_json: Dictionary containing the graph definition
            output_path: Path where to save the generated Python code
        """
        code = self.convert(graph_json)
        with open(output_path, 'w') as f:
            f.write(code)
    
    def _generate_blocks(self, nodes: List[Dict[str, Any]]) -> str:
        """Generate the blocks section with node definitions and wrapper functions."""
        lines = []
        
        # Add regular nodes and create wrapper functions for nodes with config
        for node in nodes:
            name = node["name"]
            function = node["function"]
            config = node.get("config", {})
            
            if config:
                # Create wrapper function for nodes with config
                wrapper_name = f"{name}_wrapper"
                wrapper_code = self._create_wrapper_function(wrapper_name, function, config)
                self.wrapper_functions.append(wrapper_code)
                lines.append(f'builder.add_node("{name}", {wrapper_name})')
            else:
                # Regular node without config
                lines.append(f'builder.add_node("{name}", {function})')
        
        # Combine wrapper functions and node additions
        result = []
        if self.wrapper_functions:
            result.extend(self.wrapper_functions)
            result.append("")  # Empty line for separation
        
        if lines:
            result.append("# Add nodes")
            result.extend(lines)
        
        return "\n".join(result)
    
    def _create_wrapper_function(self, wrapper_name: str, function: str, config: Dict[str, Any]) -> str:
        """Create a wrapper function for a node with configuration."""
        lines = [
            f"def {wrapper_name}(state: OverallState, config):",
            "    # Inject configuration into the config",
            "    modified_config = dict(config) if config else {}",
            "    if \"configurable\" not in modified_config:",
            "        modified_config[\"configurable\"] = {}",
        ]
        
        # Add each config parameter
        for key, value in config.items():
            if isinstance(value, str):
                lines.append(f'    modified_config["configurable"]["{key}"] = "{value}"')
            else:
                lines.append(f'    modified_config["configurable"]["{key}"] = {value}')
        
        lines.append(f"    return {function}(state, modified_config)")
        lines.append("")  # Empty line after function
        
        return "\n".join(lines)
    
    def _generate_edges(self, edges: List[List[str]]) -> str:
        """Generate the edges section with direct connections."""
        if not edges:
            return ""
        
        lines = ["# Add edges"]
        for edge in edges:
            from_node, to_node = edge
            # Handle special START and END nodes
            if from_node == "START":
                lines.append(f"builder.add_edge(START, \"{to_node}\")")
            elif to_node == "END":
                lines.append(f"builder.add_edge(\"{from_node}\", END)")
            else:
                lines.append(f"builder.add_edge(\"{from_node}\", \"{to_node}\")")
        
        return "\n".join(lines)
    
    def _generate_conditional_edges(self, conditional_edges: List[Dict[str, Any]]) -> str:
        """Generate conditional edges with routing functions."""
        if not conditional_edges:
            return ""
        
        lines = []
        
        # Generate routing functions
        for i, edge in enumerate(conditional_edges):
            from_node = edge["from"]
            condition = edge["condition"]
            routes = edge["routes"]
            
            # Special handling for sub_verifier - add max_iterations route to END
            if from_node == "sub_verifier" and condition == "state.get('sub_verified')":
                routes["max_iterations"] = "END"
            
            # Create routing function
            function_name = f"route_from_{from_node.replace('-', '_')}"
            routing_code = self._create_routing_function(function_name, condition, routes)
            self.routing_functions.append(routing_code)
            
            # Create conditional edge
            routes_dict = self._format_routes_dict(routes)
            conditional_edge_code = [
                f"builder.add_conditional_edges(",
                f"    \"{from_node}\",",
                f"    {function_name},",
                f"    {routes_dict},",
                f")"
            ]
            lines.extend(conditional_edge_code)
            if i < len(conditional_edges) - 1:  # Add empty line between conditional edges
                lines.append("")
        
        # Combine routing functions and conditional edges
        result = []
        if self.routing_functions:
            result.extend(self.routing_functions)
            result.append("")  # Empty line for separation
        
        if lines:
            result.append("# Add conditional edges")
            result.extend(lines)
        
        return "\n".join(result)
    
    def _create_routing_function(self, function_name: str, condition: str, routes: Dict[str, str]) -> str:
        """Create a routing function for conditional edges."""
        lines = [
            f"def {function_name}(state: OverallState) -> str:",
        ]
        
        # Special handling for sub_verifier routing - check max iterations first
        if "sub_verifier" in function_name and condition == "state.get('sub_verified')":
            lines.append("    # Check if max iteration count reached - route to END")
            lines.append("    from agents.qwq.config import Configuration")
            lines.append("    config = Configuration()")
            lines.append("    current_iteration = state.get('current_sub_question_iteration', 0)")
            lines.append("    if current_iteration >= config.max_iteration_count:")
            lines.append("        return \"max_iterations\"")
            lines.append("")
        
        # Create the routing logic based on condition
        if "true" in routes and "false" in routes:
            # Boolean condition - return the route key, not the route value
            lines.append(f"    return \"false\" if not ({condition}) else \"true\"")
        else:
            # More complex routing - generate if/elif chain
            conditions = []
            for route_key, route_value in routes.items():
                if route_key == "true":
                    conditions.append(f"    if {condition}:")
                    conditions.append(f"        return \"true\"")
                elif route_key == "false":
                    conditions.append(f"    if not ({condition}):")
                    conditions.append(f"        return \"false\"")
                else:
                    # Custom condition
                    conditions.append(f"    # Custom route for {route_key}")
                    conditions.append(f"    # return \"{route_key}\"")
            
            lines.extend(conditions)
            if not any("return" in line for line in conditions):
                lines.append(f"    return \"{list(routes.keys())[0]}\"  # Default route key")
        
        lines.append("")  # Empty line after function
        return "\n".join(lines)
    
    def _format_routes_dict(self, routes: Dict[str, str]) -> str:
        """Format the routes dictionary for the conditional edge."""
        items = []
        for key, value in routes.items():
            # Handle special END node - use constant instead of string
            if value == "END":
                items.append(f'        "{key}": END')
            else:
                items.append(f'        "{key}": "{value}"')
        
        return "{\n" + ",\n".join(items) + ",\n    }"


# Convenience function for quick conversion
def convert_graph_json(graph_json: Dict[str, Any]) -> str:
    """
    Quick function to convert a graph JSON to executable code.
    
    Args:
        graph_json: Dictionary containing the graph definition
        
    Returns:
        String containing the executable langgraph code
    """
    converter = JSONConverter()
    return converter.convert(graph_json)


def convert_graph_file(json_file_path: str, output_path: str = None) -> str:
    """
    Convert a graph JSON file to executable langgraph code.
    
    Args:
        json_file_path: Path to the JSON file
        output_path: Optional path to save the generated code
        
    Returns:
        String containing the executable langgraph code
    """
    converter = JSONConverter()
    code = converter.convert_from_file(json_file_path)
    
    if output_path:
        with open(output_path, 'w') as f:
            f.write(code)
    
    return code
