#!/usr/bin/env python3
"""
Script for traversing folder structure and merging all .traj files
"""

import os
import sys
import json
from pathlib import Path
from datetime import datetime

# evaluate_names = [
#     'qwen2.5-coder-7b', 'deepseek-r1',
#     'qwen2.5-coder-32b', 'deepseek-v3', 'devstral',
#     'qwen2.5-coder-14b', 'qwen3-14b', 'qwen3-32b',
# ]

evaluate_names = ['deepseek-r1']

class Logger:
    """Logger class that outputs to both console and file"""
    
    def __init__(self, log_file):
        self.log_file = log_file
        self.terminal = sys.stdout
        
        # Create log file directory
        log_path = Path(log_file)
        log_path.parent.mkdir(parents=True, exist_ok=True)
        
        # Open log file
        self.log = open(log_file, 'w', encoding='utf-8')
        
        # Write start time
        start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        self.write(f"Log start time: {start_time}\n")
        self.write("=" * 80 + "\n")
    
    def write(self, message):
        """Write message to console and file"""
        self.terminal.write(message)
        self.log.write(message)
        self.flush()
    
    def flush(self):
        """Flush output"""
        self.terminal.flush()
        self.log.flush()
    
    def close(self):
        """Close file"""
        end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        self.write(f"\nLog end time: {end_time}\n")
        self.log.close()


def extract_content_text(content):
    """
    Extract text content from content field
    
    Args:
        content: Value of content field, could be string or object array containing text field
        
    Returns:
        str: Extracted text content
    """
    if isinstance(content, str):
        # If content is string type, return directly
        return content
    elif isinstance(content, list) and len(content) > 0:
        # If content is array, find first object containing text field
        for item in content:
            if isinstance(item, dict) and 'text' in item:
                return item['text']
        # If no text field found, return string representation of first object
        return str(content[0]) if content else ""
    else:
        # Other cases, return string representation
        return str(content) if content is not None else ""


def build_message_object(item):
    """
    Build message object, preserve role, content and action attributes
    
    Args:
        item: Dictionary containing message information
        
    Returns:
        dict: Built message object
    """
    message = {}
    
    # Add role attribute (if exists)
    if 'role' in item:
        message['role'] = item['role']
    
    # Add content attribute (if exists)
    if 'content' in item:
        message['content'] = extract_content_text(item['content'])
    
    # Add action attribute (if exists and not empty)
    if 'action' in item and item['action']:
        message['action'] = item['action']
    
    return message


