from typing import List, Dict, Any, Optional, Type
from dataclasses import dataclass
from pydantic import BaseModel, Field

from mcp_agent.workflows.llm.augmented_llm_openai import (
    OpenAIAugmentedLLM,
    ChatCompletionMessageParam,
    ChatCompletionMessage,
)
from mcp_agent.workflows.llm.augmented_llm import (
    RequestParams,
    ModelT,
    MessageTypes,
)
from mcp_agent.logging.logger import get_logger
from mcp_agent.tracing.telemetry import get_tracer


@dataclass
class ToolStep:
    """Represents a single tool execution step in the plan"""
    step_number: int
    tool_name: str
    tool_args: Dict[str, Any]
    description: str
    expected_output: str
    dependencies: List[int] = None  # Step numbers this step depends on
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "step_number": self.step_number,
            "tool_name": self.tool_name,
            "tool_args": self.tool_args,
            "description": self.description,
            "expected_output": self.expected_output,
            "dependencies": self.dependencies or []
        }


class ExecutionPlan(BaseModel):
    """Structured execution plan generated by the planner"""
    query: str = Field(description="The original user query")
    analysis: str = Field(description="Analysis of what needs to be done")
    steps: List[Dict[str, Any]] = Field(description="List of execution steps")
    expected_final_output: str = Field(description="Expected final output")
    
    @property
    def tool_steps(self) -> List[ToolStep]:
        """Convert steps to ToolStep objects"""
        return [ToolStep(**step) for step in self.steps]


class PlanValidation(BaseModel):
    """Validation result for an execution plan"""
    is_valid: bool = Field(description="Whether the plan is valid")
    issues: List[str] = Field(default_factory=list, description="List of issues found")
    suggestions: List[str] = Field(default_factory=list, description="Suggestions for improvement")
    revised_plan: Optional[ExecutionPlan] = Field(default=None, description="Revised plan if needed")


