import json
import argparse
import os
from typing import List, Dict, Any


def process_conversation_messages(conversation_messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
    """
    Process conversation messages to change roles according to requirements.
    Keep all content exactly as it was in the original evaluation.json file.
    """
    processed_conversations = []
    user_count = 0
    
    for message in conversation_messages:
        role = message['role']
        content = message['content']
        
        if role == 'system':
            # Skip system messages in conversations as they're handled separately
            continue
        elif role == 'assistant':
            # Keep assistant content exactly as is
            processed_conversations.append({
                'from': 'gpt',
                'value': content
            })
        elif role == 'user':
            # First user becomes human, remaining users become observation
            if user_count == 0:
                processed_conversations.append({
                    'from': 'human',
                    'value': content
                })
                user_count += 1
            else:
                processed_conversations.append({
                    'from': 'observation',
                    'value': content
                })
    
    # Remove the last observation if any exist
    observation_indices = []
    for i, msg in enumerate(processed_conversations):
        if msg['from'] == 'observation':
            observation_indices.append(i)
    
    if observation_indices:
        # Remove the last observation
        last_observation_index = observation_indices[-1]
        processed_conversations.pop(last_observation_index)
    
    return processed_conversations


def process_evaluation_data(data_root_path: str, dataset_size: int) -> List[Dict[str, Any]]:
    """
    Process evaluation data and extract correct CoT examples.
    Keep all original content unchanged.
    """
    with open(data_root_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    processed_data = []
    correct_count = 0
    
    for item in data:
        # Check if the trajectory is complete and correct
        if ('detailed_trajectory' in item and 
            'final_state' in item['detailed_trajectory'] and 
            item['detailed_trajectory']['final_state'] == 'completed_correct'):
            
            if 'conversation_messages' in item:
                conversation_messages = item['conversation_messages']
                
                # Extract system message without any modifications
                system_message = None
                for msg in conversation_messages:
                    if msg['role'] == 'system':
                        system_message = msg['content']  # Keep original content unchanged
                        break
                
                if system_message:
                    # Process conversations without any content enhancement
                    processed_conversations = process_conversation_messages(conversation_messages)
                    
                    if processed_conversations:
                        processed_item = {
                            'conversations': processed_conversations,
                            'system': system_message
                        }
                        processed_data.append(processed_item)
                        correct_count += 1
                        
                        # Stop when we reach the desired dataset size
                        if correct_count >= dataset_size:
                            break
    
    return processed_data


def update_dataset_info(output_root_path: str, output_dataset_name: str):
    """
    Update dataset_info.json with the new dataset registration.
    """
    dataset_info_path = os.path.join(output_root_path, 'dataset_info.json')
    
    # Load existing dataset_info.json
    if os.path.exists(dataset_info_path):
        with open(dataset_info_path, 'r', encoding='utf-8') as f:
            dataset_info = json.load(f)
    else:
        dataset_info = {}
    
    # Add new dataset entry
    dataset_info[output_dataset_name] = {
        "file_name": f"{output_dataset_name}.json",
        "formatting": "sharegpt",
        "columns": {
            "messages": "conversations",
            "system": "system"
        }
    }
    
    # Save updated dataset_info.json
    with open(dataset_info_path, 'w', encoding='utf-8') as f:
        json.dump(dataset_info, f, indent=2, ensure_ascii=False)


def main():
    parser = argparse.ArgumentParser(description='Convert evaluation results to CoT training data')
    parser.add_argument('--data_root_path', type=str, required=True, 
                       help='Path to the evaluation results JSON file')
    parser.add_argument('--output_dataset_name', type=str, required=True,
                       help='Name of the output dataset')
    parser.add_argument('--dataset_size', type=int, required=True,
                       help='Number of data points to extract')
    parser.add_argument('--output_root_path', type=str, required=True,
                       help='Root path where to save the output dataset')
    
    args = parser.parse_args()
    
    # Ensure output directory exists
    os.makedirs(args.output_root_path, exist_ok=True)
    
    # Process the evaluation data
    print(f"Processing evaluation data from: {args.data_root_path}")
    print(f"Target dataset size: {args.dataset_size}")
    
    processed_data = process_evaluation_data(args.data_root_path, args.dataset_size)
    
    print(f"Successfully processed {len(processed_data)} correct examples")
    
    # Save the processed dataset
    output_file_path = os.path.join(args.output_root_path, f"{args.output_dataset_name}.json")
    with open(output_file_path, 'w', encoding='utf-8') as f:
        json.dump(processed_data, f, indent=2, ensure_ascii=False)
    
    print(f"Saved dataset to: {output_file_path}")
    
    # Update dataset_info.json
    update_dataset_info(args.output_root_path, args.output_dataset_name)
    print(f"Updated dataset_info.json with entry: {args.output_dataset_name}")
    
    print("Processing completed successfully!")


if __name__ == "__main__":
    main()