def find_traj_and_patch_files(root_folder_path):
    """
    Traverse folder structure to find all .traj and .patch files
    
    Args:
        root_folder_path (str): Root folder path
        
    Returns:
        tuple: (traj structured result dict, all traj files list, first level grouped traj dict, 
                patch structured result dict, all patch files list, first level grouped patch dict)
    """
    root_path = Path(root_folder_path)
    
    if not root_path.exists():
        print(f"Error: Path {root_folder_path} does not exist")
        return {}, [], {}, {}, [], {}
    
    if not root_path.is_dir():
        print(f"Error: Path {root_folder_path} is not a directory")
        return {}, [], {}, {}, [], {}
    
    # Initialize traj file related variables
    traj_result = {}
    all_traj_files = []
    first_level_traj_dict = {}
    
    # Initialize patch file related variables
    patch_result = {}
    all_patch_files = []
    first_level_patch_dict = {}
    
    print(f"Starting to traverse root directory: {root_folder_path}")
    print("=" * 80)
    
    # Get first level subdirectories
    first_level_dirs = [d for d in root_path.iterdir() if d.is_dir() and d.name in evaluate_names]
    print(f"Found {len(first_level_dirs)} first level subdirectories:")
    for dir_path in first_level_dirs:
        print(f"  - {dir_path.name}")
    print()
    
    # Traverse each first level subdirectory
    for first_level_dir in first_level_dirs:
        first_level_name = first_level_dir.name
        print(f"Processing first level directory: {first_level_name}")
        
        # Initialize results for current first level directory
        traj_result[first_level_name] = {}
        first_level_traj_dict[first_level_name] = []
        patch_result[first_level_name] = {}
        first_level_patch_dict[first_level_name] = []
        
        # Get second level subdirectories
        second_level_dirs = [d for d in first_level_dir.iterdir() if d.is_dir()]
        print(f"  Found {len(second_level_dirs)} second level subdirectories:")
        for dir_path in second_level_dirs:
            print(f"    - {dir_path.name}")
        
        # Traverse each second level subdirectory
        for second_level_dir in second_level_dirs:
            second_level_name = second_level_dir.name
            print(f"  Processing second level directory: {second_level_name}")
            
            # Find .traj files
            traj_files = []
            for file_path in second_level_dir.rglob("*.traj"):
                traj_file_path = str(file_path)
                traj_files.append(traj_file_path)
                all_traj_files.append(traj_file_path)
                first_level_traj_dict[first_level_name].append(traj_file_path)
            
            traj_result[first_level_name][second_level_name] = traj_files
            print(f"    Found {len(traj_files)} .traj files")
            
            # Find .patch files
            patch_files = []
            for file_path in second_level_dir.rglob("*.patch"):
                patch_file_path = str(file_path)
                patch_files.append(patch_file_path)
                all_patch_files.append(patch_file_path)
                first_level_patch_dict[first_level_name].append(patch_file_path)
            
            patch_result[first_level_name][second_level_name] = patch_files
            print(f"    Found {len(patch_files)} .patch files")
            
            # Print found files
            for traj_file in traj_files:
                print(f"      [TRAJ] {traj_file}")
            for patch_file in patch_files:
                print(f"      [PATCH] {patch_file}")
        
        print(f"  First level directory '{first_level_name}' total: {len(first_level_traj_dict[first_level_name])} .traj files, {len(first_level_patch_dict[first_level_name])} .patch files")
        print()
    
    return traj_result, all_traj_files, first_level_traj_dict, patch_result, all_patch_files, first_level_patch_dict


