import json
import argparse
import os
import re
import glob
from typing import List, Dict, Any


def process_system_message(system_content: str) -> str:
    """
    Process the system message to extract content before policy document and add reference.
    """
    # Find the policy document pattern
    policy_pattern = r'\n\n# Agent Policy Document #P\d+\n\n## General Instructions'
    match = re.search(policy_pattern, system_content)
    
    if match:
        # Extract everything before the policy document
        base_content = system_content[:match.start()]
        # Extract the policy number
        policy_num_match = re.search(r'#P(\d+)', system_content)
        if policy_num_match:
            policy_num = policy_num_match.group(1)
            return base_content + f"\n\nYou should follow the knowledge you learnt from Policy Document #P{policy_num}."
        else:
            return base_content + "\n\nYou should follow the knowledge you learnt from Policy Document."
    else:
        return system_content


def extract_task_instructions(system_content: str, task_type: str):
    """
    Extract task-specific instructions from system content and policy ID.
    """
    # Extract policy ID
    policy_id_match = re.search(r'# Agent Policy Document #P(\d+)', system_content)
    policy_id = policy_id_match.group(1) if policy_id_match else "unknown"
    
    # Find the task section
    task_pattern = f'### Task_Type_{task_type}\\n(.*?)\\n- The agent should call the finish_task_{task_type} tool with arguments from one instance per layer at a time\\.'
    match = re.search(task_pattern, system_content, re.DOTALL)
    
    if match:
        task_instructions = f"### Task_Type_{task_type}\n" + match.group(1) + "\n- The agent should call the finish_task_{task_type} tool with arguments from one instance per layer at a time."
        return policy_id, task_instructions
    else:
        return policy_id, ""


def extract_task_type_from_user_message(user_message: str) -> str:
    """
    Extract task type from user's message (e.g., "I want to perform task_type_5" -> "5").
    """
    task_match = re.search(r'I want to perform task_type_(\d+)', user_message, re.IGNORECASE)
    return task_match.group(1) if task_match else ""


