"""
Base agent class for all LLM-based agents in the articulated object generation pipeline.

This module provides a unified architecture for all agents (LinkerGenerator, ShapeGenerator, 
ArticulationGenerator) that interact with various LLM providers (OpenAI, Google).

Key Features:
- Unified API abstraction across different LLM providers
- Automatic retry mechanisms with exponential backoff
- Comprehensive metrics collection (tokens, cost, timing)
- Standardized error handling and logging
- Consistent input/output formatting

Architecture:
    BaseAgent (Abstract)
    ├── LinkerGeneratorAgent    # Generates link descriptions from captions  
    ├── ShapeGeneratorAgent     # Generates Three.js code from descriptions
    └── ArticulationGeneratorAgent  # Generates joint specifications

Usage:
    class MyAgent(BaseAgent):
        def _load_system_prompt(self) -> str:
            return "Your system prompt here"
        
        def _format_user_prompt(self, input_data) -> str:
            return f"Process: {input_data['content']}"
        
        def parse_response(self, response) -> Any:
            return parse_your_format(response)
"""

import os
import time
import json
import yaml
import logging
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, Tuple, List
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type

from providers import SmartProviderFactory
from utils.output_parser import OutputFormatError


class BaseAgent(ABC):
    """Base class for all agents providing common functionality"""
    
    def __init__(self, config_manager, agent_type: str):
        """
        Initialize base agent
        
        Args:
            config_manager: Configuration manager instance
            agent_type: Type of agent ('linker_generator', 'shape_generator', 'articulation_generator')
        """
        self.config = config_manager
        self.agent_type = agent_type
        self.logger = logging.getLogger(f"{self.__class__.__name__}")
        
        # Get model configuration for this agent
        self.model_name = self.config.get_model_for_agent(agent_type)
        
        # Get API keys
        api_keys = {
            'openai': self.config.get_api_key('openai'),
            'google': self.config.get_api_key('google'),
            'anthropic': self.config.get_api_key('anthropic')
        }
        
        # Create provider using smart factory
        provider_config = self.config.get_model_config(self.model_name)
        self.provider = SmartProviderFactory.create_provider(
            self.model_name, 
            provider_config,
            api_keys
        )
        
        # Load system prompt
        self.system_prompt = self._load_system_prompt()
        
        # Get retry configuration
        self.retry_config = self.config.get_retry_config()
        
        # Initialize metrics
        self.last_metrics = {}
    
    @abstractmethod
    def _load_system_prompt(self) -> str:
        """
        Load system prompt for this agent
        
        Returns:
            System prompt string
        """
        pass
    
    @abstractmethod
    def _format_user_prompt(self, input_data: Dict[str, Any]) -> str:
        """
        Format user prompt based on input data
        
        Args:
            input_data: Input data for prompt generation
            
        Returns:
            Formatted user prompt
        """
        pass
    
    @abstractmethod
    def parse_response(self, response: str) -> Any:
        """
        Parse LLM response into structured data
        
        Args:
            response: Raw LLM response
            
        Returns:
            Parsed structured data
            
        Raises:
            OutputFormatError: If response format is invalid
        """
        pass
    
    def generate_with_retry(
        self, 
        input_data: Dict[str, Any], 
        log_path: Optional[str] = None,
        output_folder: Optional[str] = None
    ) -> Tuple[Any, Dict[str, Any]]:
        """
        Generate output with automatic retry on failures
        
        Args:
            input_data: Input data for generation
            log_path: Optional path to log file
            
        Returns:
            Tuple of (parsed_result, metrics)
            
        Raises:
            Exception: If all retries fail
        """
        # Skip logging if no log_path provided
        if log_path is None:
            log_path = None
        
        # Create retry decorator with config
        @retry(
            stop=stop_after_attempt(self.retry_config['max_retries']),
            wait=wait_exponential(
                multiplier=self.retry_config.get('multiplier', 1),
                min=self.retry_config.get('min_wait', 4),
                max=self.retry_config.get('max_wait', 10)
            ),
            retry=retry_if_exception_type(OutputFormatError),
            reraise=True
        )
        def _generate():
            # Format prompts
            user_prompt = self._format_user_prompt(input_data)
            
            # Log input
            if log_path:
                self._log_llm_input(log_path, self.system_prompt, user_prompt)
            
            # Invoke provider
            start_time = time.time()
            response_text, provider_metrics = self.provider.invoke(
                prompt=user_prompt,
                system_prompt=self.system_prompt
            )
            end_time = time.time()
            
            # IMMEDIATELY save raw output to stage folder (before parsing)
            if output_folder and response_text:
                self._save_stage_output_immediate(output_folder, response_text, input_data, 
                                                 {'time_cost': end_time - start_time, **provider_metrics})
            
            # Log raw output
            if log_path:
                self._log_llm_output(log_path, response_text)
            
            # Check for empty response
            if not response_text or not response_text.strip():
                raise OutputFormatError("LLM returned empty response")
            
            # Parse response
            try:
                parsed_result = self.parse_response(response_text)
            except OutputFormatError as e:
                if log_path:
                    self._log_error(log_path, f"Parse error: {e}")
                raise
            
            # Build complete metrics with separate input/output costs
            metrics = {
                'time_cost': end_time - start_time,
                'model': self.model_name,
                'agent_type': self.agent_type,
                **provider_metrics
            }

            # Calculate separate input and output costs if tokens are available
            if 'input_tokens' in metrics and 'output_tokens' in metrics:
                from providers.utils import CostCalculator

                # Determine which pricing table to use based on provider
                if 'gpt' in self.model_name.lower() or 'o1' in self.model_name.lower() or 'o3' in self.model_name.lower():
                    metrics['input_cost'] = CostCalculator.calculate_openai_cost(self.model_name, metrics['input_tokens'], 0)
                    metrics['output_cost'] = CostCalculator.calculate_openai_cost(self.model_name, 0, metrics['output_tokens'])
                elif 'claude' in self.model_name.lower():
                    metrics['input_cost'] = CostCalculator.calculate_anthropic_cost(self.model_name, metrics['input_tokens'], 0)
                    metrics['output_cost'] = CostCalculator.calculate_anthropic_cost(self.model_name, 0, metrics['output_tokens'])
                elif 'gemini' in self.model_name.lower():
                    metrics['input_cost'] = CostCalculator.calculate_google_cost(self.model_name, metrics['input_tokens'], 0)
                    metrics['output_cost'] = CostCalculator.calculate_google_cost(self.model_name, 0, metrics['output_tokens'])
                else:
                    # Default fallback
                    metrics['input_cost'] = metrics.get('cost', 0) * (metrics['input_tokens'] / metrics.get('total_tokens', 1))
                    metrics['output_cost'] = metrics.get('cost', 0) * (metrics['output_tokens'] / metrics.get('total_tokens', 1))

                # Rename 'cost' to 'total_cost' for clarity
                metrics['total_cost'] = metrics.get('cost', metrics.get('input_cost', 0) + metrics.get('output_cost', 0))
            
            self.last_metrics = metrics
            return parsed_result, metrics, response_text
        
        # Execute with retry
        try:
            result, metrics, raw_response = _generate()
            return result, metrics, raw_response
        except Exception as e:
            self.logger.error(f"Generation failed after {self.retry_config['max_retries']} attempts: {e}")
            raise
    
    def generate(
        self,
        *args,
        output_folder: Optional[str] = None,
        log_path: Optional[str] = None,
        stage_num: Optional[int] = None,
        **kwargs
    ) -> Tuple[Any, bool, Dict[str, Any], str]:
        """
        Main generation method to be called by pipeline
        
        Args:
            *args: Agent-specific arguments
            output_folder: Output folder path
            log_path: Log file path
            stage_num: Pipeline stage number (1, 2, or 3)
            **kwargs: Additional arguments
            
        Returns:
            Tuple of (result, success, metrics, raw_llm_output)
        """
        try:
            # Prepare input data (to be implemented by subclasses)
            input_data = self._prepare_input_data(*args, **kwargs)
            
            # Set log path to pipeline_logs/pipeline_run.log
            if log_path is None and output_folder:
                log_path = os.path.join(output_folder, 'pipeline_logs', 'pipeline_run.log')
            
            # Generate with retry
            result, metrics, raw_response = self.generate_with_retry(input_data, log_path, output_folder)
            
            # Save LLM interaction to stage-specific folder
            if output_folder and raw_response:
                self._save_stage_output(output_folder, raw_response, input_data, metrics, stage_num)
            
            # Save output if folder provided
            if output_folder:
                # Pass input_data to save_output for access to original input
                if hasattr(self, '_save_output_with_input'):
                    self._save_output_with_input(result, output_folder, input_data, metrics)
                else:
                    self.save_output(result, output_folder, metrics)
            
            return result, True, metrics, raw_response
            
        except Exception as e:
            self.logger.error(f"Generation failed: {e}")
            # Save error output for debugging
            if output_folder:
                self._save_error_output(output_folder, str(e), stage_num)
            return None, False, {}, ""
    
    @abstractmethod
    def _prepare_input_data(self, *args, **kwargs) -> Dict[str, Any]:
        """
        Prepare input data from arguments
        
        Returns:
            Dictionary of input data
        """
        pass
    
    @abstractmethod
    def save_output(self, result: Any, output_folder: str, metrics: Dict[str, Any] = None):
        """
        Save generation output to files
        
        Args:
            result: Generation result
            output_folder: Output folder path
            metrics: Generation metrics
        """
        pass
    
    def _log_llm_input(self, log_path: str, system_prompt: str, user_prompt: str):
        """Log LLM input to unified pipeline log"""
        os.makedirs(os.path.dirname(log_path), exist_ok=True)
        # Don't log full prompts to main log, just summary
        pass
    
    def _log_llm_output(self, log_path: str, response: str):
        """Log LLM output summary to unified pipeline log"""
        # Don't log full response to main log, just summary
        pass
    
    def _log_error(self, log_path: str, error_msg: str):
        """Log error to unified pipeline log"""
        os.makedirs(os.path.dirname(log_path), exist_ok=True)
        with open(log_path, 'a', encoding='utf-8') as f:
            timestamp = time.strftime('%Y-%m-%d %H:%M:%S')
            f.write(f"[{timestamp}] ERROR in {self.agent_type}: {error_msg}\n")
    
    def get_token_count(self, text: str) -> int:
        """
        Get token count for text
        
        Args:
            text: Text to count tokens for
            
        Returns:
            Number of tokens
        """
        return self.provider.count_tokens(text)
    
    def get_last_metrics(self) -> Dict[str, Any]:
        """Get metrics from last generation"""
        return self.last_metrics.copy()
    
    def _save_stage_output(self, output_folder: str, raw_response: str, input_data: Dict[str, Any], 
                          metrics: Dict[str, Any], stage_num: Optional[int] = None):
        """
        Save LLM interaction to stage-specific folder
        
        Args:
            output_folder: Output folder path
            raw_response: Raw LLM response text
            input_data: Input data used for generation
            metrics: Generation metrics
            stage_num: Pipeline stage number
        """
        # Determine stage folder name
        if stage_num:
            stage_name = {
                1: 'stage_1_linker',
                2: 'stage_2_shape',
                3: 'stage_3_articulation'
            }.get(stage_num, f'stage_{stage_num}')
        else:
            stage_name = f'stage_{self.agent_type}'
        
        # Create stage folder
        stage_folder = os.path.join(output_folder, 'pipeline_logs', stage_name)
        os.makedirs(stage_folder, exist_ok=True)
        
        # Save LLM interaction
        interaction_file = os.path.join(stage_folder, 'llm_interaction.txt')
        with open(interaction_file, 'w', encoding='utf-8') as f:
            f.write("=" * 80 + "\n")
            f.write("LLM INTERACTION LOG\n")
            f.write(f"Agent: {self.agent_type}\n")
            f.write(f"Model: {self.model_name}\n")
            f.write(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write("=" * 80 + "\n\n")
            
            # Save abbreviated system prompt (first 10 lines) for readability
            system_prompt_lines = self.system_prompt.split('\n')
            abbreviated_prompt = '\n'.join(system_prompt_lines[:10])
            if len(system_prompt_lines) > 10:
                abbreviated_prompt += f"\n\n... (truncated, showing 10/{len(system_prompt_lines)} lines)"
                abbreviated_prompt += "\n[Full prompt saved in system_prompt.txt]"
            
            f.write("SYSTEM PROMPT (abbreviated):\n")
            f.write("=" * 40 + "\n")
            f.write(abbreviated_prompt)
            f.write("\n\n")
            
            f.write("USER PROMPT:\n")
            f.write("=" * 40 + "\n")
            f.write(self._format_user_prompt(input_data))
            f.write("\n\n")
            
            f.write("LLM RESPONSE:\n")
            f.write("=" * 40 + "\n")
            f.write(raw_response)
            f.write("\n\n")
            
            # Try to parse and pretty print the response
            try:
                parsed = self.parse_response(raw_response)
                f.write("PARSED OUTPUT:\n")
                f.write("=" * 40 + "\n")
                if isinstance(parsed, (dict, list)):
                    f.write(json.dumps(parsed, indent=2, ensure_ascii=False))
                else:
                    f.write(str(parsed))
                f.write("\n")
            except:
                pass
        
        # Save full system prompt to separate file
        system_prompt_file = os.path.join(stage_folder, 'system_prompt.txt')
        with open(system_prompt_file, 'w', encoding='utf-8') as f:
            f.write("FULL SYSTEM PROMPT\n")
            f.write("=" * 80 + "\n")
            f.write(f"Agent: {self.agent_type}\n")
            f.write(f"Model: {self.model_name}\n")
            f.write(f"Length: {len(self.system_prompt)} characters, {len(system_prompt_lines)} lines\n")
            f.write("=" * 80 + "\n\n")
            f.write(self.system_prompt)
        
        # Save metadata
        metadata_file = os.path.join(stage_folder, 'metadata.json')
        metadata_dict = {
            'agent_type': self.agent_type,
            'model': self.model_name,
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
            'metrics': metrics,
            'system_prompt_length': len(self.system_prompt),
            'system_prompt_lines': len(system_prompt_lines)
        }

        # Only include input_data for stage_1 (linker), exclude it for stage_2 (shape) and stage_3 (articulation)
        if self.agent_type not in ['shape_generator', 'articulation_generator']:
            metadata_dict['input_data'] = input_data

        with open(metadata_file, 'w', encoding='utf-8') as f:
            json.dump(metadata_dict, f, indent=2, ensure_ascii=False)
        
        # Log summary to main pipeline log
        pipeline_log = os.path.join(output_folder, 'pipeline_logs', 'pipeline_run.log')
        os.makedirs(os.path.dirname(pipeline_log), exist_ok=True)
        with open(pipeline_log, 'a', encoding='utf-8') as f:
            timestamp = time.strftime('%H:%M:%S')
            f.write(f"[{timestamp}] {stage_name.upper()}: Completed\n")
            f.write(f"  Model: {self.model_name} | ")
            f.write(f"Tokens: {metrics.get('total_tokens', 'N/A')} | ")
            f.write(f"Cost: ${metrics.get('total_cost', 0):.4f} | ")
            f.write(f"Duration: {metrics.get('time_cost', 0):.2f}s\n")
            f.write("-" * 60 + "\n")
    
    def _save_stage_output_immediate(self, output_folder: str, raw_response: str, input_data: Dict[str, Any], metrics: Dict[str, Any]):
        """Save raw output immediately to existing stage folder structure."""
        stage_num = {'linker_generator': 1, 'shape_generator': 2, 'articulation_generator': 3}.get(self.agent_type, 0)
        stage_name = {1: 'stage_1_linker', 2: 'stage_2_shape', 3: 'stage_3_articulation'}.get(stage_num, f'stage_{self.agent_type}')
        
        stage_folder = os.path.join(output_folder, 'pipeline_logs', stage_name)
        os.makedirs(stage_folder, exist_ok=True)
        
        # Write to llm_interaction.txt immediately
        interaction_file = os.path.join(stage_folder, 'llm_interaction.txt')
        with open(interaction_file, 'w', encoding='utf-8') as f:
            f.write("=" * 80 + "\n")
            f.write("LLM INTERACTION LOG (SAVED IMMEDIATELY)\n")
            f.write(f"Agent: {self.agent_type}\n")
            f.write(f"Model: {self.model_name}\n")
            f.write(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write("=" * 80 + "\n\n")
            f.write("RAW LLM RESPONSE:\n")
            f.write("=" * 40 + "\n")
            f.write(raw_response)
    
    def _save_error_output(self, output_folder: str, error_msg: str, stage_num: Optional[int] = None):
        """
        Save error information to stage folder
        
        Args:
            output_folder: Output folder path
            error_msg: Error message
            stage_num: Pipeline stage number
        """
        # Determine stage folder name
        if stage_num:
            stage_name = {
                1: 'stage_1_linker',
                2: 'stage_2_shape',
                3: 'stage_3_articulation'
            }.get(stage_num, f'stage_{stage_num}')
        else:
            stage_name = f'stage_{self.agent_type}'
        
        # Create stage folder
        stage_folder = os.path.join(output_folder, 'pipeline_logs', stage_name)
        os.makedirs(stage_folder, exist_ok=True)
        
        error_file = os.path.join(stage_folder, 'error.txt')
        with open(error_file, 'w', encoding='utf-8') as f:
            f.write("=" * 80 + "\n")
            f.write(f"Agent Type: {self.agent_type}\n")
            f.write(f"Model: {self.model_name}\n")
            f.write(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write("=" * 80 + "\n\n")
            
            f.write("ERROR:\n")
            f.write("-" * 40 + "\n")
            f.write(error_msg)
        
        # Log to main pipeline log
        pipeline_log = os.path.join(output_folder, 'pipeline_logs', 'pipeline_run.log')
        os.makedirs(os.path.dirname(pipeline_log), exist_ok=True)
        with open(pipeline_log, 'a', encoding='utf-8') as f:
            timestamp = time.strftime('%H:%M:%S')
            f.write(f"[{timestamp}] {stage_name.upper()}: FAILED\n")
            f.write(f"  Error: {error_msg[:100]}...\n")
            f.write("-" * 60 + "\n")