def merge_single_group_traj_files(traj_files_list, group_name, output_file):
    """
    Merge .traj files of a single group, output as jsonl file in target format
    Each .traj file corresponds to one jsonl record, format: {"messages": [...], "instance_id": ...}
    
    Args:
        traj_files_list (list): List of .traj file paths
        group_name (str): Group name
        output_file (str): Output jsonl file path
        
    Returns:
        int: Number of successfully processed files
    """
    if not traj_files_list:
        print(f"Group '{group_name}' has no .traj files found")
        return 0
    
    merged_records = []  # Store records for each file
    processed_count = 0
    
    print(f"Starting to merge {len(traj_files_list)} .traj files in group '{group_name}'...")
    
    for traj_file in traj_files_list:
        try:
            print(f"  Processing: {traj_file}")
            
            with open(traj_file, 'r', encoding='utf-8') as f:
                traj_data = json.load(f)
            
            # Create messages array for current file
            file_messages = []
            
            # Infer instance_id from file path
            file_path = Path(traj_file)
            # Use filename (without extension) as base for instance_id
            instance_id = file_path.stem
            
            # If file path contains more information, build more detailed instance_id
            # For example: include parent directory information
            parent_dirs = file_path.parts[-3:-1]  # Get 2nd and 3rd level directories from end
            if len(parent_dirs) >= 2:
                instance_id = f"{parent_dirs[-2]}_{parent_dirs[-1]}_{instance_id}"
            elif len(parent_dirs) >= 1:
                instance_id = f"{parent_dirs[-1]}_{instance_id}"
            
            # Check if there's direct history field (most common case in target format)
            if 'history' in traj_data and isinstance(traj_data['history'], list):
                for history_item in traj_data['history']:
                    if isinstance(history_item, dict):
                        # Only add if contains at least one of role, content or action fields
                        if any(key in history_item for key in ['role', 'content', 'action']):
                            message = build_message_object(history_item)
                            if message:  # Ensure message object is not empty
                                file_messages.append(message)

            
            # If there's a trajectory field, iterate through the steps
            elif 'trajectory' in traj_data and isinstance(traj_data['trajectory'], list):
                for step in traj_data['trajectory']:
                    if isinstance(step, dict):
                        # Check if there's a history field
                        if 'history' in step and isinstance(step['history'], list):
                            for history_item in step['history']:
                                if isinstance(history_item, dict):
                                    # Only add if contains at least one of role, content or action fields
                                    if any(key in history_item for key in ['role', 'content', 'action']):
                                        message = build_message_object(history_item)
                                        if message:  # Ensure message object is not empty
                                            file_messages.append(message)
                        
                        # Check if there's a query field (containing more detailed conversation history)
                        elif 'query' in step and isinstance(step['query'], list):
                            for query_item in step['query']:
                                if isinstance(query_item, dict):
                                    # Only add if contains at least one of role, content or action fields
                                    if any(key in query_item for key in ['role', 'content', 'action']):
                                        message = build_message_object(query_item)
                                        if message:  # Ensure message object is not empty
                                            file_messages.append(message)
                        
                        # If none of the above fields exist, add conversation content in order
                        else:
                            # If there's an action field, use it as user input
                            if 'action' in step and step['action']:
                                file_messages.append({
                                    'role': 'user',
                                    'content': step['action']
                                })
                            
                            # If there's a thought field, use it as assistant thinking process
                            if 'thought' in step and step['thought']:
                                file_messages.append({
                                    'role': 'assistant',
                                    'content': f"Thinking process: {step['thought']}"
                                })
                            
                            # If there's a response field, use it as assistant reply
                            if 'response' in step and step['response']:
                                file_messages.append({
                                    'role': 'assistant',
                                    'content': step['response']
                                })
                            
                            # If there's an observation field, use it as system output
                            if 'observation' in step and step['observation']:
                                file_messages.append({
                                    'role': 'system',
                                    'content': step['observation']
                                })
            
            # If there's a messages array (some files might have this format directly)
            elif 'messages' in traj_data and isinstance(traj_data['messages'], list):
                for message_item in traj_data['messages']:
                    if isinstance(message_item, dict):
                        # Only add if contains at least one of role, content or action fields
                        if any(key in message_item for key in ['role', 'content', 'action']):
                            message = build_message_object(message_item)
                            if message:  # Ensure message object is not empty
                                file_messages.append(message)
            
            # Only add record when there's conversation content
            if file_messages:
                record = {
                    "messages": file_messages,
                    "instance_id": instance_id
                }
                
                # Extract model_stats from info (if exists)
                if 'info' in traj_data and isinstance(traj_data['info'], dict):
                    if 'model_stats' in traj_data['info']:
                        record['model_stats'] = traj_data['info']['model_stats']
                
                merged_records.append(record)
                print(f"    Extracted {len(file_messages)} conversation messages, instance_id: {instance_id}")
                if 'model_stats' in record:
                    print(f"    Contains model statistics: {record['model_stats']}")
            else:
                print(f"    Warning: No valid conversation content found in file {traj_file}")
            
            processed_count += 1
            if processed_count % 10 == 0:
                print(f"    Processed: {processed_count}/{len(traj_files_list)} files")
            
        except Exception as e:
            print(f"    Error processing file {traj_file}: {e}")
            continue
    
    # Save merged data to jsonl file
    try:
        # Ensure output directory exists
        output_path = Path(output_file)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        
        with open(output_file, 'w', encoding='utf-8') as f:
            for record in merged_records:
                json.dump(record, f, ensure_ascii=False)
                f.write('\n')
        
        print(f"  Group '{group_name}' merge completed!")
        print(f"    Processed {processed_count} .traj files in total")
        print(f"    Generated {len(merged_records)} jsonl records")
        total_messages = sum(len(record['messages']) for record in merged_records)
        print(f"    Contains {total_messages} conversation messages in total")
        print(f"    Merge result saved to: {output_file}")
        
        return processed_count
        
    except Exception as e:
        print(f"  Error saving merge result for group '{group_name}': {e}")
        return 0