def process_conversation_messages(conversation_messages: List[Dict[str, str]], system_content: str) -> List[Dict[str, str]]:
    """
    Process conversation messages to change roles according to requirements.
    """
    processed_conversations = []
    user_count = 0
    first_user_message = ""
    task_type = ""
    
    # First pass: find the first user message to extract task type
    for message in conversation_messages:
        if message['role'] == 'user':
            first_user_message = message['content']
            task_type = extract_task_type_from_user_message(first_user_message)
            break
    
    # Extract task instructions if task type is found
    policy_id, task_instructions = "", ""
    if task_type:
        policy_id, task_instructions = extract_task_instructions(system_content, task_type)
    
    assistant_count = 0
    
    for message in conversation_messages:
        role = message['role']
        content = message['content']
        
        if role == 'system':
            # Skip system messages in conversations as they're handled separately
            continue
        elif role == 'assistant':
            # For the first assistant message, handle task instructions and <think> block
            if assistant_count == 0 and task_instructions:
                # Check if the content starts with a <think> block
                think_pattern = r'^<think>(.*?)</think>\s*(.*)'
                think_match = re.match(think_pattern, content, re.DOTALL)
                
                if think_match:
                    # Extract think block content and remaining content
                    think_content = think_match.group(1)
                    remaining_content = think_match.group(2)
                    
                    # Construct enhanced content: <think> + task instructions + terminology/tool usage details + original think content + </think> + remaining content
                    enhanced_content = f"<think>\nBefore start, let me quickly recall the rules for compute task_type_{task_type} of policy #P{policy_id}:\n\n{task_instructions}\n\nI should also recall some terminology specification and tool usage details:\n\n**Relative Profile Access:**\nWhen the user specifies getting a 'relative profile' or 'related profile', this means accessing other profile instances at the same layer as the current profile. To accomplish this, you should use the reference attributes from the current profile instance to find the primary keys of the target profile instances at the same layer. For example, if you are currently accessing a profile at layer 2, and the user asks for a relative profile, you should use the reference attributes in the current layer 2 profile to identify and access other layer 2 profile instances.\n\n#### Profile Access Tools\n- **Get_Profile_Layer_k**: Use this tool to directly access a specific profile instance by its primary key.\n  - **Parameter**: `index_value` (string) - The full primary key of the profile instance (e.g., \"profile_1_5\", \"profile_2_10\", \"profile_3_1\")\n  - **When to use**: \n    - When users specify a profile_id, such as \"my profile_id is profile_1_5\" or \"using profile_2_3\"\n    - When you obtain a reference attribute value from another profile instance that contains the primary key to access a different layer\n  - **Example call**: Get_Profile_Layer_1(index_value=\"profile_1_5\")\n\n- **Search_Profile_Layer_k**: Use this tool to find profile instances by their lookup attribute value.\n  - **Parameter**: `key_value` (string) - The lookup attribute value to search for\n  - **When to use**: When users specify a profile_info, such as \"my profile_info is 'engineering'\" or \"find profiles with lookup value 'sales'\"\n  - **Example call**: Search_Profile_Layer_1(key_value=\"engineering\")\n\n#### Task Completion Tools\n- **finish_task_k**: Use this tool to complete Task_Type_k with the computed arguments.\n  - **Parameter**: `attributes` (list) - A list of computed argument values in the order specified by the task requirements\n  - **When to use**: After accessing all required profile instances and computing the task arguments according to task specifications\n  - **Example call**: finish_task_1(attributes=[25, 150, 42])\n\n{think_content}\n</think>\n{remaining_content}"
                else:
                    # No think block, just prepend task instructions and terminology details
                    enhanced_content = f"Before start, let me quickly recall the rules for compute task_type_{task_type} of policy #P{policy_id}:\n\n{task_instructions}\n\nI should also recall some terminology specification and tool usage details:\n\n**Relative Profile Access:**\nWhen the user specifies getting a 'relative profile' or 'related profile', this means accessing other profile instances at the same layer as the current profile. To accomplish this, you should use the reference attributes from the current profile instance to find the primary keys of the target profile instances at the same layer. For example, if you are currently accessing a profile at layer 2, and the user asks for a relative profile, you should use the reference attributes in the current layer 2 profile to identify and access other layer 2 profile instances.\n\n#### Profile Access Tools\n- **Get_Profile_Layer_k**: Use this tool to directly access a specific profile instance by its primary key.\n  - **Parameter**: `index_value` (string) - The full primary key of the profile instance (e.g., \"profile_1_5\", \"profile_2_10\", \"profile_3_1\")\n  - **When to use**: \n    - When users specify a profile_id, such as \"my profile_id is profile_1_5\" or \"using profile_2_3\"\n    - When you obtain a reference attribute value from another profile instance that contains the primary key to access a different layer\n  - **Example call**: Get_Profile_Layer_1(index_value=\"profile_1_5\")\n\n- **Search_Profile_Layer_k**: Use this tool to find profile instances by their lookup attribute value.\n  - **Parameter**: `key_value` (string) - The lookup attribute value to search for\n  - **When to use**: When users specify a profile_info, such as \"my profile_info is 'engineering'\" or \"find profiles with lookup value 'sales'\"\n  - **Example call**: Search_Profile_Layer_1(key_value=\"engineering\")\n\n#### Task Completion Tools\n- **finish_task_k**: Use this tool to complete Task_Type_k with the computed arguments.\n  - **Parameter**: `attributes` (list) - A list of computed argument values in the order specified by the task requirements\n  - **When to use**: After accessing all required profile instances and computing the task arguments according to task specifications\n  - **Example call**: finish_task_1(attributes=[25, 150, 42])\n\n{content}"
                
                processed_conversations.append({
                    'from': 'gpt',
                    'value': enhanced_content
                })
            else:
                processed_conversations.append({
                    'from': 'gpt',
                    'value': content
                })
            assistant_count += 1
        elif role == 'user':
            # First user becomes human, remaining users become observation
            if user_count == 0:
                processed_conversations.append({
                    'from': 'human',
                    'value': content
                })
                user_count += 1
            else:
                processed_conversations.append({
                    'from': 'observation',
                    'value': content
                })
    
    # Remove the last observation if any exist
    observation_indices = []
    for i, msg in enumerate(processed_conversations):
        if msg['from'] == 'observation':
            observation_indices.append(i)
    
    if observation_indices:
        # Remove the last observation
        last_observation_index = observation_indices[-1]
        processed_conversations.pop(last_observation_index)
    
    return processed_conversations


