import json
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Tuple, Set

def load_classification_ids(evaluation_file: str) -> Tuple[Set[str], Set[str], Set[str]]:
    """
    Load classification IDs from patch evaluation file
    
    Args:
        evaluation_file: Path to patch evaluation JSON file
        
    Returns:
        Tuple[Set[str], Set[str], Set[str]]: (empty_patch_ids, resolved_ids, unresolved_ids)
    """
    try:
        with open(evaluation_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        empty_patch_ids = set(data.get('empty_patch_ids', []))
        resolved_ids = set(data.get('resolved_ids', []))
        unresolved_ids = set(data.get('unresolved_ids', []))
        
        return empty_patch_ids, resolved_ids, unresolved_ids
    except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
        print(f"Warning: Unable to load classification file {evaluation_file}: {e}")
        return set(), set(), set()

def detect_stuck_in_loop(messages: List[Dict]) -> bool:
    """
    Detect if trajectory is stuck in loop (use bucket sort to count assistant's content, if the same identical sentence appears >=3 times, it's considered stuck in loop)
    
    Args:
        messages: Message list
        
    Returns:
        bool: Whether stuck in loop
    """
    if len(messages) < 3:  # Need at least 3 messages to detect loop
        return False
    
    # Use bucket sort to count assistant role message content
    assistant_content_count = {}
    
    for msg in messages:
        if msg.get('role') == 'assistant':
            content = msg.get('content', '').strip()
            if content:  # Ensure not empty string
                # Count occurrences of each content
                assistant_content_count[content] = assistant_content_count.get(content, 0) + 1
    
    # Check if any content appears >=3 times
    for content, count in assistant_content_count.items():
        if count >= 3:
            return True
    
    return False

def count_action_observation_turns(messages: List[Dict]) -> int:
    """
    Count action-observation turns (based on assistant reply count)
    
    Args:
        messages: Message list
        
    Returns:
        int: Number of action-observation turns
    """
    if not messages:
        return 0
    
    # Count assistant role messages (usually represents turns)
    turn_count = sum(1 for msg in messages if msg.get('role') == 'assistant')
    return turn_count

def calculate_tool_call_accuracy(messages: List[Dict]) -> Tuple[int, int, float]:
    """
    Calculate tool call accuracy
    
    Args:
        messages: Message list
        
    Returns:
        Tuple[int, int, float]: (total tool calls, correct tool calls, accuracy percentage)
    """
    total_tool_calls = 0
    correct_tool_calls = 0
    
    # Check assistant's tool calls and subsequent user responses
    for i in range(len(messages) - 1):
        current_msg = messages[i]
        next_msg = messages[i + 1]
        
        # Check if it's assistant's tool call
        if current_msg.get('role') == 'assistant':
            # Check if contains tool calls (Moatless may use different formats)
            content = current_msg.get('content', '')
            if any(keyword in content.lower() for keyword in ['function_calls', 'tool_calls', '<function', 'action']):
                total_tool_calls += 1
                
                # Check if next message contains errors
                if next_msg.get('role') in ['user', 'tool', 'system']:
                    next_content = next_msg.get('content', '').lower()
                    error_patterns = [
                        'error', 'failed', 'exception', 'traceback',
                        'command not found', 'no such', 'permission denied'
                    ]
                    
                    has_error = any(pattern in next_content for pattern in error_patterns)
                    if not has_error:
                        correct_tool_calls += 1
    
    # If no obvious tool calls found, use simplified error detection method
    if total_tool_calls == 0:
        error_patterns = [
            'command not found',
            'No such',
            'Permission denied',
            'error',
            'Error',
            'ERROR',
            'failed',
            'Failed',
            'FAILED'
        ]
        
        user_messages = [msg for msg in messages if (msg.get('role') == 'user' or msg.get('role') == 'tool')]
        total_tool_calls = max(0, len(user_messages) - 1)  # Exclude first system message
        correct_tool_calls = 0
        
        for msg in user_messages:
            content = msg.get('content', '').lower()
            has_error = any(pattern.lower() in content for pattern in error_patterns)
            if not has_error:
                correct_tool_calls += 1

    accuracy = (correct_tool_calls / total_tool_calls * 100) if total_tool_calls > 0 else 0
    return total_tool_calls, correct_tool_calls, accuracy

def extract_token_usage(data: Dict) -> Tuple[int, int]:
    """
    Extract token usage information from JSON data (adapted for Moatless format)
    
    Args:
        data: JSON data
        
    Returns:
        Tuple[int, int]: (input tokens, output tokens)
    """
    tokens_sent = 0
    tokens_received = 0
    
    # Try various possible token field names (adapt to different formats)
    possible_fields = [
        ('model_stats', 'tokens_sent', 'tokens_received'),
        ('usage', 'prompt_tokens', 'completion_tokens'),
        ('token_usage', 'input_tokens', 'output_tokens'),
        ('statistics', 'input_tokens', 'output_tokens'),
        ('metrics', 'prompt_tokens', 'completion_tokens')
    ]
    
    for parent_field, input_field, output_field in possible_fields:
        if parent_field in data:
            parent_data = data[parent_field]
            if isinstance(parent_data, dict):
                tokens_sent = parent_data.get(input_field, 0)
                tokens_received = parent_data.get(output_field, 0)
                if tokens_sent > 0 or tokens_received > 0:
                    break
    
    # If no explicit token statistics found, try to get from top-level fields
    if tokens_sent == 0 and tokens_received == 0:
        tokens_sent = data.get('prompt_tokens', data.get('input_tokens', 0))
        tokens_received = data.get('completion_tokens', data.get('output_tokens', 0))
    
    return tokens_sent, tokens_received

def analyze_trajectory(jsonl_file: str, empty_patch_ids: Set[str], resolved_ids: Set[str], unresolved_ids: Set[str]) -> Tuple[int, Dict, List[Dict]]:
    """
    Analyze trajectory information for each JSON object in JSONL file, with classification statistics
    
    Args:
        jsonl_file: JSONL file path, each line is an independent JSON object
        empty_patch_ids: Set of empty patch IDs
        resolved_ids: Set of resolved IDs
        unresolved_ids: Set of unresolved IDs
        
    Returns:
        Tuple[int, Dict, List[Dict]]: (total JSON count, overall statistics, detailed statistics for each JSON)
    """
    total_jsons = 0
    total_tokens_sent = 0
    total_tokens_received = 0
    total_turns = 0
    stuck_in_loop_count = 0
    total_tool_calls = 0
    total_correct_tool_calls = 0
    json_stats = []
    
    # Classification statistics
    category_stats = {
        'empty_patch': {'count': 0, 'tokens_sent': 0, 'tokens_received': 0, 'turns': 0, 'stuck_in_loop': 0, 'tool_calls': 0, 'correct_tool_calls': 0},
        'resolved': {'count': 0, 'tokens_sent': 0, 'tokens_received': 0, 'turns': 0, 'stuck_in_loop': 0, 'tool_calls': 0, 'correct_tool_calls': 0},
        'unresolved': {'count': 0, 'tokens_sent': 0, 'tokens_received': 0, 'turns': 0, 'stuck_in_loop': 0, 'tool_calls': 0, 'correct_tool_calls': 0},
        'other': {'count': 0, 'tokens_sent': 0, 'tokens_received': 0, 'turns': 0, 'stuck_in_loop': 0, 'tool_calls': 0, 'correct_tool_calls': 0}
    }
    
    with open(jsonl_file, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            if not line.strip():  # Skip empty lines
                continue
            
            try:
                data = json.loads(line)
                
                # Extract token usage information
                tokens_sent, tokens_received = extract_token_usage(data)
                total_tokens = tokens_sent + tokens_received
                
                # Get messages for loop detection and turn statistics
                messages = data.get('messages', [])
                is_stuck_in_loop = detect_stuck_in_loop(messages)
                turns = count_action_observation_turns(messages)
                tool_calls, correct_tool_calls, tool_accuracy = calculate_tool_call_accuracy(messages)
                
                total_jsons += 1
                total_tokens_sent += tokens_sent
                total_tokens_received += tokens_received
                total_turns += turns
                total_tool_calls += tool_calls
                total_correct_tool_calls += correct_tool_calls
                if is_stuck_in_loop:
                    stuck_in_loop_count += 1
                
                # Get other useful information
                instance_id = data.get('instance_id', data.get('id', f'line_{line_num}'))
                messages_count = len(messages)
                
                # Determine classification
                if instance_id in empty_patch_ids:
                    category = 'empty_patch'
                elif instance_id in resolved_ids:
                    category = 'resolved'
                elif instance_id in unresolved_ids:
                    category = 'unresolved'
                else:
                    category = 'other'
                
                # Update classification statistics
                category_stats[category]['count'] += 1
                category_stats[category]['tokens_sent'] += tokens_sent
                category_stats[category]['tokens_received'] += tokens_received
                category_stats[category]['turns'] += turns
                category_stats[category]['tool_calls'] += tool_calls
                category_stats[category]['correct_tool_calls'] += correct_tool_calls
                if is_stuck_in_loop:
                    category_stats[category]['stuck_in_loop'] += 1
                
                json_stat = {
                    "line_number": line_num,
                    "instance_id": instance_id,
                    "category": category,
                    "tokens_sent": tokens_sent,
                    "tokens_received": tokens_received,
                    "total_tokens": total_tokens,
                    "messages_count": messages_count,
                    "turns": turns,
                    "stuck_in_loop": is_stuck_in_loop,
                    "tool_calls": tool_calls,
                    "correct_tool_calls": correct_tool_calls,
                    "tool_call_accuracy": round(tool_accuracy, 2)
                }
            except json.JSONDecodeError:
                # If JSON parsing fails, record error information
                json_stat = {
                    "line_number": line_num,
                    "instance_id": f'parse_error_line_{line_num}',
                    "category": 'other',
                    "tokens_sent": 0,
                    "tokens_received": 0,
                    "total_tokens": 0,
                    "messages_count": 0,
                    "turns": 0,
                    "stuck_in_loop": False,
                    "tool_calls": 0,
                    "correct_tool_calls": 0,
                    "tool_call_accuracy": 0,
                    "parse_error": True
                }
                total_jsons += 1
                category_stats['other']['count'] += 1
                total_jsons += 1
            
            json_stats.append(json_stat)
    
    # Calculate overall statistics
    stuck_in_loop_percentage = (stuck_in_loop_count / total_jsons * 100) if total_jsons > 0 else 0
    avg_turns = total_turns / total_jsons if total_jsons > 0 else 0
    overall_tool_accuracy = (total_correct_tool_calls / total_tool_calls * 100) if total_tool_calls > 0 else 0
    
    total_stats = {
        "total_tokens_sent": total_tokens_sent,
        "total_tokens_received": total_tokens_received,
        "total_tokens_combined": total_tokens_sent + total_tokens_received,
        "avg_tokens_sent": total_tokens_sent / total_jsons if total_jsons > 0 else 0,
        "avg_tokens_received": total_tokens_received / total_jsons if total_jsons > 0 else 0,
        "avg_tokens_combined": (total_tokens_sent + total_tokens_received) / total_jsons if total_jsons > 0 else 0,
        "stuck_in_loop_count": stuck_in_loop_count,
        "stuck_in_loop_percentage": stuck_in_loop_percentage,
        "total_turns": total_turns,
        "avg_turns": avg_turns,
        "total_tool_calls": total_tool_calls,
        "total_correct_tool_calls": total_correct_tool_calls,
        "overall_tool_accuracy": overall_tool_accuracy,
        "category_stats": category_stats
    }
    
    return total_jsons, total_stats, json_stats

def generate_metrics_report(jsonl_file: str, output_dir, empty_patch_ids: Set[str], resolved_ids: Set[str], unresolved_ids: Set[str]):
    """
    Generate metrics report for Moatless trajectories
    """
    # Input file
    total_jsons, total_stats, json_stats = analyze_trajectory(jsonl_file, empty_patch_ids, resolved_ids, unresolved_ids)
    
    # Generate output filename, consistent with input filename, add timestamp
    input_filename = Path(jsonl_file).stem  # Get input filename (without extension)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_file = output_dir / f"{input_filename}.json"
    
    # Calculate sorting and extremes - simplified version, only keep necessary calculations
    valid_stats = [stat for stat in json_stats if not stat.get('parse_error', False)]
    
    # Group data by classification
    categories = ['empty_patch', 'resolved', 'unresolved', 'other']
    category_data = {}
    
    for category in categories:
        category_instances = [stat for stat in valid_stats if stat.get('category') == category]
        if category_instances:
            turns_list = [stat["turns"] for stat in category_instances]
            stuck_in_loop_list = [1 if stat["stuck_in_loop"] else 0 for stat in category_instances]
            tool_accuracy_list = [stat["tool_call_accuracy"] for stat in category_instances]
            prompt_tokens_list = [stat["tokens_sent"] for stat in category_instances]
            completion_tokens_list = [stat["tokens_received"] for stat in category_instances]
            
            category_data[category] = {
                "instance_ids": [stat["instance_id"] for stat in category_instances],
                "turns_list": turns_list,
                "stuck_in_loop_list": stuck_in_loop_list,
                "tool_accuracy_list": tool_accuracy_list,
                "prompt_tokens_list": prompt_tokens_list,
                "completion_tokens_list": completion_tokens_list,
                "averages": {
                    "avg_turns": round(sum(turns_list) / len(turns_list), 2) if turns_list else 0,
                    "avg_stuck_in_loop_rate": round(sum(stuck_in_loop_list) / len(stuck_in_loop_list) * 100, 2) if stuck_in_loop_list else 0,
                    "avg_tool_accuracy": round(sum(tool_accuracy_list) / len(tool_accuracy_list), 2) if tool_accuracy_list else 0,
                    "avg_prompt_tokens": round(sum(prompt_tokens_list) / len(prompt_tokens_list), 2) if prompt_tokens_list else 0,
                    "avg_completion_tokens": round(sum(completion_tokens_list) / len(completion_tokens_list), 2) if completion_tokens_list else 0
                }
            }
        else:
            category_data[category] = {
                "instance_ids": [],
                "turns_list": [],
                "stuck_in_loop_list": [],
                "tool_accuracy_list": [],
                "prompt_tokens_list": [],
                "completion_tokens_list": [],
                "averages": {
                    "avg_turns": 0,
                    "avg_stuck_in_loop_rate": 0,
                    "avg_tool_accuracy": 0,
                    "avg_prompt_tokens": 0,
                    "avg_completion_tokens": 0
                }
            }
    
    # Prepare output data - according to example JSON format, include ID lists and corresponding value lists
    instance_ids = [stat["instance_id"] for stat in valid_stats]
    turns_list = [stat["turns"] for stat in valid_stats]
    stuck_in_loop_list = [1 if stat["stuck_in_loop"] else 0 for stat in valid_stats]
    tool_accuracy_list = [stat["tool_call_accuracy"] for stat in valid_stats]
    prompt_tokens_list = [stat["tokens_sent"] for stat in valid_stats]
    completion_tokens_list = [stat["tokens_received"] for stat in valid_stats]
    
    stats = {
        "summary": {
            "total_instances": len(valid_stats),
            "avg_turns": round(total_stats["avg_turns"], 2),
            "stuck_in_loop_percentage": round(total_stats["stuck_in_loop_percentage"], 2),
            "overall_tool_accuracy": round(total_stats["overall_tool_accuracy"], 2),
            "total_prompt_tokens": total_stats["total_tokens_sent"],
            "total_completion_tokens": total_stats["total_tokens_received"],
            "avg_prompt_tokens": round(total_stats["avg_tokens_sent"], 2),
            "avg_completion_tokens": round(total_stats["avg_tokens_received"], 2),
            "classification_counts": {cat: total_stats["category_stats"][cat]["count"] for cat in categories}
        },
        "instance_ids": instance_ids,
        "turns_list": turns_list,
        "stuck_in_loop_list": stuck_in_loop_list,
        "tool_accuracy_list": tool_accuracy_list,
        "prompt_tokens_list": prompt_tokens_list,
        "completion_tokens_list": completion_tokens_list,
        "category_breakdown": category_data,
        "timestamp": datetime.now().isoformat(),
        "agent_type": "moatless"
    }
    
    # Write to JSON file
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(stats, f, indent=4, ensure_ascii=False)
    
    # Print results
    print(f"Moatless trajectory analysis results:")
    print(f"Total instances: {len(valid_stats)}")
    print(f"")
    print(f"Classification statistics:")
    for category in categories:
        cat_stats = total_stats["category_stats"][category]
        count = cat_stats["count"]
        if count > 0:
            print(f"  {category}: {count} instances")
            print(f"    Average turns: {cat_stats['turns']/count:.2f}")
            print(f"    Loop rate: {cat_stats['stuck_in_loop']/count*100:.2f}%")
            print(f"    Tool accuracy: {cat_stats['correct_tool_calls']/cat_stats['tool_calls']*100 if cat_stats['tool_calls'] > 0 else 0:.2f}%")
            print(f"    Average input tokens: {cat_stats['tokens_sent']/count:.2f}")
            print(f"    Average output tokens: {cat_stats['tokens_received']/count:.2f}")
    print(f"")
    print(f"Key metrics statistics:")
    print(f"  Average turns (Avg Turns): {total_stats['avg_turns']:.2f}")
    print(f"  Loop percentage (% Stuck in Loop): {total_stats['stuck_in_loop_percentage']:.2f}%")
    print(f"  Tool call accuracy (% Tool Accuracy): {total_stats['overall_tool_accuracy']:.2f}%")
    print(f"  Total input tokens (Total Prompt Tokens): {total_stats['total_tokens_sent']}")
    print(f"  Total output tokens (Total Completion Tokens): {total_stats['total_tokens_received']}")
    print(f"  Average input tokens (Avg Prompt Tokens): {total_stats['avg_tokens_sent']:.2f}")
    print(f"  Average output tokens (Avg Completion Tokens): {total_stats['avg_tokens_received']:.2f}")
    print(f"")
    print(f"Data list lengths:")
    print(f"  Instance IDs: {len(stats['instance_ids'])}")
    print(f"  Turns List: {len(stats['turns_list'])}")
    print(f"  Stuck in Loop List: {len(stats['stuck_in_loop_list'])}")
    print(f"  Tool Accuracy List: {len(stats['tool_accuracy_list'])}")
    print(f"  Prompt Tokens List: {len(stats['prompt_tokens_list'])}")
    print(f"  Completion Tokens List: {len(stats['completion_tokens_list'])}")
    print(f"")
    print(f"Detailed statistics saved to: {output_file}")

def get_evaluation_file_for_jsonl(jsonl_file: str) -> str:
    """
    Infer corresponding evaluation file name based on JSONL file name
    """
    jsonl_path = Path(jsonl_file)
    jsonl_name = jsonl_path.stem
    
    # Define mapping relationships
    name_mapping = {
        'moatless-deepseek-v3': 'moatless-deepseek-v3', 
    }
    
    # Extract possible model names from file name
    for pattern, evaluation_name in name_mapping.items():
        if pattern in jsonl_name:
            return f"data/patch-evaluation/{evaluation_name}.json"
    
    # If no clear mapping exists, try to infer from file name
    if 'openhands' in jsonl_name.lower():
        return "data/patch-evaluation/openhands-deepseek-v3.json"
    elif 'moatless' in jsonl_name.lower():
        return "data/patch-evaluation/moatless-deepseek-v3.json"
    elif 'deepseek' in jsonl_name.lower():
        if 'r1' in jsonl_name.lower():
            return "data/patch-evaluation/deepseek-r1.json"
        else:
            return "data/patch-evaluation/deepseek-v3.json"
    elif 'qwen2.5' in jsonl_name.lower():
        if '32b' in jsonl_name.lower():
            return "data/patch-evaluation/qwen2.5-coder-32b.json"
        elif '14b' in jsonl_name.lower():
            return "data/patch-evaluation/qwen2.5-coder-14b.json"
        elif '7b' in jsonl_name.lower():
            return "data/patch-evaluation/qwen2.5-coder-7b.json"
        else:
            return "data/patch-evaluation/qwen2.5-coder-32b.json"
    elif 'qwen3' in jsonl_name.lower():
        if '32b' in jsonl_name.lower():
            return "data/patch-evaluation/qwen3-32b.json"
        elif '14b' in jsonl_name.lower():
            return "data/patch-evaluation/qwen3-14b.json"
    elif 'mistral' in jsonl_name.lower():
        return "data/patch-evaluation/mistral.json"
    elif 'devstral' in jsonl_name.lower():
        return "data/patch-evaluation/devstral.json"
    
    # Default return deepseek-v3
    return ""

def main():
    """
    Main function - process Moatless format trajectory files
    """
    # Generate output folder name with time accurate to seconds
    output_path = f"data/traj-evaluation"
    Path(output_path).mkdir(parents=True, exist_ok=True)

    # Moatless trajectory file list - can be modified according to actual situation
    jsonl_file_list = ['data/trajectory_original/moatless-deepseek-v3.jsonl']
    
    # If jsonl_file_list is empty, try to find Moatless files in current directory
    if not jsonl_file_list or not any(Path(f).exists() for f in jsonl_file_list):
        current_dir = Path(".")
        # Find jsonl files containing moatless keywords
        moatless_files = list(current_dir.glob("*moatless*.jsonl"))
        if moatless_files:
            jsonl_file_list = [str(f) for f in moatless_files]
            print(f"Automatically discovered {len(jsonl_file_list)} Moatless trajectory files")
        else:
            print("No Moatless trajectory files found. Please run Evaluate_Trajectory_By_Rule/split_traj_by_patch_correctness/generate-parameter-of-trajectory2report.py firstly, and specify file paths in jsonl_file_list")
            return
    
    for jsonl_file in jsonl_file_list:
        if Path(jsonl_file).exists():
            print(f"Analyzing Moatless file: {jsonl_file}")
            
            # Get corresponding evaluation file
            evaluation_file = get_evaluation_file_for_jsonl(jsonl_file)
            print(f"Using evaluation file: {evaluation_file}")
            
            # Load classification IDs
            empty_patch_ids, resolved_ids, unresolved_ids = load_classification_ids(evaluation_file)
            print(f"Loaded classifications: empty_patch({len(empty_patch_ids)}), resolved({len(resolved_ids)}), unresolved({len(unresolved_ids)})")
            
            generate_metrics_report(jsonl_file, Path(output_path), empty_patch_ids, resolved_ids, unresolved_ids)
            print("\n" + "="*50 + "\n")
        else:
            print(f"File does not exist: {jsonl_file}")

if __name__ == "__main__":
    main()