def merge_traj_files_by_group(first_level_traj_dict, timestamp):
    """
    Merge .traj files by first level grouping, output one jsonl file per group
    
    Args:
        first_level_traj_dict (dict): Traj files dictionary grouped by first level
        timestamp (str): Timestamp for creating output directory
        
    Returns:
        dict: Processing result statistics for each group
    """
    if not first_level_traj_dict:
        print("No .traj files found in any group")
        return {}
    
    # Create output directory
    output_dir = Path("data/temp/traj_jsonl_before_split")
    output_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"\nStarting to merge .traj files by group...")
    print(f"Output directory: {output_dir}")
    print(f"Total {len(first_level_traj_dict)} groups")
    print("=" * 80)
    
    group_results = {}
    
    for group_name, traj_files in first_level_traj_dict.items():
        print(f"\nProcessing group: {group_name}")
        print(f"Contains {len(traj_files)} .traj files")
        print("-" * 50)
        
        if not traj_files:
            print(f"  Group '{group_name}' has no .traj files, skipping")
            group_results[group_name] = {
                'processed_count': 0,
                'output_file': None,
                'total_files': 0
            }
            continue
        
        # Generate output filename
        output_file = output_dir / f"{group_name}.jsonl"
        
        # Merge files of current group
        processed_count = merge_single_group_traj_files(traj_files, group_name, str(output_file))
        
        group_results[group_name] = {
            'processed_count': processed_count,
            'output_file': str(output_file),
            'total_files': len(traj_files)
        }
        
        print(f"  Group '{group_name}' processing completed")
    
    # Print summary
    print("\n" + "=" * 80)
    print("Group merge summary:")
    total_processed = 0
    total_groups_with_files = 0
    
    for group_name, result in group_results.items():
        if result['total_files'] > 0:
            total_groups_with_files += 1
            total_processed += result['processed_count']
            print(f"  {group_name}: Processed {result['processed_count']}/{result['total_files']} files")
            print(f"    Output file: {result['output_file']}")
        else:
            print(f"  {group_name}: No .traj files")
    
    print(f"\nTotal:")
    print(f"  Groups with files: {total_groups_with_files}")
    print(f"  Total successfully processed files: {total_processed}")
    print(f"  Output directory: {output_dir}")
    
    return group_results


