import json
import json
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Tuple, Set

def load_classification_ids(evaluation_file: str) -> Tuple[Set[str], Set[str], Set[str]]:
    """
    Load classification IDs from patch evaluation file
    
    Args:
        evaluation_file: Path to patch evaluation JSON file
        
    Returns:
        Tuple[Set[str], Set[str], Set[str]]: (empty_patch_ids, resolved_ids, unresolved_ids)
    """
    try:
        with open(evaluation_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        empty_patch_ids = set(data.get('empty_patch_ids', []))
        resolved_ids = set(data.get('resolved_ids', []))
        unresolved_ids = set(data.get('unresolved_ids', []))
        
        return empty_patch_ids, resolved_ids, unresolved_ids
    except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
        print(f"Warning: Unable to load classification file {evaluation_file}: {e}")
        return set(), set(), set()

def extract_swe_bench_instance_id(trajectory_instance_id: str) -> str:
    """
    Extract SWE-bench instance ID from trajectory instance_id
    
    Trajectory ID format: qwen2.5-coder:32b__t-0.00__p-1.00__c-0.00___swe_bench_lite_test_django__django-15695_django__django-15695
    Extraction result: django__django-15695
    
    Args:
        trajectory_instance_id: instance_id in trajectory
        
    Returns:
        str: Extracted SWE-bench instance ID
    """
    # Use regex to directly match the last repo__repo-number format
    import re
    
    # Find pattern: word__word-number, at end of string or before _
    pattern = r'([a-zA-Z0-9-]+__[a-zA-Z0-9-]+-\d+)(?:_[a-zA-Z0-9-]+__[a-zA-Z0-9-]+-\d+)?$'
    match = re.search(pattern, trajectory_instance_id)
    
    if match:
        return match.group(1)
    
    # Backup method: find part after test_
    if 'swe_bench_lite_test_' in trajectory_instance_id:
        # Find part after test_
        parts = trajectory_instance_id.split('swe_bench_lite_test_')
        if len(parts) > 1:
            # Get part after test_, then extract first repo__repo-number
            remainder = parts[1]
            # Find first part matching the format
            pattern2 = r'([a-zA-Z0-9-]+__[a-zA-Z0-9-]+-\d+)'
            match2 = re.search(pattern2, remainder)
            if match2:
                return match2.group(1)
    
    # If not found, return original ID
    return trajectory_instance_id

def detect_stuck_in_loop(messages: List[Dict]) -> bool:
    """
    Detect if trajectory is stuck in loop (use bucket sort to count assistant's content, if the same identical sentence appears >=3 times, it's considered stuck in loop)
    
    Args:
        messages: Message list
        
    Returns:
        bool: Whether stuck in loop
    """
    if len(messages) < 3:  # Need at least 3 messages to detect loop
        return False
    
    # Use bucket sort to count assistant role message content
    assistant_content_count = {}
    
    for msg in messages:
        if msg.get('role') == 'assistant':
            content = msg.get('content', '').strip()
            if content:  # Ensure not empty string
                # Count occurrences of each content
                assistant_content_count[content] = assistant_content_count.get(content, 0) + 1
    
    # Check if any content appears >=3 times
    for content, count in assistant_content_count.items():
        if count >= 3:
            return True
    
    return False

def count_action_observation_turns(messages: List[Dict]) -> int:
    """
    Count action-observation turns (based on assistant reply count)
    
    Args:
        messages: Message list
        
    Returns:
        int: Number of action-observation turns
    """
    if not messages:
        return 0
    
    # Count assistant role messages (usually represents turns)
    turn_count = sum(1 for msg in messages if msg.get('role') == 'assistant')
    return turn_count

def calculate_tool_call_accuracy(messages: List[Dict]) -> Tuple[int, int, float]:
    """
    Calculate tool call accuracy
    
    Args:
        messages: Message list
        
    Returns:
        Tuple[int, int, float]: (total tool calls, correct tool calls, accuracy percentage)
    """
    total_tool_calls = 0
    correct_tool_calls = 0
    
    # If first method didn't find any tool calls, use second method
    if total_tool_calls == 0 or correct_tool_calls == 0:
        # Second method: check if user role replies contain error information
        error_patterns = [
            'command not found',
            'No such',
            'Permission denied',
            'error',
            'Error',
            'ERROR',
            'failed',
            'Failed',
            'FAILED'
        ]
        
        user_messages = [msg for msg in messages if (msg.get('role') == 'user' or msg.get('role') == 'tool')]
        total_tool_calls = len(user_messages)-1
        correct_tool_calls = 0
        
        for msg in user_messages:
            content = msg.get('content', '').lower()
            has_error = any(pattern.lower() in content for pattern in error_patterns)
            if not has_error:  # No error information means tool call is correct
                correct_tool_calls += 1
    
    print(f"Total tool calls: {total_tool_calls}, Correct tool calls: {correct_tool_calls}")

    accuracy = (correct_tool_calls / total_tool_calls * 100) if total_tool_calls > 0 else 0
    return total_tool_calls, correct_tool_calls, accuracy

def analyze_trajectory(jsonl_file: str, empty_patch_ids: Set[str], resolved_ids: Set[str], unresolved_ids: Set[str]) -> Tuple[int, Dict, List[Dict]]:
    """
    Analyze token consumption in model_stats, loop detection and turn statistics for each JSON object in JSONL file, with classification statistics
    
    Args:
        jsonl_file: JSONL file path, each line is an independent JSON object
        empty_patch_ids: Set of empty patch IDs
        resolved_ids: Set of resolved IDs
        unresolved_ids: Set of unresolved IDs
        
    Returns:
        Tuple[int, Dict, List[Dict]]: (total JSON count, overall statistics, detailed statistics for each JSON)
    """
    total_jsons = 0
    total_tokens_sent = 0
    total_tokens_received = 0
    total_turns = 0
    stuck_in_loop_count = 0
    total_tool_calls = 0
    total_correct_tool_calls = 0
    json_stats = []
    
    # Classification statistics
    category_stats = {
        'empty_patch': {'count': 0, 'tokens_sent': 0, 'tokens_received': 0, 'turns': 0, 'stuck_in_loop': 0, 'tool_calls': 0, 'correct_tool_calls': 0},
        'resolved': {'count': 0, 'tokens_sent': 0, 'tokens_received': 0, 'turns': 0, 'stuck_in_loop': 0, 'tool_calls': 0, 'correct_tool_calls': 0},
        'unresolved': {'count': 0, 'tokens_sent': 0, 'tokens_received': 0, 'turns': 0, 'stuck_in_loop': 0, 'tool_calls': 0, 'correct_tool_calls': 0},
        'other': {'count': 0, 'tokens_sent': 0, 'tokens_received': 0, 'turns': 0, 'stuck_in_loop': 0, 'tool_calls': 0, 'correct_tool_calls': 0}
    }
    
    with open(jsonl_file, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            if not line.strip():  # Skip empty lines
                continue
            
            # Try to parse JSON to get model_stats information
            try:
                data = json.loads(line)
                
                # Get token information from model_stats
                model_stats = data.get('model_stats', {})
                tokens_sent = model_stats.get('tokens_sent', 0)
                tokens_received = model_stats.get('tokens_received', 0)
                total_tokens = tokens_sent + tokens_received
                
                # Get messages for loop detection and turn statistics
                messages = data.get('messages', [])
                is_stuck_in_loop = detect_stuck_in_loop(messages)
                turns = count_action_observation_turns(messages)
                tool_calls, correct_tool_calls, tool_accuracy = calculate_tool_call_accuracy(messages)
                
                total_jsons += 1
                total_tokens_sent += tokens_sent
                total_tokens_received += tokens_received
                total_turns += turns
                total_tool_calls += tool_calls
                total_correct_tool_calls += correct_tool_calls
                if is_stuck_in_loop:
                    stuck_in_loop_count += 1
                
                # Get other useful information
                trajectory_instance_id = data.get('instance_id', f'line_{line_num}')
                instance_id = extract_swe_bench_instance_id(trajectory_instance_id)
                messages_count = len(messages)
                
                # Determine classification
                if instance_id in empty_patch_ids:
                    category = 'empty_patch'
                elif instance_id in resolved_ids:
                    category = 'resolved'
                elif instance_id in unresolved_ids:
                    category = 'unresolved'
                else:
                    category = 'other'
                
                # Update classification statistics
                category_stats[category]['count'] += 1
                category_stats[category]['tokens_sent'] += tokens_sent
                category_stats[category]['tokens_received'] += tokens_received
                category_stats[category]['turns'] += turns
                category_stats[category]['tool_calls'] += tool_calls
                category_stats[category]['correct_tool_calls'] += correct_tool_calls
                if is_stuck_in_loop:
                    category_stats[category]['stuck_in_loop'] += 1
                
                json_stat = {
                    "line_number": line_num,
                    "instance_id": instance_id,  # Clean SWE-bench instance ID
                    "trajectory_instance_id": trajectory_instance_id,  # Original trajectory instance ID
                    "category": category,
                    "tokens_sent": tokens_sent,
                    "tokens_received": tokens_received,
                    "total_tokens": total_tokens,
                    "messages_count": messages_count,
                    "turns": turns,
                    "stuck_in_loop": is_stuck_in_loop,
                    "tool_calls": tool_calls,
                    "correct_tool_calls": correct_tool_calls,
                    "tool_call_accuracy": round(tool_accuracy, 2),
                    "model_stats": model_stats  # Preserve complete model_stats information
                }
                
            except json.JSONDecodeError:
                # If JSON parsing fails, record error information
                trajectory_instance_id = f'parse_error_line_{line_num}'
                instance_id = f'parse_error_line_{line_num}'
                json_stat = {
                    "line_number": line_num,
                    "instance_id": instance_id,
                    "trajectory_instance_id": trajectory_instance_id,
                    "category": 'other',
                    "tokens_sent": 0,
                    "tokens_received": 0,
                    "total_tokens": 0,
                    "messages_count": 0,
                    "turns": 0,
                    "stuck_in_loop": False,
                    "tool_calls": 0,
                    "correct_tool_calls": 0,
                    "tool_call_accuracy": 0,
                    "parse_error": True
                }
                total_jsons += 1
                # Note: Parse error records don't update category_stats as they will be filtered out by valid_stats
            
            json_stats.append(json_stat)
    
    # Calculate overall statistics
    stuck_in_loop_percentage = (stuck_in_loop_count / total_jsons * 100) if total_jsons > 0 else 0
    avg_turns = total_turns / total_jsons if total_jsons > 0 else 0
    overall_tool_accuracy = (total_correct_tool_calls / total_tool_calls * 100) if total_tool_calls > 0 else 0
    
    total_stats = {
        "total_tokens_sent": total_tokens_sent,
        "total_tokens_received": total_tokens_received,
        "total_tokens_combined": total_tokens_sent + total_tokens_received,
        "avg_tokens_sent": total_tokens_sent / total_jsons if total_jsons > 0 else 0,
        "avg_tokens_received": total_tokens_received / total_jsons if total_jsons > 0 else 0,
        "avg_tokens_combined": (total_tokens_sent + total_tokens_received) / total_jsons if total_jsons > 0 else 0,
        "stuck_in_loop_count": stuck_in_loop_count,
        "stuck_in_loop_percentage": stuck_in_loop_percentage,
        "total_turns": total_turns,
        "avg_turns": avg_turns,
        "total_tool_calls": total_tool_calls,
        "total_correct_tool_calls": total_correct_tool_calls,
        "overall_tool_accuracy": overall_tool_accuracy,
        "category_stats": category_stats
    }
    
    return total_jsons, total_stats, json_stats

def generate_metrics_report(jsonl_file: str, output_dir, empty_patch_ids: Set[str], resolved_ids: Set[str], unresolved_ids: Set[str]):
    # Input file
    total_jsons, total_stats, json_stats = analyze_trajectory(jsonl_file, empty_patch_ids, resolved_ids, unresolved_ids)
    
    # Generate output filename, consistent with input filename, add timestamp
    input_filename = Path(jsonl_file).stem  # Get input filename (without extension)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_file = output_dir / f"{input_filename}.json"
    
    # Calculate sorting and extremes - simplified version, only keep necessary calculations
    valid_stats = [stat for stat in json_stats if not stat.get('parse_error', False)]
    
    # Group data by classification
    categories = ['empty_patch', 'resolved', 'unresolved', 'other']
    category_data = {}
    
    for category in categories:
        category_instances = [stat for stat in valid_stats if stat.get('category') == category]
        if category_instances:
            turns_list = [stat["turns"] for stat in category_instances]
            stuck_in_loop_list = [1 if stat["stuck_in_loop"] else 0 for stat in category_instances]
            tool_accuracy_list = [stat["tool_call_accuracy"] for stat in category_instances]
            prompt_tokens_list = [stat["tokens_sent"] for stat in category_instances]
            completion_tokens_list = [stat["tokens_received"] for stat in category_instances]
            
            category_data[category] = {
                "instance_ids": [stat["instance_id"] for stat in category_instances],
                "turns_list": turns_list,
                "stuck_in_loop_list": stuck_in_loop_list,
                "tool_accuracy_list": tool_accuracy_list,
                "prompt_tokens_list": prompt_tokens_list,
                "completion_tokens_list": completion_tokens_list,
                "averages": {
                    "avg_turns": round(sum(turns_list) / len(turns_list), 2) if turns_list else 0,
                    "avg_stuck_in_loop_rate": round(sum(stuck_in_loop_list) / len(stuck_in_loop_list) * 100, 2) if stuck_in_loop_list else 0,
                    "avg_tool_accuracy": round(sum(tool_accuracy_list) / len(tool_accuracy_list), 2) if tool_accuracy_list else 0,
                    "avg_prompt_tokens": round(sum(prompt_tokens_list) / len(prompt_tokens_list), 2) if prompt_tokens_list else 0,
                    "avg_completion_tokens": round(sum(completion_tokens_list) / len(completion_tokens_list), 2) if completion_tokens_list else 0
                }
            }
        else:
            category_data[category] = {
                "instance_ids": [],
                "turns_list": [],
                "stuck_in_loop_list": [],
                "tool_accuracy_list": [],
                "prompt_tokens_list": [],
                "completion_tokens_list": [],
                "averages": {
                    "avg_turns": 0,
                    "avg_stuck_in_loop_rate": 0,
                    "avg_tool_accuracy": 0,
                    "avg_prompt_tokens": 0,
                    "avg_completion_tokens": 0
                }
            }
    
    # Prepare output data - according to example JSON format, include ID lists and corresponding value lists
    instance_ids = [stat["instance_id"] for stat in valid_stats]
    turns_list = [stat["turns"] for stat in valid_stats]
    stuck_in_loop_list = [1 if stat["stuck_in_loop"] else 0 for stat in valid_stats]
    tool_accuracy_list = [stat["tool_call_accuracy"] for stat in valid_stats]
    prompt_tokens_list = [stat["tokens_sent"] for stat in valid_stats]
    completion_tokens_list = [stat["tokens_received"] for stat in valid_stats]
    
    stats = {
        "summary": {
            "total_instances": len(valid_stats),
            "avg_turns": round(total_stats["avg_turns"], 2),
            "stuck_in_loop_percentage": round(total_stats["stuck_in_loop_percentage"], 2),
            "overall_tool_accuracy": round(total_stats["overall_tool_accuracy"], 2),
            "total_prompt_tokens": total_stats["total_tokens_sent"],
            "total_completion_tokens": total_stats["total_tokens_received"],
            "avg_prompt_tokens": round(total_stats["avg_tokens_sent"], 2),
            "avg_completion_tokens": round(total_stats["avg_tokens_received"], 2),
            "classification_counts": {cat: total_stats["category_stats"][cat]["count"] for cat in categories}
        },
        "instance_ids": instance_ids,
        "turns_list": turns_list,
        "stuck_in_loop_list": stuck_in_loop_list,
        "tool_accuracy_list": tool_accuracy_list,
        "prompt_tokens_list": prompt_tokens_list,
        "completion_tokens_list": completion_tokens_list,
        "category_breakdown": category_data,
        "timestamp": datetime.now().isoformat()
    }
    
    # Write to JSON file
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(stats, f, indent=4, ensure_ascii=False)
    
    # Print results
    print(f"JSONL file trajectory analysis results:")
    print(f"Total instances: {len(valid_stats)}")
    print(f"")
    print(f"Classification statistics:")
    for category in categories:
        cat_stats = total_stats["category_stats"][category]
        count = cat_stats["count"]
        if count > 0:
            print(f"  {category}: {count} instances")
            print(f"    Average turns: {cat_stats['turns']/count:.2f}")
            print(f"    Loop rate: {cat_stats['stuck_in_loop']/count*100:.2f}%")
            print(f"    Tool accuracy: {cat_stats['correct_tool_calls']/cat_stats['tool_calls']*100 if cat_stats['tool_calls'] > 0 else 0:.2f}%")
            print(f"    Average input tokens: {cat_stats['tokens_sent']/count:.2f}")
            print(f"    Average output tokens: {cat_stats['tokens_received']/count:.2f}")
    print(f"")
    print(f"Key metrics statistics:")
    print(f"  Average turns (Avg Turns): {total_stats['avg_turns']:.2f}")
    print(f"  Loop percentage (% Stuck in Loop): {total_stats['stuck_in_loop_percentage']:.2f}%")
    print(f"  Tool call accuracy (% Tool Accuracy): {total_stats['overall_tool_accuracy']:.2f}%")
    print(f"  Total input tokens (Total Prompt Tokens): {total_stats['total_tokens_sent']}")
    print(f"  Total output tokens (Total Completion Tokens): {total_stats['total_tokens_received']}")
    print(f"  Average input tokens (Avg Prompt Tokens): {total_stats['avg_tokens_sent']:.2f}")
    print(f"  Average output tokens (Avg Completion Tokens): {total_stats['avg_tokens_received']:.2f}")
    print(f"")
    print(f"Data list lengths:")
    print(f"  Instance IDs: {len(stats['instance_ids'])}")
    print(f"  Turns List: {len(stats['turns_list'])}")
    print(f"  Stuck in Loop List: {len(stats['stuck_in_loop_list'])}")
    print(f"  Tool Accuracy List: {len(stats['tool_accuracy_list'])}")
    print(f"  Prompt Tokens List: {len(stats['prompt_tokens_list'])}")
    print(f"  Completion Tokens List: {len(stats['completion_tokens_list'])}")
    print(f"")
    print(f"Detailed statistics saved to: {output_file}")

def get_evaluation_file_for_jsonl(jsonl_file: str) -> str:
    """
    Infer corresponding evaluation file name based on JSONL file name
    """
    jsonl_path = Path(jsonl_file)
    jsonl_name = jsonl_path.stem
    
    # Define mapping relationships
    name_mapping = {
        'openhands-deepseek-v3': 'openhands-deepseek-v3',
        'moatless': 'moatless-deepseek-v3',
    }
    
    # Extract possible model names from file name
    for pattern, evaluation_name in name_mapping.items():
        if pattern in jsonl_name:
            return f"data/patch-evaluation/{evaluation_name}.json"
    
    # If no clear mapping exists, try to infer from file name
    if 'swe_smith' in jsonl_name.lower():
        return "data/patch-evaluation/swe_smith-deepseek-v3.json"
    elif 'openhands' in jsonl_name.lower():
        return "data/patch-evaluation/openhands-deepseek-v3.json"
    elif 'moatless' in jsonl_name.lower():
        return "data/patch-evaluation/moatless-deepseek-v3.json"
    elif 'deepseek' in jsonl_name.lower():
        if 'r1' in jsonl_name.lower():
            return "data/patch-evaluation/deepseek-r1.json"
        else:
            return "data/patch-evaluation/deepseek-v3.json"
    elif 'qwen2.5' in jsonl_name.lower():
        if '32b' in jsonl_name.lower():
            return "data/patch-evaluation/qwen2.5-coder-32b.json"
        elif '14b' in jsonl_name.lower():
            return "data/patch-evaluation/qwen2.5-coder-14b.json"
        elif '7b' in jsonl_name.lower():
            return "data/patch-evaluation/qwen2.5-coder-7b.json"
        else:
            return "data/patch-evaluation/qwen2.5-coder-32b.json"
    elif 'qwen3' in jsonl_name.lower():
        if '32b' in jsonl_name.lower():
            return "data/patch-evaluation/qwen3-32b.json"
        elif '14b' in jsonl_name.lower():
            return "data/patch-evaluation/qwen3-14b.json"
    elif 'mistral' in jsonl_name.lower():
        return "data/patch-evaluation/mistral.json"
    elif 'devstral' in jsonl_name.lower():
        return "data/patch-evaluation/devstral.json"
    
    # Default return deepseek-v3
    return ""

def main():
    # Generate output folder name with time accurate to seconds
    output_path = f"data/traj-evaluation"
    Path(output_path).mkdir(parents=True, exist_ok=True)

    jsonl_file_list = ['data/trajectory_original/deepseek-r1_empty_ids.jsonl', 'data/trajectory_original/deepseek-r1_resolved_ids.jsonl', 'data/trajectory_original/deepseek-r1_unresolved_ids.jsonl']
    
       
    # If jsonl_file_list is empty, traverse all jsonl files in the specified folder
    if not jsonl_file_list:
        target_folder = Path("data/trajectory_original")
        if target_folder.exists() and target_folder.is_dir():
            jsonl_files = list(target_folder.glob("*.jsonl"))
            jsonl_file_list = [str(f) for f in jsonl_files]
            if not jsonl_file_list:
                print(f"No .jsonl files found in folder {target_folder}")
                return
            print(f"Found {len(jsonl_file_list)} .jsonl files in folder {target_folder}. Please run Evaluate_Trajectory_By_Rule/split_traj_by_patch_correctness/generate-parameter-of-trajectory2report.py firstly, and specify file paths in jsonl_file_list")
        else:
            print(f"Target folder {target_folder} does not exist or is not a valid directory. Please run Evaluate_Trajectory_By_Rule/split_traj_by_patch_correctness/generate-parameter-of-trajectory2report.py firstly, and specify file paths in jsonl_file_list")
            return
    
    for jsonl_file in jsonl_file_list:
        print(f"Analyzing file: {jsonl_file}")
        
        # Get corresponding evaluation file
        evaluation_file = get_evaluation_file_for_jsonl(jsonl_file)
        print(f"Using evaluation file: {evaluation_file}")
        
        # Load classification IDs
        empty_patch_ids, resolved_ids, unresolved_ids = load_classification_ids(evaluation_file)
        print(f"Loaded classifications: empty_patch({len(empty_patch_ids)}), resolved({len(resolved_ids)}), unresolved({len(unresolved_ids)})")
        
        generate_metrics_report(jsonl_file, Path(output_path), empty_patch_ids, resolved_ids, unresolved_ids)
        print("\n" + "="*50 + "\n")

if __name__ == "__main__":
    main()
