import json
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Tuple, Set, Set


def load_classification_ids(evaluation_file: str) -> Tuple[Set[str], Set[str], Set[str]]:
    """
    Load classification IDs from patch evaluation file
    
    Args:
        evaluation_file: Path to patch evaluation JSON file
        
    Returns:
        Tuple[Set[str], Set[str], Set[str]]: (empty_patch_ids, resolved_ids, unresolved_ids)
    """
    try:
        with open(evaluation_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        empty_patch_ids = set(data.get('empty_patch_ids', []))
        resolved_ids = set(data.get('resolved_ids', []))
        unresolved_ids = set(data.get('unresolved_ids', []))
        
        return empty_patch_ids, resolved_ids, unresolved_ids
    except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
        print(f"Warning: Unable to load classification file {evaluation_file}: {e}")
        return set(), set(), set()


def detect_stuck_in_loop(messages: List[Dict]) -> bool:
    """
    Detect if trajectory is stuck in a loop (using bucket sort to count assistant's content, 
    if the same sentence appears >= 3 times, consider it stuck in loop)
    
    Args:
        messages: List of messages
        
    Returns:
        bool: Whether stuck in loop
    """
    if len(messages) < 3:  # Need at least 3 messages to detect loop
        return False
    
    # Use bucket sort to count assistant role message content
    assistant_content_count = {}
    
    for msg in messages:
        if msg.get('role') == 'assistant':
            content = msg.get('content', '').strip()
            if content:  # Ensure it's not an empty string
                # Count occurrences of each content
                assistant_content_count[content] = assistant_content_count.get(content, 0) + 1
    
    # Check if any content appears >= 3 times
    for content, count in assistant_content_count.items():
        if count >= 3:
            return True
    
    return False

def count_action_observation_turns(messages: List[Dict]) -> int:
    """
    Count action-observation turns (based on number of assistant replies)
    
    Args:
        messages: List of messages
        
    Returns:
        int: Number of action-observation turns
    """
    if not messages:
        return 0
    
    # Count assistant role messages (usually represents turns)
    turn_count = sum(1 for msg in messages if msg.get('role') == 'assistant')
    return turn_count

def calculate_tool_call_accuracy(messages: List[Dict]) -> Tuple[int, int, float]:
    """
    Calculate tool call accuracy
    
    Args:
        messages: List of messages
        
    Returns:
        Tuple[int, int, float]: (total tool calls, correct tool calls, accuracy percentage)
    """
    total_tool_calls = 0
    correct_tool_calls = 0
    
    # If first method found no tool calls, use second method
    if total_tool_calls == 0 or correct_tool_calls == 0:
        # Second method: Check if user role replies contain error information
        error_patterns = [
            'command not found',
            'No such',
            'Permission denied',
            'error',
            'Error',
            'ERROR',
            'failed',
            'Failed',
            'FAILED'
        ]
        
        user_messages = [msg for msg in messages if (msg.get('role') == 'user' or msg.get('role') == 'tool')]
        total_tool_calls = len(user_messages)-1
        correct_tool_calls = 0
        
        for msg in user_messages:
            content = msg.get('content', '').lower()
            has_error = any(pattern.lower() in content for pattern in error_patterns)
            if not has_error:  # No error information means tool call is correct
                correct_tool_calls += 1
    
    print(f"Total tool calls: {total_tool_calls}, Correct tool calls: {correct_tool_calls}")

    accuracy = (correct_tool_calls / total_tool_calls * 100) if total_tool_calls > 0 else 0
    return total_tool_calls, correct_tool_calls, accuracy

def extract_usage_info(data: Dict) -> Dict:
    """
    Extract usage information from JSON data, supporting multiple nested structures
    
    Args:
        data: JSON data dictionary
        
    Returns:
        Extracted usage information dictionary
    """
    usage = {}
    
    # Look for usage directly at root level
    if 'usage' in data:
        usage = data['usage']
    
    # Look for usage in fncall_response
    elif 'fncall_response' in data and isinstance(data['fncall_response'], dict):
        if 'usage' in data['fncall_response']:
            usage = data['fncall_response']['usage']
    
    # Look for usage in messages (some formats might put usage here)
    elif 'messages' in data and isinstance(data['messages'], list):
        for message in data['messages']:
            if isinstance(message, dict) and 'usage' in message:
                usage = message['usage']
                break
    
    return usage if isinstance(usage, dict) else {}

def sum_usage_values(usage1: Dict, usage2: Dict) -> Dict:
    """
    Accumulate values from two usage dictionaries
    
    Args:
        usage1: First usage dictionary
        usage2: Second usage dictionary
        
    Returns:
        Accumulated usage dictionary
    """
    result = usage1.copy()
    
    for key, value in usage2.items():
        if key in result:
            # If numeric type, accumulate
            if isinstance(value, (int, float)) and isinstance(result[key], (int, float)):
                result[key] += value
            # If dictionary type, process recursively
            elif isinstance(value, dict) and isinstance(result[key], dict):
                result[key] = sum_usage_values(result[key], value)
            else:
                # For other types, use the latter value
                result[key] = value
        else:
            result[key] = value
    
    return result

def find_all_json_files(folder_path: Path) -> List[Path]:
    """
    Find all JSON files in specified folder
    
    Args:
        folder_path: Folder path
        
    Returns:
        List of all JSON file paths, sorted by filename
    """
    if not folder_path.exists() or not folder_path.is_dir():
        return []
    
    # Find all JSON files
    json_files = list(folder_path.glob("*.json"))
    
    # Sort by filename
    json_files.sort(key=lambda x: x.name)
    return json_files

def calculate_instance_token_usage(instance_id: str, base_dir: str = "OpenHands/evaluation/evaluation_outputs/outputs/princeton-nlp__SWE-bench_Lite-test/CodeActAgent/deepseek-v3_maxiter_30_N_v0.42.0-no-hint-run_1/llm_completions") -> Tuple[int, int]:
    """
    Calculate token usage for specific instance from OpenHands directory
    
    Args:
        instance_id: instance ID
        base_dir: Base path of OpenHands output directory
        
    Returns:
        Tuple[int, int]: (prompt_tokens, completion_tokens)
    """
    base_path = Path(base_dir)
    instance_folder = base_path / instance_id
    
    if not instance_folder.exists() or not instance_folder.is_dir():
        print(f"Warning: Instance folder {instance_folder} not found")
        return 0, 0
    
    json_files = find_all_json_files(instance_folder)
    
    if not json_files:
        print(f"Warning: No JSON files found in {instance_folder}")
        return 0, 0
    
    total_usage = {}
    
    for json_file in json_files:
        try:
            with open(json_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
            
            # Extract usage information
            usage = extract_usage_info(data)
            
            if usage:
                total_usage = sum_usage_values(total_usage, usage)
                        
        except (json.JSONDecodeError, FileNotFoundError, KeyError) as e:
            print(f"Warning: Error processing file {json_file}: {e}")
            continue
    
    prompt_tokens = total_usage.get('prompt_tokens', 0)
    completion_tokens = total_usage.get('completion_tokens', 0)
    
    return prompt_tokens, completion_tokens

def analyze_trajectory(jsonl_file: str, empty_patch_ids: Set[str], resolved_ids: Set[str], unresolved_ids: Set[str]) -> Tuple[int, Dict, List[Dict]]:
    """
    Analyze loop detection and turn statistics for each JSON object in JSONL file, with classification statistics
    
    Args:
        jsonl_file: JSONL file path, each line is an independent JSON object
        empty_patch_ids: Set of empty patch IDs
        resolved_ids: Set of resolved IDs
        unresolved_ids: Set of unresolved IDs
        
    Returns:
        Tuple[int, Dict, List[Dict]]: (total JSON count, overall statistics, detailed statistics for each JSON)
    """
    total_jsons = 0
    total_turns = 0
    stuck_in_loop_count = 0
    total_tool_calls = 0
    total_correct_tool_calls = 0
    total_prompt_tokens = 0
    total_completion_tokens = 0
    json_stats = []
    
    # Classification statistics
    category_stats = {
        'empty_patch': {'count': 0, 'prompt_tokens': 0, 'completion_tokens': 0, 'turns': 0, 'stuck_in_loop': 0, 'tool_calls': 0, 'correct_tool_calls': 0},
        'resolved': {'count': 0, 'prompt_tokens': 0, 'completion_tokens': 0, 'turns': 0, 'stuck_in_loop': 0, 'tool_calls': 0, 'correct_tool_calls': 0},
        'unresolved': {'count': 0, 'prompt_tokens': 0, 'completion_tokens': 0, 'turns': 0, 'stuck_in_loop': 0, 'tool_calls': 0, 'correct_tool_calls': 0},
        'other': {'count': 0, 'prompt_tokens': 0, 'completion_tokens': 0, 'turns': 0, 'stuck_in_loop': 0, 'tool_calls': 0, 'correct_tool_calls': 0}
    }
    
    with open(jsonl_file, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            if not line.strip():  # Skip empty lines
                continue
            
            # Try to parse JSON to get information
            try:
                data = json.loads(line)
                
                # Get messages for analysis
                messages = data.get('messages', [])
                
                # Loop detection and turn statistics
                is_stuck_in_loop = detect_stuck_in_loop(messages)
                turns = count_action_observation_turns(messages)
                tool_calls, correct_tool_calls, tool_accuracy = calculate_tool_call_accuracy(messages)
                
                # Get other useful information
                instance_id = data.get('instance_id', f'line_{line_num}')
                messages_count = len(messages)
                
                # Token usage statistics - get from OpenHands directory
                prompt_tokens, completion_tokens = calculate_instance_token_usage(instance_id)
                
                total_jsons += 1
                total_turns += turns
                total_tool_calls += tool_calls
                total_correct_tool_calls += correct_tool_calls
                total_prompt_tokens += prompt_tokens
                total_completion_tokens += completion_tokens
                if is_stuck_in_loop:
                    stuck_in_loop_count += 1
                
                # Determine classification
                if instance_id in empty_patch_ids:
                    category = 'empty_patch'
                elif instance_id in resolved_ids:
                    category = 'resolved'
                elif instance_id in unresolved_ids:
                    category = 'unresolved'
                else:
                    category = 'other'
                
                # Update classification statistics
                category_stats[category]['count'] += 1
                category_stats[category]['prompt_tokens'] += prompt_tokens
                category_stats[category]['completion_tokens'] += completion_tokens
                category_stats[category]['turns'] += turns
                category_stats[category]['tool_calls'] += tool_calls
                category_stats[category]['correct_tool_calls'] += correct_tool_calls
                if is_stuck_in_loop:
                    category_stats[category]['stuck_in_loop'] += 1
                
                json_stat = {
                    "line_number": line_num,
                    "instance_id": instance_id,
                    "category": category,
                    "messages_count": messages_count,
                    "turns": turns,
                    "stuck_in_loop": is_stuck_in_loop,
                    "tool_calls": tool_calls,
                    "correct_tool_calls": correct_tool_calls,
                    "tool_call_accuracy": round(tool_accuracy, 2),
                    "prompt_tokens": prompt_tokens,
                    "completion_tokens": completion_tokens
                }
                
            except json.JSONDecodeError:
                # If JSON parsing fails, record error information
                json_stat = {
                    "line_number": line_num,
                    "instance_id": f'parse_error_line_{line_num}',
                    "category": 'other',
                    "messages_count": 0,
                    "turns": 0,
                    "stuck_in_loop": False,
                    "tool_calls": 0,
                    "correct_tool_calls": 0,
                    "tool_call_accuracy": 0,
                    "prompt_tokens": 0,
                    "completion_tokens": 0,
                    "parse_error": True
                }
                total_jsons += 1
                category_stats['other']['count'] += 1
            
            json_stats.append(json_stat)
    
    # Calculate overall statistics
    stuck_in_loop_percentage = (stuck_in_loop_count / total_jsons * 100) if total_jsons > 0 else 0
    avg_turns = total_turns / total_jsons if total_jsons > 0 else 0
    overall_tool_accuracy = (total_correct_tool_calls / total_tool_calls * 100) if total_tool_calls > 0 else 0
    avg_prompt_tokens = total_prompt_tokens / total_jsons if total_jsons > 0 else 0
    avg_completion_tokens = total_completion_tokens / total_jsons if total_jsons > 0 else 0
    
    total_stats = {
        "stuck_in_loop_count": stuck_in_loop_count,
        "stuck_in_loop_percentage": stuck_in_loop_percentage,
        "total_turns": total_turns,
        "avg_turns": avg_turns,
        "total_tool_calls": total_tool_calls,
        "total_correct_tool_calls": total_correct_tool_calls,
        "overall_tool_accuracy": overall_tool_accuracy,
        "total_prompt_tokens": total_prompt_tokens,
        "total_completion_tokens": total_completion_tokens,
        "avg_prompt_tokens": avg_prompt_tokens,
        "avg_completion_tokens": avg_completion_tokens,
        "category_stats": category_stats
    }
    
    return total_jsons, total_stats, json_stats

def generate_metrics_report(jsonl_file: str, output_dir, empty_patch_ids: Set[str], resolved_ids: Set[str], unresolved_ids: Set[str]):
    # Input file
    total_jsons, total_stats, json_stats = analyze_trajectory(jsonl_file, empty_patch_ids, resolved_ids, unresolved_ids)
    
    # Generate output filename, consistent with input filename, add timestamp
    input_filename = Path(jsonl_file).stem  # Get input filename (without extension)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_file = output_dir / f"{input_filename}.json"
    
    # Calculate sorting and extremes - simplified version, only keep necessary calculations
    valid_stats = [stat for stat in json_stats if not stat.get('parse_error', False)]
    
    # Group data by classification
    categories = ['empty_patch', 'resolved', 'unresolved', 'other']
    category_data = {}
    
    for category in categories:
        category_instances = [stat for stat in valid_stats if stat.get('category') == category]
        if category_instances:
            turns_list = [stat["turns"] for stat in category_instances]
            stuck_in_loop_list = [1 if stat["stuck_in_loop"] else 0 for stat in category_instances]
            tool_accuracy_list = [stat["tool_call_accuracy"] for stat in category_instances]
            prompt_tokens_list = [stat["prompt_tokens"] for stat in category_instances]
            completion_tokens_list = [stat["completion_tokens"] for stat in category_instances]
            
            category_data[category] = {
                "instance_ids": [stat["instance_id"] for stat in category_instances],
                "turns_list": turns_list,
                "stuck_in_loop_list": stuck_in_loop_list,
                "tool_accuracy_list": tool_accuracy_list,
                "prompt_tokens_list": prompt_tokens_list,
                "completion_tokens_list": completion_tokens_list,
                "averages": {
                    "avg_turns": round(sum(turns_list) / len(turns_list), 2) if turns_list else 0,
                    "avg_stuck_in_loop_rate": round(sum(stuck_in_loop_list) / len(stuck_in_loop_list) * 100, 2) if stuck_in_loop_list else 0,
                    "avg_tool_accuracy": round(sum(tool_accuracy_list) / len(tool_accuracy_list), 2) if tool_accuracy_list else 0,
                    "avg_prompt_tokens": round(sum(prompt_tokens_list) / len(prompt_tokens_list), 2) if prompt_tokens_list else 0,
                    "avg_completion_tokens": round(sum(completion_tokens_list) / len(completion_tokens_list), 2) if completion_tokens_list else 0
                }
            }
        else:
            category_data[category] = {
                "instance_ids": [],
                "turns_list": [],
                "stuck_in_loop_list": [],
                "tool_accuracy_list": [],
                "prompt_tokens_list": [],
                "completion_tokens_list": [],
                "averages": {
                    "avg_turns": 0,
                    "avg_stuck_in_loop_rate": 0,
                    "avg_tool_accuracy": 0,
                    "avg_prompt_tokens": 0,
                    "avg_completion_tokens": 0
                }
            }
    
    # Prepare simplified output data - change to simple list format
    instance_ids = []
    turns_list = []
    stuck_in_loop_list = []
    tool_accuracy_list = []
    prompt_tokens_list = []
    completion_tokens_list = []
    
    for stat in valid_stats:
        instance_ids.append(stat["instance_id"])
        turns_list.append(stat["turns"])
        stuck_in_loop_list.append(1 if stat["stuck_in_loop"] else 0)  # Convert to numeric
        tool_accuracy_list.append(stat["tool_call_accuracy"])
        prompt_tokens_list.append(stat["prompt_tokens"])
        completion_tokens_list.append(stat["completion_tokens"])
    
    stats = {
        "summary": {
            "total_instances": len(valid_stats),
            "avg_turns": round(total_stats["avg_turns"], 2),
            "stuck_in_loop_percentage": round(total_stats["stuck_in_loop_percentage"], 2),
            "overall_tool_accuracy": round(total_stats["overall_tool_accuracy"], 2),
            "total_prompt_tokens": total_stats["total_prompt_tokens"],
            "total_completion_tokens": total_stats["total_completion_tokens"],
            "avg_prompt_tokens": round(total_stats["avg_prompt_tokens"], 2),
            "avg_completion_tokens": round(total_stats["avg_completion_tokens"], 2),
            "classification_counts": {cat: total_stats["category_stats"][cat]["count"] for cat in categories}
        },
        "instance_ids": instance_ids,
        "turns_list": turns_list,
        "stuck_in_loop_list": stuck_in_loop_list,
        "tool_accuracy_list": tool_accuracy_list,
        "prompt_tokens_list": prompt_tokens_list,
        "completion_tokens_list": completion_tokens_list,
        "category_breakdown": category_data,
        "timestamp": datetime.now().isoformat()
    }
    
    # Write to JSON file
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(stats, f, indent=4, ensure_ascii=False)
    
    # Print simplified results
    print(f"JSONL file trajectory analysis results - Key metrics statistics:")
    print(f"Total instances: {len(valid_stats)}")
    print(f"")
    print(f"Classification statistics:")
    for category in categories:
        cat_stats = total_stats["category_stats"][category]
        count = cat_stats["count"]
        if count > 0:
            print(f"  {category}: {count} instances")
            print(f"    Average turns: {cat_stats['turns']/count:.2f}")
            print(f"    Loop rate: {cat_stats['stuck_in_loop']/count*100:.2f}%")
            print(f"    Tool accuracy: {cat_stats['correct_tool_calls']/cat_stats['tool_calls']*100 if cat_stats['tool_calls'] > 0 else 0:.2f}%")
            print(f"    Average input tokens: {cat_stats['prompt_tokens']/count:.2f}")
            print(f"    Average output tokens: {cat_stats['completion_tokens']/count:.2f}")
    print(f"")
    print(f"Key metrics averages:")
    print(f"  Average turns (Turns): {total_stats['avg_turns']:.2f}")
    print(f"  Loop percentage (% Stuck in Loop): {total_stats['stuck_in_loop_percentage']:.2f}%")
    print(f"  Tool call accuracy (% Tool Call Suc.): {total_stats['overall_tool_accuracy']:.2f}%")
    print(f"  Average input tokens (Avg Prompt Tokens): {total_stats['avg_prompt_tokens']:.2f}")
    print(f"  Average output tokens (Avg Completion Tokens): {total_stats['avg_completion_tokens']:.2f}")
    print(f"  Total input tokens (Total Prompt Tokens): {total_stats['total_prompt_tokens']}")
    print(f"  Total output tokens (Total Completion Tokens): {total_stats['total_completion_tokens']}")
    print(f"")
    print(f"Detailed statistics saved to: {output_file}")

def get_evaluation_file_for_jsonl(jsonl_file: str) -> str:
    """
    Infer corresponding evaluation filename based on JSONL filename
    """
    jsonl_path = Path(jsonl_file)
    jsonl_name = jsonl_path.stem
    
    # Define mapping relationships
    name_mapping = {
        'openhands-deepseek-v3': 'openhands-deepseek-v3'
    }
    
    # Extract possible model name from filename
    for pattern, evaluation_name in name_mapping.items():
        if pattern in jsonl_name:
            return f"data/patch-evaluation/{evaluation_name}.json"
    
    # If no explicit mapping, try to infer based on filename
    if 'openhands' in jsonl_name.lower():
        return "data/patch-evaluation/openhands-deepseek-v3.json"
    elif 'moatless' in jsonl_name.lower():
        return "data/patch-evaluation/moatless-deepseek-v3.json"
    elif 'deepseek' in jsonl_name.lower():
        if 'r1' in jsonl_name.lower():
            return "data/patch-evaluation/deepseek-r1.json"
        else:
            return "data/patch-evaluation/deepseek-v3.json"
    elif 'qwen2.5' in jsonl_name.lower():
        if '32b' in jsonl_name.lower():
            return "data/patch-evaluation/qwen2.5-coder-32b.json"
        elif '14b' in jsonl_name.lower():
            return "data/patch-evaluation/qwen2.5-coder-14b.json"
        elif '7b' in jsonl_name.lower():
            return "data/patch-evaluation/qwen2.5-coder-7b.json"
        else:
            return "data/patch-evaluation/qwen2.5-coder-32b.json"
    elif 'qwen3' in jsonl_name.lower():
        if '32b' in jsonl_name.lower():
            return "data/patch-evaluation/qwen3-32b.json"
        elif '14b' in jsonl_name.lower():
            return "data/patch-evaluation/qwen3-14b.json"
    elif 'mistral' in jsonl_name.lower():
        return "data/patch-evaluation/mistral.json"
    elif 'devstral' in jsonl_name.lower():
        return "data/patch-evaluation/devstral.json"
    
    # Default return deepseek-v3
    return ""

def main():
    # Generate output folder name with time accurate to seconds
    output_path = f"data/traj-evaluation"
    Path(output_path).mkdir(parents=True, exist_ok=True)

    jsonl_file_list = ['data/trajectory_original/openhands-deepseek-v3.json']
     
    # If jsonl_file_list is empty, traverse all jsonl files in specified folder
    if not jsonl_file_list:
        target_folder = Path("data/trajectory_original")
        if target_folder.exists() and target_folder.is_dir():
            jsonl_files = list(target_folder.glob("*.jsonl"))
            jsonl_file_list = [str(f) for f in jsonl_files]
            if not jsonl_file_list:
                print(f"No .jsonl files found in folder {target_folder}. Please run Evaluate_Trajectory_By_Rule/split_traj_by_patch_correctness/generate-parameter-of-trajectory2report.py firstly, and specify file paths in jsonl_file_list")
                return
            print(f"Found {len(jsonl_file_list)} .jsonl files in folder {target_folder}")
        else:
            print(f"Target folder {target_folder} does not exist or is not a valid directory. Please run Evaluate_Trajectory_By_Rule/split_traj_by_patch_correctness/generate-parameter-of-trajectory2report.py firstly, and specify file paths in jsonl_file_list")
            return
    
    for jsonl_file in jsonl_file_list:
        print(f"Analyzing file: {jsonl_file}")
        
        # Get corresponding evaluation file
        evaluation_file = get_evaluation_file_for_jsonl(jsonl_file)
        print(f"Using evaluation file: {evaluation_file}")
        
        # Load classification IDs
        empty_patch_ids, resolved_ids, unresolved_ids = load_classification_ids(evaluation_file)
        print(f"Loaded classifications: empty_patch({len(empty_patch_ids)}), resolved({len(resolved_ids)}), unresolved({len(unresolved_ids)})")
        
        generate_metrics_report(jsonl_file, Path(output_path), empty_patch_ids, resolved_ids, unresolved_ids)
        print("\n" + "="*50 + "\n")

if __name__ == "__main__":
    main()
