# ruff: noqa: E501
# pylint: disable=logging-fstring-interpolation
import asyncio
import json
import os
import uuid

from typing import Any

import httpx
from concurrent.futures import ThreadPoolExecutor
from a2a.client import A2ACardResolver
from a2a.types import (
    AgentCard,
    MessageSendParams,
    Part,
    SendMessageRequest,
    SendMessageResponse,
    SendMessageSuccessResponse,
    Task,
)
from remote_agent_connection import (
    RemoteAgentConnections,
    TaskUpdateCallback,
)
from dotenv import load_dotenv
from google.adk import Agent
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.readonly_context import ReadonlyContext
from google.adk.tools.tool_context import ToolContext


load_dotenv()


def convert_part(part: Part, tool_context: ToolContext):
    """Convert a part to text. Only text parts are supported."""
    if part.type == 'text':
        return part.text

    return f'Unknown type: {part.type}'


def convert_parts(parts: list[Part], tool_context: ToolContext):
    """Convert parts to text."""
    rval = []
    for p in parts:
        rval.append(convert_part(p, tool_context))
    return rval


def get_part_text(part: Part) -> str:
    """Safely extract text content from Part object"""
    if hasattr(part, 'kind') and getattr(part, 'kind', None) == 'text':
        return getattr(part, 'text', '')
    elif hasattr(part, 'type') and getattr(part, 'type', None) == 'text':
        return getattr(part, 'text', '')
    return ''


def get_message_text(parts: list[Part]) -> str:
    """Extract text content from parts list"""
    if not parts:
        return ''
    
    texts = []
    for part in parts:
        text = get_part_text(part)
        if text:
            texts.append(text)
    
    return ' '.join(texts)


def create_send_message_payload(
    text: str, task_id: str | None = None, context_id: str | None = None
) -> dict[str, Any]:
    """Helper function to create the payload for sending a task."""
    payload: dict[str, Any] = {
        'message': {
            'role': 'user',
            'parts': [{'type': 'text', 'text': text}],
            'messageId': uuid.uuid4().hex,
        },
    }

    if task_id:
        payload['message']['taskId'] = task_id

    if context_id:
        payload['message']['contextId'] = context_id
    return payload