def process_evaluation_data(data_root_paths: List[str]) -> List[Dict[str, Any]]:
    """
    Process evaluation data from multiple files and extract all correct CoT examples.
    """
    processed_data = []
    total_files = len(data_root_paths)
    
    for file_idx, data_path in enumerate(data_root_paths, 1):
        print(f"Processing file {file_idx}/{total_files}: {data_path}")
        
        try:
            with open(data_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except Exception as e:
            print(f"Error reading file {data_path}: {e}")
            continue
        
        file_correct_count = 0
        
        for item in data:
            # Check if the trajectory is complete and correct
            if ('detailed_trajectory' in item and 
                'final_state' in item['detailed_trajectory'] and 
                item['detailed_trajectory']['final_state'] == 'completed_correct'):
                
                if 'conversation_messages' in item:
                    conversation_messages = item['conversation_messages']
                    
                    # Extract system message
                    system_message = None
                    for msg in conversation_messages:
                        if msg['role'] == 'system':
                            system_message = process_system_message(msg['content'])
                            break
                    
                    if system_message:
                        # Process conversations (pass original system content for task extraction)
                        original_system_content = None
                        for msg in conversation_messages:
                            if msg['role'] == 'system':
                                original_system_content = msg['content']
                                break
                        
                        processed_conversations = process_conversation_messages(conversation_messages, original_system_content or "")
                        
                        if processed_conversations:
                            processed_item = {
                                'conversations': processed_conversations,
                                'system': system_message
                            }
                            processed_data.append(processed_item)
                            file_correct_count += 1
        
        print(f"Found {file_correct_count} correct examples in file {file_idx}")
    
    return processed_data


def update_dataset_info(output_root_path: str, output_dataset_name: str):
    """
    Update dataset_info.json with the new dataset registration.
    """
    dataset_info_path = os.path.join(output_root_path, 'dataset_info.json')
    
    # Load existing dataset_info.json
    if os.path.exists(dataset_info_path):
        with open(dataset_info_path, 'r', encoding='utf-8') as f:
            dataset_info = json.load(f)
    else:
        dataset_info = {}
    
    # Add new dataset entry
    dataset_info[output_dataset_name] = {
        "file_name": f"{output_dataset_name}.json",
        "formatting": "sharegpt",
        "columns": {
            "messages": "conversations",
            "system": "system"
        }
    }
    
    # Save updated dataset_info.json
    with open(dataset_info_path, 'w', encoding='utf-8') as f:
        json.dump(dataset_info, f, indent=2, ensure_ascii=False)


def main():
    parser = argparse.ArgumentParser(description='Convert evaluation results to CoT training data')
    parser.add_argument('--data_root_path', type=str, required=True, 
                       help='Path pattern to the evaluation results JSON files (supports glob)')
    parser.add_argument('--output_dataset_name', type=str, required=True,
                       help='Name of the output dataset')
    parser.add_argument('--output_root_path', type=str, required=True,
                       help='Root path where to save the output dataset')
    
    args = parser.parse_args()
    
    # Ensure output directory exists
    os.makedirs(args.output_root_path, exist_ok=True)
    
    # Get all matching files
    data_files = glob.glob(args.data_root_path)
    if not data_files:
        print(f"No files found matching pattern: {args.data_root_path}")
        return
    
    print(f"Found {len(data_files)} files to process")
    
    # Process the evaluation data
    processed_data = process_evaluation_data(data_files)
    
    print(f"Successfully processed {len(processed_data)} total correct examples across all files")
    
    # Save the processed dataset
    output_file_path = os.path.join(args.output_root_path, f"{args.output_dataset_name}.json")
    with open(output_file_path, 'w', encoding='utf-8') as f:
        json.dump(processed_data, f, indent=2, ensure_ascii=False)
    
    print(f"Saved dataset to: {output_file_path}")
    
    # Update dataset_info.json
    update_dataset_info(args.output_root_path, args.output_dataset_name)
    print(f"Updated dataset_info.json with entry: {args.output_dataset_name}")
    
    print("Processing completed successfully!")


if __name__ == "__main__":
    main()