class PlanThenExecuteLLM(OpenAIAugmentedLLM):
    """
    Plan-then-Execute implementation of AugmentedLLM.
    
    This approach:
    1. Analyzes the query and available tools
    2. Creates a complete execution plan
    3. Validates and optimizes the plan
    4. Executes the plan step by step
    5. Handles errors with re-planning if needed
    """
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.logger = get_logger(__name__)
        self.max_replan_attempts = kwargs.get("max_replan_attempts", 2)
        self.enable_plan_validation = kwargs.get("enable_plan_validation", True)
        self.enable_parallel_execution = kwargs.get("enable_parallel_execution", True)
    
    async def generate(
        self,
        message: MessageTypes,
        request_params: RequestParams | None = None,
    ) -> List[ChatCompletionMessage]:
        """
        Plan-then-Execute generation process:
        1. Generate execution plan
        2. Validate plan (optional)
        3. Execute plan
        4. Handle errors with re-planning
        """
        tracer = get_tracer(self.context)
        with tracer.start_as_current_span(
            f"{self.__class__.__name__}.{self.name}.generate"
        ) as span:
            span.set_attribute("execution_mode", "plan_then_execute")
            
            params = self.get_request_params(request_params)
            
            # Step 1: Generate execution plan
            self.logger.info("Generating execution plan...")
            plan = await self._generate_plan(message, params)
            span.set_attribute("plan.steps_count", len(plan.steps))
            
            # Step 2: Validate plan (optional)
            if self.enable_plan_validation:
                self.logger.info("Validating execution plan...")
                validation = await self._validate_plan(plan, params)
                
                if not validation.is_valid:
                    self.logger.warning(f"Plan validation failed: {validation.issues}")
                    if validation.revised_plan:
                        plan = validation.revised_plan
                        self.logger.info("Using revised plan")
            
            # Step 3: Execute plan
            self.logger.info(f"Executing plan with {len(plan.steps)} steps...")
            results = await self._execute_plan(plan, params)
            
            # Step 4: Generate final response
            final_response = await self._generate_final_response(
                plan, results, message, params
            )
            
            return final_response
    
    async def _generate_plan(
        self, 
        message: MessageTypes, 
        params: RequestParams
    ) -> ExecutionPlan:
        """Generate a structured execution plan"""
        # Get available tools
        tools_response = await self.agent.list_tools()
        tools_info = [
            {
                "name": tool.name,
                "description": tool.description,
                "parameters": tool.inputSchema
            }
            for tool in tools_response.tools
        ]
        
        planning_prompt = f"""You are a strategic planner. Analyze the user's query and create a detailed execution plan.

Available tools:
{self._format_tools_for_prompt(tools_info)}

User query: {message}

Create a step-by-step execution plan. For each step, specify:
1. The tool to use
2. The exact arguments for the tool
3. What the step accomplishes
4. Expected output
5. Dependencies on previous steps (if any)

Think carefully about:
- The logical order of operations
- Which tools are most appropriate
- How to combine tool outputs effectively
- Potential error cases

Return a structured plan that can be executed sequentially or in parallel where possible."""

        # Use structured generation to get a valid plan
        plan = await super().generate_structured(
            planning_prompt,
            ExecutionPlan,
            params
        )
        
        self.logger.debug(f"Generated plan: {plan.model_dump_json(indent=2)}")
        return plan
    
    async def _validate_plan(
        self, 
        plan: ExecutionPlan, 
        params: RequestParams
    ) -> PlanValidation:
        """Validate and potentially revise the execution plan"""
        validation_prompt = f"""You are a plan validator. Review this execution plan for potential issues:

Plan:
{plan.model_dump_json(indent=2)}

Check for:
1. Logical consistency
2. Correct tool usage
3. Proper dependencies
4. Missing steps
5. Redundant operations
6. Potential failure points

If the plan has issues, provide a revised version."""

        validation = await super().generate_structured(
            validation_prompt,
            PlanValidation,
            params
        )
        
        return validation
    
    async def _execute_plan(
        self, 
        plan: ExecutionPlan, 
        params: RequestParams
    ) -> Dict[int, Any]:
        """Execute the plan step by step"""
        results = {}
        tool_steps = plan.tool_steps
        
        for step in tool_steps:
            # Check dependencies
            if step.dependencies:
                for dep in step.dependencies:
                    if dep not in results:
                        self.logger.error(f"Dependency {dep} not satisfied for step {step.step_number}")
                        continue
            
            self.logger.info(f"Executing step {step.step_number}: {step.description}")
            
            try:
                # Substitute any references to previous results
                processed_args = self._process_arguments(step.tool_args, results)
                
                # Execute tool
                result = await self.agent.call_tool(
                    name=step.tool_name,
                    arguments=processed_args
                )
                
                results[step.step_number] = result
                self.logger.debug(f"Step {step.step_number} completed successfully")
                
            except Exception as e:
                self.logger.error(f"Step {step.step_number} failed: {e}")
                results[step.step_number] = {"error": str(e)}
                
                # Decide whether to continue or re-plan
                if self._should_replan(step, e):
                    # Re-planning logic would go here
                    self.logger.info("Re-planning required due to execution failure")
                    break
        
        return results
    
    async def _generate_final_response(
        self,
        plan: ExecutionPlan,
        results: Dict[int, Any],
        original_message: MessageTypes,
        params: RequestParams
    ) -> List[ChatCompletionMessage]:
        """Generate the final response based on plan execution results"""
        # Format execution results
        results_summary = self._format_results(results)
        
        final_prompt = f"""Based on the execution plan and results, provide a comprehensive response to the user.

Original query: {original_message}

Execution plan:
{plan.model_dump_json(indent=2)}

Execution results:
{results_summary}

Synthesize the results into a clear, helpful response that directly addresses the user's query."""

        # Use the parent's generate method for the final response
        return await super().generate(final_prompt, params)
    
    def _format_tools_for_prompt(self, tools: List[Dict]) -> str:
        """Format tools information for the planning prompt"""
        formatted = []
        for tool in tools:
            formatted.append(f"- {tool['name']}: {tool['description']}")
            if tool.get('parameters'):
                formatted.append(f"  Parameters: {tool['parameters']}")
        return "\n".join(formatted)
    
    def _process_arguments(
        self, 
        args: Dict[str, Any], 
        results: Dict[int, Any]
    ) -> Dict[str, Any]:
        """Process arguments, substituting references to previous results"""
        processed = {}
        for key, value in args.items():
            if isinstance(value, str) and value.startswith("$step_"):
                # Reference to a previous step's result
                step_num = int(value.split("_")[1])
                if step_num in results:
                    processed[key] = self._extract_result_value(results[step_num])
                else:
                    processed[key] = value
            else:
                processed[key] = value
        return processed
    
    def _extract_result_value(self, result: Any) -> Any:
        """Extract the actual value from a tool result"""
        if hasattr(result, 'content') and result.content:
            if hasattr(result.content[0], 'text'):
                return result.content[0].text
        return str(result)
    
    def _should_replan(self, step: ToolStep, error: Exception) -> bool:
        """Determine if re-planning is needed based on the error"""
        # Simple heuristic - can be made more sophisticated
        critical_errors = ["not found", "permission denied", "invalid", "failed"]
        error_str = str(error).lower()
        return any(critical in error_str for critical in critical_errors)
    
    def _format_results(self, results: Dict[int, Any]) -> str:
        """Format execution results for the final prompt"""
        formatted = []
        for step_num, result in results.items():
            if isinstance(result, dict) and "error" in result:
                formatted.append(f"Step {step_num}: ERROR - {result['error']}")
            else:
                formatted.append(f"Step {step_num}: {self._extract_result_value(result)}")
        return "\n".join(formatted)
    
    async def generate_str(
        self,
        message: MessageTypes,
        request_params: RequestParams | None = None,
    ) -> str:
        """Generate string response using plan-then-execute"""
        responses = await self.generate(message, request_params)
        return "\n".join([
            response.content 
            for response in responses 
            if hasattr(response, 'content') and response.content
        ])
    
    async def generate_structured(
        self,
        message: MessageTypes,
        response_model: Type[ModelT],
        request_params: RequestParams | None = None,
    ) -> ModelT:
        """Generate structured response using plan-then-execute"""
        # First get string response using the generate method
        responses = await self.generate(message, request_params)
        
        # Extract text from responses
        response_str = "\n".join([
            response.content 
            for response in responses 
            if hasattr(response, 'content') and response.content
        ])
        
        # Then use parent's structured generation on the response string
        # Create a simple prompt to structure the response
        structuring_prompt = f"Please structure the following response according to the required format:\n\n{response_str}"
        
        # Call parent's generate_structured directly to avoid recursion
        return await OpenAIAugmentedLLM.generate_structured(
            self,
            structuring_prompt,
            response_model,
            request_params
        ) 