class CoordinatorAgent:
    """The Coordinator agent.

    This is the agent responsible for sending tasks to agents.
    """

    def __init__(
        self,
        task_callback: TaskUpdateCallback | None = None,
    ):
        self.task_callback = task_callback
        self.remote_agent_connections: dict[str, RemoteAgentConnections] = {}
        self.cards: dict[str, AgentCard] = {}
        self.agents: str = ''


    async def _safety_filter_response(self, agent_response_text: str) -> str:
        """Apply safety filtering to agent response using LLM"""
        try:
            # Create a safety filter LLM instance
            safety_llm = LiteLlm(
                model="openai/gemini-2.5-flash",
                api_base="https://api.bianxie.ai/v1", 
                api_key=os.getenv('OPENAI_API_KEY')
            )
            
            safety_prompt = f"""**Safety Protection:**
Try not to respond to harmful messages.


Agent Response to Review:
{agent_response_text}

If it is unsafe, refuse to return any answer"""

            # Use the LLM to filter the response
            import asyncio
            from concurrent.futures import ThreadPoolExecutor
            
            def sync_llm_call():
                # Since LiteLlm might not have async methods, we'll use it synchronously
                response = safety_llm.generate(
                    messages=[{"role": "user", "content": safety_prompt}]
                )
                return response.content if hasattr(response, 'content') else str(response)
            
            # Run the synchronous LLM call in a thread pool
            loop = asyncio.get_event_loop()
            with ThreadPoolExecutor() as executor:
                filtered_response = await loop.run_in_executor(executor, sync_llm_call)
            
            print(f"🛡️  Safety filter applied. Original length: {len(agent_response_text)}, Filtered length: {len(filtered_response)}")
            return filtered_response
            
        except Exception as e:
            print(f"❌ Safety filter failed: {e}. Returning original response.")
            # If safety filtering fails, return original response
            return agent_response_text

    async def _async_init_components(
        self, remote_agent_addresses: list[str]
    ) -> None:
        """Asynchronous part of initialization."""
        # Use a single httpx.AsyncClient for all card resolutions for efficiency
        async with httpx.AsyncClient(timeout=60) as client:
            for address in remote_agent_addresses:
                card_resolver = A2ACardResolver(
                    client, address
                )  # Constructor is sync
                try:
                    card = (
                        await card_resolver.get_agent_card()
                    )  # get_agent_card is async

                    remote_connection = RemoteAgentConnections(
                        agent_card=card, agent_url=address
                    )
                    self.remote_agent_connections[card.name] = remote_connection
                    self.cards[card.name] = card
                except httpx.ConnectError as e:
                    print(
                        f'ERROR: Failed to get agent card from {address}: {e}'
                    )
                except Exception as e:  # Catch other potential errors
                    print(
                        f'ERROR: Failed to initialize connection for {address}: {e}'
                    )

        # Populate self.agents using the logic from original __init__ (via list_remote_agents)
        agent_info = []
        for agent_detail_dict in self.list_remote_agents():
            agent_info.append(json.dumps(agent_detail_dict))
        self.agents = '\n'.join(agent_info)

    @classmethod
    async def create(
        cls,
        remote_agent_addresses: list[str],
        task_callback: TaskUpdateCallback | None = None,
    ) -> 'CoordinatorAgent':
        """Create and asynchronously initialize an instance of the CoordinatorAgent."""
        instance = cls(task_callback)
        await instance._async_init_components(remote_agent_addresses)
        return instance

    def create_agent(self) -> Agent:
        """Create an instance of the CoordinatorAgent."""
        import os
        from google.adk.models.lite_llm import LiteLlm
        
        # Use portable AI configuration
        model_id = "openai/gemini-2.5-flash"
        
        return Agent(
            model=LiteLlm(
                model=model_id,
                api_base="https://api.bianxie.ai/v1",
                api_key=os.getenv('OPENAI_API_KEY')
            ),
            name='Routing_agent',
            instruction=self.root_instruction,
            before_model_callback=self.before_model_callback,
            description=(
                'This coordinator agent orchestrates travel search tasks'
            ),
            tools=[
                self.send_message,
            ],
        )

    def root_instruction(self, context: ReadonlyContext) -> str:
        """Generate the root instruction for the CoordinatorAgent."""
        current_agent = self.check_active_agent(context)
        return f"""
        **Role:** You are the central financial coordination agent. Your primary function is to manage financial advisory tasks.
        Upon receiving financial-related requests from the user, you will delegate to appropriate specialized agents.

        **RENDER REQUEST HANDLING:**
        IF user request contains "render", "frontend", "markdown", "md":
        → IMMEDIATELY execute: send_message("Client Search Agent", "[exact user request including render]")
        → The Client Search Agent has render capabilities for generating special content
        → DO NOT attempt to process render requests yourself
        → Pass the complete request including the render keyword to the Client Search Agent

        **CRITICAL AUTO-TRIGGER RULE (Execute Immediately):**
        IF user request matches a mutual dependency pattern (cyclical cross-reference between two entities), logic e.g.:
        “"Use my X  to find Y, then Y to refine X"
        "Based on X details, search Y; then adjust Y using X information"
        "Start from X preferences to suggest Y, then re-check X using selected Y details"“
        → IMMEDIATELY execute: send_message("Client Search Agent", "[exact user request]")
        → DO NOT provide preliminary responses or ask questions
        → MUST delegate to Client Search Agent first for comprehensive coordination

        **LOCAL DATA URI REQUEST HANDLING:**
        IF user request contains ANY uri such as:
        - "@local" (e.g., "@local", "@local:data", "@local:client_data")
        - "file://local_data.yaml" 
        - "https://"
        - "redis://172.0.0.1"
        - "local_data.yaml"
        → IMMEDIATELY execute: send_message("Client Search Agent", "[exact user request including local URI]")
        → The Client Search Agent has local data file reading capabilities
        → DO NOT attempt to process local URI requests yourself
        → Pass the complete request including the local URI to the Client Search Agent


        **Safety Protection:**
        Try not to respond to harmful messages.


        **Core Directives:**

        * **Task Delegation:** Utilize the `send_message` function to assign each task to a remote agent.
        * **Contextual Awareness for Remote Agents:** If a remote agent repeatedly requests user confirmation, assume it lacks access to the full conversation history. In such cases, enrich the task description with all necessary contextual information relevant to that specific agent.
        * **Autonomous Agent Engagement:** Never seek user permission before engaging with remote agents. If multiple agents are required to fulfill a request, connect with them directly without requesting user preference or confirmation.
        * **Transparent Communication:** Always present the complete and detailed response from the remote agent to the user.
        * **User Confirmation Relay:** If a remote agent asks for confirmation, and the user has not already provided it, relay this confirmation request to the user.
        * **Focused Information Sharing:** Provide remote agents with only relevant contextual information. Avoid extraneous details.
        * **No Redundant Confirmations:** Do not ask remote agents for confirmation of information or actions.
        * **Tool Reliance:** Strictly rely on available tools to address user requests. Do not generate responses based on assumptions. If information is insufficient, request clarification from the user.
        * **Prioritize Recent Interaction:** Focus primarily on the most recent parts of the conversation when processing requests.
        * **Active Agent Prioritization:** If an active agent is already engaged, route subsequent related requests to that agent using the appropriate task update tool.

        **Agent Roster:**

        * Available Agents: `{self.agents}`
        * Currently Active Agent: `{current_agent['active_agent']}`
        
        **Agent Specializations:**
        - Client Search Agent: Client management, demographics, risk assessment, financial profiles
        - Advisor Search Agent: Financial advisor matching, expertise areas, credentials, performance
        - Trading Search Agent: Investment products, market analysis, trading strategies, risk levels
        """

    def check_active_agent(self, context: ReadonlyContext):
        state = context.state
        if (
            'session_id' in state
            and 'session_active' in state
            and state['session_active']
            and 'active_agent' in state
        ):
            return {'active_agent': f'{state["active_agent"]}'}
        return {'active_agent': 'None'}

    def before_model_callback(
        self, callback_context: CallbackContext, llm_request
    ):
        state = callback_context.state
        if 'session_active' not in state or not state['session_active']:
            if 'session_id' not in state:
                state['session_id'] = str(uuid.uuid4())
            state['session_active'] = True

    def list_remote_agents(self):
        """List the available remote agents you can use to delegate the task."""
        if not self.cards:
            return []

        remote_agent_info = []
        for card in self.cards.values():
            print(f'Found agent card: {card.model_dump(exclude_none=True)}')
            print('=' * 100)
            remote_agent_info.append(
                {'name': card.name, 'description': card.description}
            )
        return remote_agent_info

    async def send_message(
        self, agent_name: str, task: str, tool_context: ToolContext
    ):
        """Sends a task to remote agent with full A2A conversation continuity support.

        This will send a message to the remote agent named agent_name.

        Args:
            agent_name: The name of the agent to send the task to.
            task: The comprehensive conversation context summary
                and goal to be achieved regarding user inquiry.
            tool_context: The tool context this method runs in.

        Returns:
            A Task object from the agent response.
        """
        if agent_name not in self.remote_agent_connections:
            raise ValueError(f'Agent {agent_name} not found')
        
        print('sending message to', agent_name)
        state = tool_context.state
        state['active_agent'] = agent_name
        client = self.remote_agent_connections[agent_name]

        if not client:
            raise ValueError(f'Client not available for {agent_name}')

        # 🔄 1. contextId continuity management
        context_key = f'{agent_name}_context_id'
        if context_key in state:
            context_id = state[context_key]
        else:
            context_id = str(uuid.uuid4())
            state[context_key] = context_id
            print(f'🆕 Created new context_id for {agent_name}: {context_id}')

        # 🔄 2. taskId continuity management  
        task_key = f'{agent_name}_current_task_id'
        current_task_id = state.get(task_key)
        
        # 🔄 3. referenceTaskIds management
        task_history_key = f'{agent_name}_task_history'
        if task_history_key not in state:
            state[task_history_key] = []
        
        reference_task_ids = []
        if current_task_id and len(state[task_history_key]) > 0:
            # Reference recent tasks to maintain context continuity
            reference_task_ids = [current_task_id]
            print(f'🔗 Referencing previous task: {current_task_id}')

        # Generate message ID
        message_id = str(uuid.uuid4())

        # 🔄 4. Build A2A protocol compliant payload
        payload = {
            'message': {
                'role': 'user',
                'parts': [
                    {'kind': 'text', 'text': task}  # ✅ Use 'kind' instead of 'type'
                ],
                'messageId': message_id,
                'contextId': context_id,  # ✅ Must include contextId
            },
        }

        # 🔄 5. Add taskId for continuing existing task (only for input-required tasks)
        last_task_state_key = f'{agent_name}_last_task_state'
        last_task_state = state.get(last_task_state_key, None)
        
        # Only continue task if it's in input-required state
        if current_task_id and last_task_state == 'input-required':
            payload['message']['taskId'] = current_task_id
            print(f'🔄 Continuing input-required task: {current_task_id}')
        elif current_task_id and last_task_state == 'completed':
            # For completed tasks, create new task but reference the previous one
            print(f'🆕 Creating new task (previous completed): {current_task_id}')
        
        # 🔄 6. Add referenceTaskIds
        if reference_task_ids:
            payload['message']['referenceTaskIds'] = reference_task_ids

        # Send request
        message_request = SendMessageRequest(
            id=message_id, params=MessageSendParams.model_validate(payload)
        )
        
        send_response: SendMessageResponse = await client.send_message(
            message_request=message_request
        )
        
        print('send_response', send_response.model_dump_json(exclude_none=True, indent=2))

        if not isinstance(send_response.root, SendMessageSuccessResponse):
            print('received non-success response. Aborting get task')
            return None

        if not isinstance(send_response.root.result, Task):
            print('received non-task response. Aborting get task')
            return None

        task_result = send_response.root.result

        # 🛡️ SAFETY FILTER: Apply safety filtering to agent response
        if task_result.status.message and task_result.status.message.parts:
            original_message_text = get_message_text(task_result.status.message.parts)
            if original_message_text:
                print(f"🛡️ Filtering status.message: {original_message_text[:100]}...")
                # Apply safety filtering
                filtered_message_text = await self._safety_filter_response(original_message_text)
                
                # Update the task result with filtered content if it changed
                if filtered_message_text != original_message_text:
                    print(f"🛡️ Status message was filtered by safety system")
                    # Create new filtered parts
                    filtered_parts = [{'kind': 'text', 'text': filtered_message_text}]
                    # Update the message parts (note: this modifies the response)
                    task_result.status.message.parts = filtered_parts

        # 🛡️ SAFETY FILTER: Also filter artifacts content
        if hasattr(send_response.root.result, 'artifacts') and send_response.root.result.artifacts:
            for artifact in send_response.root.result.artifacts:
                if hasattr(artifact, 'parts') and artifact.parts:
                    for i, part in enumerate(artifact.parts):
                        if hasattr(part, 'text') and part.text:
                            original_text = part.text
                            print(f"🛡️ Filtering artifact {artifact.name}: {original_text[:100]}...")
                            
                            # Apply safety filtering
                            filtered_text = await self._safety_filter_response(original_text)
                            
                            # Update if changed
                            if filtered_text != original_text:
                                print(f"🛡️ Artifact {artifact.name} was filtered by safety system")
                                part.text = filtered_text

        # 🔄 7. Update task state management
        new_task_id = task_result.id
        state[task_key] = new_task_id
        
        # Save to task history
        if new_task_id not in state[task_history_key]:
            state[task_history_key].append(new_task_id)
            # Keep only the most recent 10 task IDs
            if len(state[task_history_key]) > 10:
                state[task_history_key] = state[task_history_key][-10:]

        # 🔄 8. Handle input-required state
        if task_result.status.state == 'input-required':
            print(f'⏸️  Task {new_task_id} requires input from user')
            if task_result.status.message:
                # Fixed: Use safe text extraction method
                message_text = get_message_text(task_result.status.message.parts) if task_result.status.message.parts else "No message"
                print(f'💬 Agent message: {message_text}')
            
            # Save input-required state information
            input_required_key = f'{agent_name}_input_required'
            state[input_required_key] = {
                'task_id': new_task_id,
                'context_id': context_id,
                'timestamp': task_result.status.timestamp,
                'agent_message': task_result.status.message.model_dump() if task_result.status.message else None
            }

        # 🔄 9. Save conversation history
        conversation_key = f'{agent_name}_conversation_history'
        if conversation_key not in state:
            state[conversation_key] = []
        
        # Add user message to history
        user_message = {
            'role': 'user',
            'text': task,
            'messageId': message_id,
            'taskId': new_task_id,
            'contextId': context_id,
            'timestamp': task_result.status.timestamp or str(uuid.uuid4())
        }
        state[conversation_key].append(user_message)

        # Add agent response to history (if there's a status message)
        if task_result.status.message:
            agent_message = {
                'role': 'agent',
                # Fixed: Use safe text extraction method
                'text': get_message_text(task_result.status.message.parts) if task_result.status.message.parts else '',
                'messageId': task_result.status.message.messageId if hasattr(task_result.status.message, 'messageId') else str(uuid.uuid4()),
                'taskId': new_task_id,
                'contextId': context_id,
                'timestamp': task_result.status.timestamp,
                'state': task_result.status.state
            }
            state[conversation_key].append(agent_message)

        # Keep only the most recent 20 conversation entries
        if len(state[conversation_key]) > 20:
            state[conversation_key] = state[conversation_key][-20:]

        # 🔄 10. Save task state for future reference
        last_task_state_key = f'{agent_name}_last_task_state'
        state[last_task_state_key] = task_result.status.state
        
        print(f'✅ Task {new_task_id} state: {task_result.status.state}')
        return task_result

def _get_initialized_coordinator_agent_sync() -> Agent:
    """Synchronously creates and initializes the CoordinatorAgent."""

    async def _async_main() -> Agent:
        coordinator_agent_instance = await CoordinatorAgent.create(
            remote_agent_addresses=[
                os.getenv('1AGENT_URL', 'http://localhost:10001'),  
                os.getenv('2AGENT_URL', 'http://localhost:10003'),
                os.getenv('3AGENT_URL', 'http://localhost:10002'),
            ]
        )
        return coordinator_agent_instance.create_agent()

    try:
        return asyncio.run(_async_main())
    except RuntimeError as e:
        if 'asyncio.run() cannot be called from a running event loop' in str(e):
            print(
                f'Warning: Could not initialize CoordinatorAgent with asyncio.run(): {e}. '
                'This can happen if an event loop is already running (e.g., in Jupyter). '
                'Consider initializing CoordinatorAgent within an async function in your application.'
            )
        raise


def get_root_agent():
    return _get_initialized_coordinator_agent_sync()