def merge_traj_files(traj_files_list, output_file):
    """
    Merge all .traj files, output as jsonl file in target format
    Each .traj file corresponds to one jsonl record, format: {"messages": [...], "instance_id": ...}
    
    Args:
        traj_files_list (list): List of .traj file paths
        output_file (str): Output jsonl file path
        
    Returns:
        int: Number of successfully processed files
    """
    if not traj_files_list:
        print("No .traj files found")
        return 0
    
    merged_records = []  # Store records for each file
    processed_count = 0
    
    print(f"Starting to merge all .traj files, total {len(traj_files_list)} files...")
    
    for traj_file in traj_files_list:
        try:
            print(f"  Processing: {traj_file}")
            
            with open(traj_file, 'r', encoding='utf-8') as f:
                traj_data = json.load(f)
            
            # Create messages array for current file
            file_messages = []
            
            # Infer instance_id from file path
            file_path = Path(traj_file)
            # Use filename (without extension) as base for instance_id
            instance_id = file_path.stem
            
            # If file path contains more information, build more detailed instance_id
            # For example: include parent directory information
            parent_dirs = file_path.parts[-3:-1]  # Get 2nd and 3rd level directories from end
            if len(parent_dirs) >= 2:
                instance_id = f"{parent_dirs[-2]}_{parent_dirs[-1]}_{instance_id}"
            elif len(parent_dirs) >= 1:
                instance_id = f"{parent_dirs[-1]}_{instance_id}"
            
            # Check if there's direct history field (most common case in target format)
            if 'history' in traj_data and isinstance(traj_data['history'], list):
                for history_item in traj_data['history']:
                    if isinstance(history_item, dict):
                        # Only add if contains at least one of role, content or action fields
                        if any(key in history_item for key in ['role', 'content', 'action']):
                            message = build_message_object(history_item)
                            if message:  # Ensure message object is not empty
                                file_messages.append(message)
            
            # If there's a trajectory field, iterate through the steps
            elif 'trajectory' in traj_data and isinstance(traj_data['trajectory'], list):
                for step in traj_data['trajectory']:
                    if isinstance(step, dict):
                        # Check if there's a history field
                        if 'history' in step and isinstance(step['history'], list):
                            for history_item in step['history']:
                                if isinstance(history_item, dict):
                                    # Only add if contains at least one of role, content or action fields
                                    if any(key in history_item for key in ['role', 'content', 'action']):
                                        message = build_message_object(history_item)
                                        if message:  # Ensure message object is not empty
                                            file_messages.append(message)
                        
                        # Check if there's a query field (containing more detailed conversation history)
                        elif 'query' in step and isinstance(step['query'], list):
                            for query_item in step['query']:
                                if isinstance(query_item, dict):
                                    # Only add if contains at least one of role, content or action fields
                                    if any(key in query_item for key in ['role', 'content', 'action']):
                                        message = build_message_object(query_item)
                                        if message:  # Ensure message object is not empty
                                            file_messages.append(message)
                        
                        # If none of the above fields exist, add conversation content in order
                        else:
                            # If there's an action field, use it as user input
                            if 'action' in step and step['action']:
                                file_messages.append({
                                    'role': 'user',
                                    'content': step['action']
                                })
                            
                            # If there's a thought field, use it as assistant thinking process
                            if 'thought' in step and step['thought']:
                                file_messages.append({
                                    'role': 'assistant',
                                    'content': f"Thinking process: {step['thought']}"
                                })
                            
                            # If there's a response field, use it as assistant reply
                            if 'response' in step and step['response']:
                                file_messages.append({
                                    'role': 'assistant',
                                    'content': step['response']
                                })
                            
                            # If there's an observation field, use it as system output
                            if 'observation' in step and step['observation']:
                                file_messages.append({
                                    'role': 'system',
                                    'content': step['observation']
                                })
            
            # If there's a messages array (some files might have this format directly)
            elif 'messages' in traj_data and isinstance(traj_data['messages'], list):
                for message_item in traj_data['messages']:
                    if isinstance(message_item, dict):
                        # Only add if contains at least one of role, content or action fields
                        if any(key in message_item for key in ['role', 'content', 'action']):
                            message = build_message_object(message_item)
                            if message:  # Ensure message object is not empty
                                file_messages.append(message)
            
            # Only add record when there's conversation content
            if file_messages:
                record = {
                    "messages": file_messages,
                    "instance_id": instance_id
                }
                
                # Extract model_stats from info (if exists)
                if 'info' in traj_data and isinstance(traj_data['info'], dict):
                    if 'model_stats' in traj_data['info']:
                        record['model_stats'] = traj_data['info']['model_stats']
                
                merged_records.append(record)
                print(f"    Extracted {len(file_messages)} conversation messages, instance_id: {instance_id}")
                if 'model_stats' in record:
                    print(f"    Contains model statistics: {record['model_stats']}")
            else:
                print(f"    Warning: No valid conversation content found in file {traj_file}")
            
            processed_count += 1
            if processed_count % 10 == 0:
                print(f"    Processed: {processed_count}/{len(traj_files_list)} files")
            
        except Exception as e:
            print(f"    Error processing file {traj_file}: {e}")
            continue
    
    # Save merged data to jsonl file
    try:
        # Ensure output directory exists
        output_path = Path(output_file)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        
        with open(output_file, 'w', encoding='utf-8') as f:
            for record in merged_records:
                json.dump(record, f, ensure_ascii=False)
                f.write('\n')
        
        print(f"Merge completed!")
        print(f"Processed {processed_count} .traj files in total")
        print(f"Generated {len(merged_records)} jsonl records")
        total_messages = sum(len(record['messages']) for record in merged_records)
        print(f"Contains {total_messages} conversation messages in total")
        print(f"Merge result saved to: {output_file}")
        
        return processed_count
        
    except Exception as e:
        print(f"Error saving merge result: {e}")
        return 0





