"""
LLM Judge for evaluating orchestrator performance.

This module provides an LLM-based judge that evaluates orchestrator logs
based on workload and outcome dimensions using a Likert scale (1-5).
"""

import argparse
import json
import logging
import re
from typing import Any, Dict, List, Optional, Union

from generators.model import GPT5, Mistral_24B_Instruct, Message


def setup_logging(level=logging.INFO):
    """Setup logging configuration."""
    logging.basicConfig(
        level=level,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler('llm_judge_TO.log')
        ]
    )


class LLMJudge:
    """
    LLM-based judge for evaluating orchestrator performance.
    
    Evaluates orchestrator logs on six dimensions:
    - Mental Demand (1-5)
    - Temporal Demand (1-5) 
    - Performance (1-5)
    - Effort (1-5)
    - Frustration (1-5)
    - Adaptivity (1-5)
    """
    
    def __init__(self, model_name: str = "gpt-5"):
        """
        Initialize the LLM judge.
        
        Args:
            model_name: Name of the LLM model to use for evaluation (default: gpt-5)
                      Supported models: "gpt-5", "mistral-24b-instruct"
        """
        self.model_name = model_name
        self.llm = self._initialize_llm()
        self.evaluation_prompt = self._get_evaluation_prompt()
        
        # Token usage tracking
        self.total_prompt_tokens = 0
        self.total_completion_tokens = 0
        self.total_requests = 0
        
    def _initialize_llm(self):
        """Initialize the LLM model."""
        if self.model_name == "gpt-5":
            return GPT5()
        elif self.model_name == "mistral-24b-instruct":
            return Mistral_24B_Instruct()
        else:
            raise ValueError(f"Unsupported model: {self.model_name}. Supported models: gpt-5, mistral-24b-instruct")
    
    def _get_evaluation_prompt(self) -> str:
        """Get the evaluation prompt template."""
        return """You are an impartial evaluator. 
Your task is to assess how well the orchestrator performed in recommending *policies*, which then generated the actual learning tasks for the student. 
The evaluation is based on the orchestrator log, which contains:
- The policies chosen by the orchestrator at each step,
- The tasks/questions recommended by the selected policies (with difficulty and skill info),
- The student's responses (correct/incorrect) and timing. 

The log may be in text format with timestamps, student states, policy reasoning, and step-by-step interactions. Look for patterns in:
- Policy selection reasoning and choices
- Student mastery level changes over time
- Question difficulty and success rates
- Response times and performance metrics
- Failed questions and gap recovery

You must rate the orchestrator on six workload and outcome dimensions, each scored from 1 (very low) to 5 (very high). 
For each dimension, provide both a score and a short justification grounded in the log. 

## Guidance
- Use patterns in correctness and timing to infer mastery gains, workload, and emotional states. 
- Rate Adaptivity by checking if policy choices respond to changes in student performance. 
- Use extreme ratings (1 or 5) only when strongly supported by the log. 
- Keep justifications concise (1–2 sentences) and consistent across dimensions. 
- Ensure ratings across dimensions are logically coherent. 
- Output must be valid JSON only.


## Dimensions and Scoring Rubric

1. **Mental Demand**  
   - Definition: Mental demand is about task complexity.  
   - Low (1): Very simple, repetitive, predictable questions with little reasoning.  
   - High (5): Highly complex, requiring intensive reasoning, multi-step problem solving, or integration of many concepts.  

2. **Temporal Demand**  
   - Definition: The time pressure experienced due to the orchestrator's pacing through its policy choices and resulting tasks.  
   - Low (1): Pace feels slow, leisurely, or inefficient.  
   - High (5): Pace feels rushed or frantic, leaving insufficient time for comprehension.  

3. **Performance**  
   - Definition: The extent to which the orchestrator's chosen policies helped the student achieve mastery improvement.  
   - Low (1): Student makes little progress on mastery levels, irrelevant recommendations.  
   - High (5): Student's mastery level clearly improves, recommendations are effective.  

4. **Effort**  
   - Definition: The amount of mental effort required from the student to complete the sequence of tasks generated by the orchestrator's policy choices. It refers to the student's *perceived workload* and sustained concentration. Even if tasks are simple, high repetition or long sequences may result in high Effort.    
   - Low (1): Student barely needs to try; tasks feel effortless.  
   - High (5): Student must exert extreme sustained effort; tasks feel exhausting.  

5. **Frustration**  
   - Definition: The emotional experience of the student (e.g., irritation, discouragement, or satisfaction) while following the orchestrator's chosen policies.  
   - Low (1): Student is relaxed, motivated, satisfied.  
   - High (5): Student shows signs of discouragement, annoyance, or strong negative emotions.  

6. **Adaptivity**  
   - Definition: How well the orchestrator adapts policy choices to the student's evolving state (mastery, errors, pace).  
   - Low (1): Repeatedly selects poorly matched policies; no evidence of adjustment to student performance.  
   - High (5): Clearly adjusts policy choices to student's progress, difficulty needs, and error patterns.  



### Orchestrator Log
{log}

### Output Format
Return your evaluation **strictly in JSON** with this schema:

```json
{{
  "Mental Demand": {{
    "score": <int from 1 to 5>,
    "reason": "<short explanation>"
  }},
  "Temporal Demand": {{
    "score": <int from 1 to 5>,
    "reason": "<short explanation>"
  }},
  "Performance": {{
    "score": <int from 1 to 5>,
    "reason": "<short explanation>"
  }},
  "Effort": {{
    "score": <int from 1 to 5>,
    "reason": "<short explanation>"
  }},
  "Frustration": {{
    "score": <int from 1 to 5>,
    "reason": "<short explanation>"
  }},
  "Adaptivity": {{
    "score": <int from 1 to 5>,
    "reason": "<short explanation>"
  }}
}}
```"""
    
    def evaluate_log(self, log_data: Union[str, Dict, List]) -> Dict[str, Any]:
        """
        Evaluate an orchestrator log using the LLM judge.
        
        Args:
            log_data: The orchestrator log data (can be string, dict, or list)
            
        Returns:
            Dictionary containing the evaluation scores and reasons for each dimension
            
        Raises:
            ValueError: If the LLM response is not valid JSON or missing required fields
        """
        # Convert log data to string format
        log_str = self._format_log_data(log_data)
        
        # Create the full prompt
        full_prompt = self.evaluation_prompt.format(log=log_str)
        
        # Generate evaluation using LLM
        messages = [Message(role="user", content=full_prompt)]
        
        try:
            # Log the prompt being sent
            logging.info("="*80)
            logging.info("EVALUATION PROMPT")
            logging.info("="*80)
            logging.info(f"Model: {self.model_name}")
            logging.info(f"Prompt length: {len(full_prompt)} characters")
            logging.info(f"Prompt content:\n{full_prompt}")
            logging.info("="*80)
            
            # Track token usage before the call
            from generators.model import token_usage_callback
            initial_prompt_tokens = token_usage_callback.total_prompt_tokens
            initial_completion_tokens = token_usage_callback.total_completion_tokens
            initial_requests = token_usage_callback.request_count
            
            response = self.llm.generate_chat(messages, max_tokens=4096, temperature=0.0)
            
            # Track token usage after the call
            prompt_tokens_used = token_usage_callback.total_prompt_tokens - initial_prompt_tokens
            completion_tokens_used = token_usage_callback.total_completion_tokens - initial_completion_tokens
            requests_made = token_usage_callback.request_count - initial_requests
            
            # Update our tracking
            self.total_prompt_tokens += prompt_tokens_used
            self.total_completion_tokens += completion_tokens_used
            self.total_requests += requests_made
            
            # Log token usage for this evaluation
            logging.info(f"Evaluation token usage: {prompt_tokens_used} prompt + {completion_tokens_used} completion = {prompt_tokens_used + completion_tokens_used} total")
            
            # Extract content from response
            if isinstance(response, dict) and "content" in response:
                content = response["content"]
            else:
                content = str(response)
            
            # Log the response received
            logging.info("="*80)
            logging.info("LLM RESPONSE")
            logging.info("="*80)
            logging.info(f"Response length: {len(content)} characters")
            logging.info(f"Response content:\n{content}")
            logging.info("="*80)
            
            # Parse JSON response
            evaluation = self._parse_json_response(content)
            
            # Log the parsed evaluation
            logging.info("="*80)
            logging.info("PARSED EVALUATION")
            logging.info("="*80)
            logging.info(f"Evaluation result:\n{json.dumps(evaluation, indent=2)}")
            logging.info("="*80)
            
            # Validate the evaluation
            self._validate_evaluation(evaluation)
            
            return evaluation
            
        except Exception as e:
            logging.error(f"Error during LLM evaluation: {e}")
            raise ValueError(f"Failed to evaluate log: {e}")
    
    def _format_log_data(self, log_data: Union[str, Dict, List]) -> str:
        """
        Format log data into a readable string for the LLM.
        
        Args:
            log_data: The log data to format (can be text log, JSON string, dict, or list)
            
        Returns:
            Formatted string representation of the log
        """
        if isinstance(log_data, str):
            # If it's already a string, return as-is (handles both text logs and JSON strings)
            return log_data
        elif isinstance(log_data, (dict, list)):
            return json.dumps(log_data, indent=2)
        else:
            return str(log_data)
    
    def _parse_json_response(self, content: str) -> Dict[str, Any]:
        """
        Parse JSON response from LLM, handling potential formatting issues.
        
        Args:
            content: The raw content from the LLM
            
        Returns:
            Parsed JSON dictionary
            
        Raises:
            ValueError: If JSON parsing fails
        """
        # Try to find JSON in the response
        json_match = re.search(r'\{.*\}', content, re.DOTALL)
        if json_match:
            json_str = json_match.group(0)
        else:
            json_str = content
        
        try:
            return json.loads(json_str)
        except json.JSONDecodeError as e:
            logging.error(f"Failed to parse JSON response: {e}")
            logging.error(f"Raw content: {content}")
            
            # Try to fix common JSON issues
            try:
                # Try to complete the JSON if it's truncated
                if json_str.count('{') > json_str.count('}'):
                    # Missing closing braces
                    missing_braces = json_str.count('{') - json_str.count('}')
                    json_str += '}' * missing_braces
                    logging.info(f"Attempting to fix JSON by adding {missing_braces} closing braces")
                    return json.loads(json_str)
                elif json_str.strip().endswith(','):
                    # Remove trailing comma
                    json_str = json_str.rstrip(',') + '}'
                    logging.info("Attempting to fix JSON by removing trailing comma")
                    return json.loads(json_str)
                elif '"' in json_str and not json_str.strip().endswith('"'):
                    # Try to complete incomplete string
                    if json_str.count('"') % 2 == 1:
                        json_str += '"'
                        if json_str.count('{') > json_str.count('}'):
                            json_str += '}'
                        logging.info("Attempting to fix JSON by completing string and adding closing brace")
                        return json.loads(json_str)
            except json.JSONDecodeError:
                pass  # If fixing didn't work, continue with original error
            
            raise ValueError(f"Invalid JSON response from LLM: {e}")
    
    def _validate_evaluation(self, evaluation: Dict[str, Any]) -> None:
        """
        Validate that the evaluation contains all required fields and valid scores.
        
        Args:
            evaluation: The parsed evaluation dictionary
            
        Raises:
            ValueError: If validation fails
        """
        required_dimensions = [
            "Mental Demand", "Temporal Demand", "Performance", 
            "Effort", "Frustration", "Adaptivity"
        ]
        
        for dimension in required_dimensions:
            if dimension not in evaluation:
                raise ValueError(f"Missing dimension: {dimension}")
            
            dim_data = evaluation[dimension]
            if not isinstance(dim_data, dict):
                raise ValueError(f"Invalid format for {dimension}: expected dict")
            
            if "score" not in dim_data or "reason" not in dim_data:
                raise ValueError(f"Missing 'score' or 'reason' for {dimension}")
            
            score = dim_data["score"]
            if not isinstance(score, int) or score < 1 or score > 5:
                raise ValueError(f"Invalid score for {dimension}: {score} (must be 1-5)")
            
            if not isinstance(dim_data["reason"], str) or not dim_data["reason"].strip():
                raise ValueError(f"Invalid reason for {dimension}: must be non-empty string")
    
    
    def get_token_usage(self) -> Dict[str, int]:
        """
        Get token usage statistics for this judge instance.
        
        Returns:
            Dictionary containing token usage statistics
        """
        return {
            "total_prompt_tokens": self.total_prompt_tokens,
            "total_completion_tokens": self.total_completion_tokens,
            "total_tokens": self.total_prompt_tokens + self.total_completion_tokens,
            "total_requests": self.total_requests,
            "average_prompt_tokens": self.total_prompt_tokens / max(self.total_requests, 1),
            "average_completion_tokens": self.total_completion_tokens / max(self.total_requests, 1),
        }
    
    def print_token_usage(self) -> None:
        """Print token usage statistics to console."""
        usage = self.get_token_usage()
        logging.info("\n" + "="*50)
        logging.info("TOKEN USAGE STATISTICS")
        logging.info("="*50)
        logging.info(f"Model: {self.model_name}")
        logging.info(f"Total Requests: {usage['total_requests']}")
        logging.info(f"Prompt Tokens: {usage['total_prompt_tokens']:,}")
        logging.info(f"Completion Tokens: {usage['total_completion_tokens']:,}")
        logging.info(f"Total Tokens: {usage['total_tokens']:,}")
        if usage['total_requests'] > 0:
            logging.info(f"Average Prompt Tokens/Request: {usage['average_prompt_tokens']:.1f}")
            logging.info(f"Average Completion Tokens/Request: {usage['average_completion_tokens']:.1f}")
        logging.info("="*50)
    
    def save_evaluation_with_usage(self, evaluation: Dict[str, Any], filepath: str) -> None:
        """
        Save evaluation results with token usage to a JSON file.
        
        Args:
            evaluation: The evaluation dictionary to save
            filepath: Path to save the file
        """
        result = {
            "evaluation": evaluation,
            "token_usage": self.get_token_usage(),
            "model": self.model_name
        }
        with open(filepath, 'w') as f:
            json.dump(result, f, indent=2)
        logging.info(f"Evaluation with token usage saved to {filepath}")


