from agents.state import OverallState, SearcherState, BrowserState
from agents.schemas import SearchResult, Reference
import json

def load_state_dict_to_overall_state(result_dict: dict) -> OverallState:
    """
    Convert a dictionary to an OverallState object.
    
    Args:
        result_dict: Dictionary containing state data
        
    Returns:
        OverallState object with properly typed fields
    """
    # Handle messages - default to empty list if not present
    # Reconstruct Message objects from dictionaries
    if 'messages' in result_dict and result_dict['messages']:
        messages = []
        for msg in result_dict['messages']:
            if isinstance(msg, dict):
                # Check message type and reconstruct appropriate message object
                from langchain_core.messages import HumanMessage, AIMessage
                if msg.get('type') == 'human':
                    # Only include non-None values in constructor
                    kwargs = {
                        'content': msg.get('content', ''),
                        'additional_kwargs': msg.get('additional_kwargs', {}),
                        'response_metadata': msg.get('response_metadata', {})
                    }
                    if msg.get('name') is not None:
                        kwargs['name'] = msg.get('name')
                    if msg.get('id') is not None:
                        kwargs['id'] = msg.get('id')
                    messages.append(HumanMessage(**kwargs))
                elif msg.get('type') == 'ai':
                    # Only include non-None values in constructor
                    kwargs = {
                        'content': msg.get('content', ''),
                        'additional_kwargs': msg.get('additional_kwargs', {}),
                        'response_metadata': msg.get('response_metadata', {})
                    }
                    if msg.get('name') is not None:
                        kwargs['name'] = msg.get('name')
                    if msg.get('id') is not None:
                        kwargs['id'] = msg.get('id')
                    # Handle tool_calls properly - ensure it's a list of dicts or empty list
                    if msg.get('tool_calls') is not None:
                        tool_calls = msg.get('tool_calls')
                        if isinstance(tool_calls, list):
                            # Ensure each tool_call has proper structure
                            valid_tool_calls = []
                            for tool_call in tool_calls:
                                if isinstance(tool_call, dict) and 'args' in tool_call:
                                    # Ensure args is a dict, not an empty string
                                    if not isinstance(tool_call['args'], dict):
                                        tool_call['args'] = {}
                                    valid_tool_calls.append(tool_call)
                            kwargs['tool_calls'] = valid_tool_calls
                        else:
                            kwargs['tool_calls'] = []
                    messages.append(AIMessage(**kwargs))
                else:
                    messages.append(msg)  # Keep as is if unknown type
            else:
                messages.append(msg)  # Keep as is if not a dict
        result_dict['messages'] = messages
    
    # Handle string fields with defaults
    current_sub_question = result_dict.get("current_sub_question", "")
    current_summary = result_dict.get("current_summary", "")
    final_answer = result_dict.get("final_answer", "")
    
    # Handle boolean fields with defaults
    sub_verified = result_dict.get("sub_verified", False)
    final_verified = result_dict.get("final_verified", False)
    
    # Handle searcher_state
    searcher_data = result_dict.get("searcher_state", {})
    search_results = []
    if "search_results" in searcher_data:
        # Convert dict representations to SearchResult objects
        for result_data in searcher_data["search_results"]:
            if isinstance(result_data, dict):
                search_results.append(SearchResult(
                    url=result_data.get("url", ""),
                    snippet=result_data.get("snippet", "")
                ))
            elif isinstance(result_data, SearchResult):
                search_results.append(result_data)
    
    searcher_state = SearcherState(
        search_count=searcher_data.get("search_count", 0),
        used_keywords=searcher_data.get("used_keywords", []),
        search_results=search_results,
        search_cache=searcher_data.get("search_cache", {})
    )
    
    # Handle browser_state
    browser_data = result_dict.get("browser_state", {})
    found_references = []
    if "found_references" in browser_data:
        # Convert dict representations to Reference objects
        for ref_data in browser_data["found_references"]:
            if isinstance(ref_data, dict):
                found_references.append(Reference(
                    url=ref_data.get("url", ""),
                    information_list=ref_data.get("information_list", [])
                ))
            elif isinstance(ref_data, Reference):
                found_references.append(ref_data)
    
    browser_state = BrowserState(
        visit_count=browser_data.get("visit_count", 0),
        visited_urls=browser_data.get("visited_urls", []),
        found_references=found_references,
        visit_cache=browser_data.get("visit_cache", {})
    )
    
    # Create and return OverallState
    return OverallState(
        messages=messages,
        current_sub_question=current_sub_question,
        current_summary=current_summary,
        final_answer=final_answer,
        sub_verified=sub_verified,
        final_verified=final_verified,
        searcher_state=searcher_state,
        browser_state=browser_state
    )

def save_state(state, output_file):
    """Save the state to JSON file"""
    def serialize_pydantic(obj):
        """Custom serializer for Pydantic models and other objects"""
        if hasattr(obj, 'dict'):  # Pydantic model
            return obj.dict()
        elif hasattr(obj, '__dict__'):  # Regular object
            return obj.__dict__
        elif isinstance(obj, (list, tuple)):
            return [serialize_pydantic(item) for item in obj]
        elif isinstance(obj, dict):
            return {k: serialize_pydantic(v) for k, v in obj.items()}
        else:
            # For primitive types (int, str, bool, None), return as-is
            # Only convert to string as last resort for complex objects
            if isinstance(obj, (int, float, str, bool, type(None))):
                return obj
            else:
                return str(obj)
    
    state_dict = dict(state)
    with open(output_file, 'w') as f:
        json.dump(state_dict, f, indent=2, default=serialize_pydantic)