def main():
    """Main function"""
    import sys
    import os
    import json
    
    # Create log folder
    log_dir = Path("log")
    log_dir.mkdir(exist_ok=True)
    
    # Generate log filename (based on current time)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = log_dir / f"traj_finder_{timestamp}.log"
    
    # Initialize logger
    logger = Logger(log_file)
    
    # Redirect print output
    original_stdout = sys.stdout
    sys.stdout = logger
    
    try:
        # Please set the root folder path you want to traverse here
        print("Script starting...")
        print(f"Log file: {log_file}")
        
        root_folder = input("Please enter the root folder path to traverse: ").strip()

        if not root_folder:
            # Default path
            root_folder = "data/temp/traj_json"
        
        print(f"\nStarting to process folder: {root_folder}")
        print("=" * 80)
        
        # Find all .traj and .patch files
        traj_result, all_traj_files, first_level_traj_dict, patch_result, all_patch_files, first_level_patch_dict = find_traj_and_patch_files(root_folder)
        
        print("\n" + "=" * 80)
        print("Summary:")
        print(f"Found {len(all_traj_files)} .traj files in total")
        print(f"Found {len(all_patch_files)} .patch files in total")
        
        # Statistics by first level folders
        for first_level, second_level_dict in traj_result.items():
            total_traj_files = sum(len(files) for files in second_level_dict.values())
            total_patch_files = sum(len(files) for files in patch_result[first_level].values())
            print(f"  {first_level}: {total_traj_files} .traj files, {total_patch_files} .patch files")
            for second_level, files in second_level_dict.items():
                if files:  # Only show folders with files
                    patch_files_count = len(patch_result[first_level][second_level])
                    print(f"    {second_level}: {len(files)} .traj files, {patch_files_count} .patch files")
        
        print("\nAll .traj file path list:")
        print("=" * 80)
        for i, file_path in enumerate(all_traj_files, 1):
            print(f"{i:3d}. {file_path}")
        
        print("\nAll .patch file path list:")
        print("=" * 80)
        for i, file_path in enumerate(all_patch_files, 1):
            print(f"{i:3d}. {file_path}")
        
        # Save file list to separate file
        list_file = log_dir / f"traj_files_list_{timestamp}.txt"
        with open(list_file, 'w', encoding='utf-8') as f:
            f.write(f"Found {len(all_traj_files)} .traj files in total\n")
            f.write(f"Found {len(all_patch_files)} .patch files in total\n")
            f.write("=" * 80 + "\n")
            f.write("\n.traj file list:\n")
            for i, file_path in enumerate(all_traj_files, 1):
                f.write(f"{i:3d}. {file_path}\n")
            f.write("\n.patch file list:\n")
            for i, file_path in enumerate(all_patch_files, 1):
                f.write(f"{i:3d}. {file_path}\n")
        
        # Save first level grouping dictionary to JSON file
        # Save .traj file dictionary
        traj_json_file = log_dir / f"first_level_traj_dict_{timestamp}.json"
        with open(traj_json_file, 'w', encoding='utf-8') as f:
            json.dump(first_level_traj_dict, f, ensure_ascii=False, indent=2)
        
        # Save .patch file dictionary
        patch_json_file = log_dir / f"first_level_patch_dict_{timestamp}.json"
        with open(patch_json_file, 'w', encoding='utf-8') as f:
            json.dump(first_level_patch_dict, f, ensure_ascii=False, indent=2)
        
        print(f"\nFile list saved to: {list_file}")
        print(f"First level .traj grouping dictionary saved to: {traj_json_file}")
        print(f"First level .patch grouping dictionary saved to: {patch_json_file}")
        
        # If there are traj files, provide merge option
        if first_level_traj_dict and any(files for files in first_level_traj_dict.values()):
            total_traj_files = sum(len(files) for files in first_level_traj_dict.values())
            print(f"\nFound {total_traj_files} .traj files")
            choice = input("Do you want to merge .traj files by group into jsonl format? (y/n): ").strip().lower()
            
            if choice == 'y':
                # Merge traj files by group
                group_results = merge_traj_files_by_group(first_level_traj_dict, timestamp)
                
                # Output final statistics
                successful_groups = sum(1 for result in group_results.values() if result['processed_count'] > 0)
                total_processed_files = sum(result['processed_count'] for result in group_results.values())
                
                print(f"\nMerge operation completed!")
                print(f"Successfully processed groups: {successful_groups}")
                print(f"Total successfully processed files: {total_processed_files}")
            else:
                print("Skip merge operation")
        
        return all_traj_files
        
    except Exception as e:
        print(f"Error occurred: {e}")
        return []
    
    finally:
        # Restore original output
        sys.stdout = original_stdout
        logger.close()
        print(f"Script execution completed, please check detailed log: {log_file}")


if __name__ == "__main__":
    import sys
    import os
    import json
    
    # Only run interactive mode
    traj_files = main()