def load_text_log(filepath: str) -> str:
    """
    Load a text log file.
    
    Args:
        filepath: Path to the text log file
        
    Returns:
        Content of the log file as string
    """
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            return f.read()
    except Exception as e:
        raise ValueError(f"Failed to load log file {filepath}: {e}")


def main():
    """Main function with command-line argument support."""
    parser = argparse.ArgumentParser(description="LLM Judge for evaluating orchestrator performance")
    parser.add_argument(
        "--model", 
        choices=["gpt-5", "mistral-24b-instruct"], 
        default="gpt-5",
        help="LLM model to use for evaluation (default: gpt-5)"
    )
    parser.add_argument(
        "--log-file",
        default="Orchestrator_Claude_04/3O_tool_call_claude-3.7-sonnet-thinking_irt_2025-09-20_19-22-08/evaluation_clean.log",
        help="Path to the log file to evaluate"
    )
    parser.add_argument(
        "--output",
        default="evaluation_result_TO_gpt-5.json",
        help="Output file path (default: evaluation_result_{model}.json)"
    )
    parser.add_argument(
        "--log-level",
        choices=["DEBUG", "INFO", "WARNING", "ERROR"],
        default="INFO",
        help="Logging level (default: INFO)"
    )
    
    args = parser.parse_args()
    
    # Setup logging with specified level
    log_level = getattr(logging, args.log_level.upper())
    setup_logging(level=log_level)
    
    # Set default output file if not provided
    if args.output is None:
        args.output = f"evaluation_result_{args.model.replace('-', '_')}.json"
    
    try:
        # Initialize the judge with the specified model
        logging.info(f"Initializing LLM Judge with {args.model}...")
        judge = LLMJudge(model_name=args.model)
        
        # Load the text log
        logging.info(f"Loading log from: {args.log_file}")
        log_content = load_text_log(args.log_file)
        logging.info(f"Log loaded successfully ({len(log_content)} characters)")
        
        # Evaluate the log
        logging.info(f"Evaluating log with {args.model}...")
        evaluation = judge.evaluate_log(log_content)
        
        logging.info(f"Evaluation Results ({args.model}):")
        logging.info("=" * 60)
        for dimension, data in evaluation.items():
            logging.info(f"{dimension}: {data['score']}/5")
            logging.info(f"  Reason: {data['reason']}")
        
        # Save evaluation with token usage (always)
        judge.save_evaluation_with_usage(evaluation, args.output)
        logging.info(f"Evaluation with token usage saved to {args.output}")
        
        # Display token usage
        judge.print_token_usage()
        
    except Exception as e:
        logging.error(f"Error: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()
