import json
import uuid
from typing import Any, Dict, List, Optional

from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.events import EventQueue
from a2a.server.tasks import TaskUpdater
from a2a.types import (
    Part,
    TaskState,
    TextPart,
)
from a2a.utils import new_agent_text_message, new_task
from a2a.utils.message import get_message_text
from google.adk.artifacts import InMemoryArtifactService
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
from google.adk.runners import Runner
from google.adk.sessions import InMemorySessionService
from google.genai import types

class ADKAgentExecutor(AgentExecutor):
    def __init__(
        self,
        agent,
        status_message="Processing request...",
        artifact_name="response",
        enable_followup=True,
        debug_json=False,
    ):
        self.agent = agent
        self.status_message = status_message
        self.artifact_name = artifact_name
        self.enable_followup = enable_followup
        self.debug_json = debug_json
        
        # Keep original variable names unchanged
        self._conversation_history: Dict[str, List[Dict[str, Any]]] = {}
        self._last_search_results: Dict[str, str] = {}
        
        self.runner = Runner(
            app_name=agent.name,
            agent=agent,
            artifact_service=InMemoryArtifactService(),
            session_service=InMemorySessionService(),
            memory_service=InMemoryMemoryService(),
        )
        
    def _build_a2a_task_history(self, context_id: str) -> List[Dict[str, Any]]:
        """Build Task.history that conforms to A2A specification"""
        if context_id not in self._conversation_history:
            return []
        
        a2a_history = []
        for item in self._conversation_history[context_id]:
            # User message
            user_message = {
                "role": "user",
                "parts": [{"kind": "text", "text": item['query']}],
                "messageId": str(uuid.uuid4()),
                "contextId": context_id,
                "kind": "message"
            }
            a2a_history.append(user_message)
            
            # Agent response
            if item['response']:
                agent_message = {
                    "role": "agent", 
                    "parts": [{"kind": "text", "text": item['response']}],
                    "messageId": str(uuid.uuid4()),
                    "contextId": context_id,
                    "kind": "message"
                }
                a2a_history.append(agent_message)
        
        return a2a_history
        
    def _print_json(self, label: str, data: Any) -> None:
        """Concise JSON output, similar to host-side style"""
        if self.debug_json:
            print(f"{label} {json.dumps(data, indent=2, ensure_ascii=False)}")

    async def execute(
        self,
        context: RequestContext,
        event_queue: EventQueue,
    ) -> None:
        """Process an incoming message and dispatch to the appropriate handler."""
        raw_text = get_message_text(context.message) if context.message else ''
        
        # Display received request (similar to host-side style)
        if self.debug_json and context.message:
            request_data = {
                "method": "message/send",
                "params": {
                    "message": context.message.model_dump(exclude_none=True)
                }
            }
            self._print_json("📨 request", request_data)
        
        # Determine if this is a follow-up message
        is_followup = bool(
            context.message and context.message.reference_task_ids
        )

        if is_followup:
            await self._handle_followup(context, raw_text, event_queue)
        else:
            await self._handle_initial(raw_text, context, event_queue)

    async def _handle_initial(
        self, raw_text: str, context: RequestContext, event_queue: EventQueue
    ) -> None:
        """Handle the first message in a new conversation """
        
        # A2A protocol: Use contextId provided by client, or generate new one
        context_id = context.context_id or str(uuid.uuid4())
        task_id = context.task_id or str(uuid.uuid4())

        existing_history = self._conversation_history.get(context_id, [])
        existing_results = self._last_search_results.get(context_id, "")
        
        # If there's history, build a contextual query
        if existing_history:
            contextual_query = self._build_contextual_query(raw_text, existing_history, existing_results)
            print(f"[Context n] Using existing context {context_id} with {len(existing_history)} previous interactions")
        else:
            contextual_query = raw_text
        
        # Process the query - CHANGED: doctor query instead of hotel
        response_text = await self._process_doctor_query(contextual_query, context_id)
        
        updater = TaskUpdater(
            event_queue,
            task_id=task_id,
            context_id=context_id,
        )
        
        # Save to history
        new_entry = {
            'query': raw_text,
            'response': response_text,
            'timestamp': str(uuid.uuid4())
        }
        
        if existing_history:
            # If there's existing history, append to it
            self._conversation_history[context_id].append(new_entry)
        else:
            # Create new history
            self._conversation_history[context_id] = [new_entry]
            
        self._last_search_results[context_id] = response_text
        
        # Add results and wait for further input
        if self.enable_followup:
            # CHANGED: doctor-specific followup message
            followup_message = "\n\n💬 Would you like to refine your search, ask about specific doctors, or search in other specialties? (Type 'done' when finished)"
            enhanced_response = response_text + followup_message
            await updater.add_artifact(
                [Part(root=TextPart(text=enhanced_response))],
                name=self.artifact_name
            )
            await updater.requires_input(final=True)
            
            # Display response (similar to host-side style)
            if self.debug_json:
                response_data = {
                    "id": task_id,
                    "result": {
                        "id": task_id,
                        "contextId": context_id,
                        "status": {"state": "input-required"},
                        "artifacts": [{
                            "artifactId": "generated-id",
                            "name": self.artifact_name,
                            "parts": [{"kind": "text", "text": enhanced_response}]
                        }],
                        "history": self._build_a2a_task_history(context_id),
                        "kind": "task"
                    }
                }
                self._print_json("📤 response", response_data)
        else:
            await updater.add_artifact(
                [Part(root=TextPart(text=response_text))],
                name=self.artifact_name
            )
            await updater.complete()

    async def _handle_followup(
        self, context: RequestContext, raw_text: str, event_queue: EventQueue
    ) -> None:
        """Handle a follow-up message that references an existing task."""
        
        # ✅ Use task_id provided by A2A framework, or generate new one
        task_id = context.task_id or str(uuid.uuid4())
        
        if context.message and context.message.reference_task_ids:
            print(f"[A2A] Processing follow-up with referenceTaskIds: {context.message.reference_task_ids}")
        
        updater = TaskUpdater(
            event_queue,
            task_id=task_id,  # Use correct task_id
            context_id=context.context_id or str(uuid.uuid4()),
        )

        context_id = context.context_id or str(uuid.uuid4())
        
        # Check if this is a completion signal
        if raw_text.lower() in ['done', 'finished', 'complete', 'thank you', 'thanks']:
            # CHANGED: doctor-specific completion message
            print(f'[DoctorSearch] Received completion signal - completing task {task_id}')
            await updater.complete()
            return

        # Process follow-up query
        history = self._conversation_history.get(context_id, [])
        last_result = self._last_search_results.get(context_id, "")
        
        # Build context-aware query
        contextual_query = self._build_contextual_query(raw_text, history, last_result)
        
        # Process the query - CHANGED: doctor query instead of hotel
        response_text = await self._process_doctor_query(contextual_query, context_id)
        
        # Add to history
        history.append({
            'query': raw_text,
            'response': response_text,
            'timestamp': str(uuid.uuid4())
        })
        self._conversation_history[context_id] = history
        self._last_search_results[context_id] = response_text
        
        await updater.add_artifact(
            [Part(root=TextPart(text=response_text))],
            name=self.artifact_name
        )
        await updater.requires_input(final=True)
        
        # Display follow-up response
        if self.debug_json:
            response_data = {
                "id": task_id,
                "result": {
                    "id": task_id,
                    "contextId": context_id,
                    "status": {"state": "input-required"},
                    "artifacts": [{
                        "artifactId": "generated-id",
                        "name": self.artifact_name,
                        "parts": [{"kind": "text", "text": response_text}]
                    }]
                }
            }
            self._print_json("📤 followup_response", response_data)

    def _build_contextual_query(self, current_query: str, history: List[Dict], last_result: str) -> str:
        """Build context-based query - enhanced version ensures LLM can see sensitive information"""
        if not history:
            return current_query
            
        # Build context with complete history information
        context_parts = []
        context_parts.append("=== CONVERSATION CONTEXT ===")
        
        # Add recent query history
        recent_queries = [item['query'] for item in history[-3:]]
        context_parts.append(f"Previous queries: {' | '.join(recent_queries)}")
        
        # Key improvement: Add complete historical response content (including sensitive information)
        if history:
            context_parts.append("Previous conversation details:")
            for i, item in enumerate(history[-2:], 1):  
                context_parts.append(f"Turn {i}:")
                context_parts.append(f"  User: {item['query']}")
                if item['response']:
                    # Include complete response content, ensuring sensitive information is visible to the LLM
                    context_parts.append(f"  Assistant: {item['response'][:300]}")  # First 300 characters
        
        context_parts.append("=== CURRENT QUERY ===")
        context_parts.append(current_query)
        # CHANGED: doctor-specific context message
        context_parts.append("Please consider the conversation context when processing this doctor search.")
        
        return '\n'.join(context_parts)

    # CHANGED: Renamed method from _process_hotel_query to _process_doctor_query
    async def _process_doctor_query(self, query: str, context_id: Optional[str]) -> str:
        """Process doctor search query using the original ADK agent logic."""
        if context_id and context_id in self._conversation_history:
            user_id = f"a2a_user_{context_id}"
        else:
            user_id = "a2a_user"
            
        try:
            session = await self.runner.session_service.create_session(
                app_name=self.agent.name,
                user_id=user_id,
                state={},
                session_id=context_id or str(uuid.uuid4()),
            )

            content = types.Content(
                role="user", parts=[types.Part.from_text(text=query)]
            )

            response_text = ""
            async for event in self.runner.run_async(
                user_id=user_id, session_id=session.id, new_message=content
            ):
                if event.is_final_response() and event.content and event.content.parts:
                    for part in event.content.parts:
                        if hasattr(part, "text") and part.text:
                            response_text += part.text + "\n"

            # CHANGED: doctor-specific error message
            return response_text.strip() if response_text else "No doctor search results found."
            
        except Exception as e:
            # CHANGED: doctor-specific error message
            return f"Error processing doctor search: {str(e)}"

    async def cancel(
        self,
        context: RequestContext,
        event_queue: EventQueue,
    ) -> None:
        """Cancel the execution of a specific task."""
        if context.task_id:
            # CHANGED: doctor-specific cancel message
            print(f'[DoctorSearch] Task {context.task_id} canceled on request of peer agent')
            updater = TaskUpdater(
                event_queue,
                task_id=context.task_id,
                context_id=context.context_id or str(uuid.uuid4()),
            )
            await updater.cancel()