#!/usr/bin/env python3

import json
import argparse
import os
import sys
import importlib
from typing import List, Dict, Any

# Root path to the tau-bench evaluation results
Root_Path = '/code/jiateng-sandbox/taubench_application/Single_turn_tau_bench/tau_bench_retail_Qwen_Qwen3_32B_20250805_200219.json'

# Output configuration
output_dataset_name = 'self_cot_tau_bench'
output_dir = '/code/jiateng-sandbox/intern_project/third_party/LLaMA-Factory/data'


def load_retail_tools() -> str:
    """Load retail tools and get their specifications"""
    # Add tau-bench to path
    tau_bench_root = "/code/jiateng-sandbox/taubench_application/tau-bench"
    domain_path = os.path.join(tau_bench_root, "tau_bench", "envs", "retail")
    
    if domain_path not in sys.path:
        sys.path.append(domain_path)
    if tau_bench_root not in sys.path:
        sys.path.append(tau_bench_root)
    
    try:
        # Load data for tools initialization
        data_module_path = "tau_bench.envs.retail.data"
        data_module = importlib.import_module(data_module_path)
        data = data_module.load_data()
        
        # Load tools
        tools_module_path = "tau_bench.envs.retail.tools"
        tools_module = importlib.import_module(tools_module_path)
        
        tools_map = {}
        for tool_class in tools_module.ALL_TOOLS:
            tool_name = tool_class.__name__
            tools_map[tool_name] = tool_class
        
        # Get tool specifications
        specifications = []
        for tool_name, tool_class in tools_map.items():
            try:
                info = tool_class.get_info()
                func_info = info.get('function', {})
                name = func_info.get('name', tool_name)
                description = func_info.get('description', 'No description available')
                parameters = func_info.get('parameters', {})
                
                spec = f"Tool: {name}\n"
                spec += f"Description: {description}\n"
                
                if parameters and 'properties' in parameters:
                    spec += "Parameters:\n"
                    for param_name, param_info in parameters['properties'].items():
                        param_type = param_info.get('type', 'unknown')
                        param_desc = param_info.get('description', 'No description')
                        required = param_name in parameters.get('required', [])
                        spec += f"  - {param_name} ({param_type}{'*' if required else ''}): {param_desc}\n"
                
                specifications.append(spec)
            except Exception as e:
                print(f"Error getting info for tool {tool_name}: {e}")
        
        return "\n".join(specifications)
        
    except Exception as e:
        print(f"Error loading retail tools: {e}")
        return "Tool specifications could not be loaded"


def build_system_prompt() -> str:
    """Build the system prompt using the prompt template and retail tools"""
    template_path = "/code/jiateng-sandbox/taubench_application/Single_turn_tau_bench/prompt_template_none.txt"
    
    try:
        with open(template_path, 'r') as f:
            template = f.read()
    except Exception as e:
        print(f"Error loading prompt template: {e}")
        template = "Based on the provided policy document and tools, help the user complete their request.\n\n{Tool Specifications}"
    
    # Get retail tool specifications
    tool_specifications = load_retail_tools()
    
    # Replace placeholder
    system_prompt = template.replace("{Tool Specifications}", tool_specifications)
    
    return system_prompt


def process_conversation(conversation: List[Dict[str, Any]], instruction: str) -> List[Dict[str, str]]:
    """
    Process conversation from tau-bench format to sharegpt format
    Pattern: human + (N-1) turns of gpt-observation pairs + final gpt
    """
    processed_conversations = []
    
    # First message is always the human instruction
    processed_conversations.append({
        'from': 'human',
        'value': instruction
    })
    
    # Process the conversation turns
    for i, turn in enumerate(conversation):
        response = turn.get('response', '')
        tool_call = turn.get('tool_call')
        tool_result = turn.get('tool_result')
        
        # Add assistant response
        processed_conversations.append({
            'from': 'gpt',
            'value': response
        })
        
        # Always add observation except for the last turn
        if i < len(conversation) - 1:
            # Handle tool result
            if tool_result:
                status = tool_result.get('status', 'unknown')
                result = tool_result.get('result', 'No result')
                
                if status == 'success':
                    if result:
                        observation_content = f"Tool execution result: {json.dumps(result)}"
                    else:
                        observation_content = "Tool execution was successful but returned empty result. Please continue with the task."
                else:
                    observation_content = f"Tool execution failed. Error: {result}. Please continue with the task."
            else:
                # No tool call or tool result
                observation_content = "No tool was executed or tool returned empty result. Please continue with the task."
            
            processed_conversations.append({
                'from': 'observation',
                'value': observation_content
            })
    
    return processed_conversations


def transform_tau_bench_data(data_path: str) -> List[Dict[str, Any]]:
    """
    Transform tau-bench evaluation results to CoT training data
    """
    # Load the evaluation results
    with open(data_path, 'r', encoding='utf-8') as f:
        raw_data = json.load(f)
    
    # Build system prompt
    system_prompt = build_system_prompt()
    
    processed_data = []
    included_count = 0
    
    for item in raw_data:
        # Only include data where final_status is "completed_successfully"
        final_status = item.get('final_status')
        if final_status != 'completed_successfully':
            continue
        
        instruction = item.get('instruction', '')
        conversation = item.get('conversation', [])
        
        if not instruction or not conversation:
            continue
        
        try:
            # Process conversation
            processed_conversations = process_conversation(conversation, instruction)
            
            # Create the training data item
            training_item = {
                'conversations': processed_conversations,
                'system': system_prompt
            }
            
            processed_data.append(training_item)
            included_count += 1
            
        except Exception as e:
            print(f"Error processing item: {e}")
            continue
    
    print(f"Processed {included_count} items with 'completed_incorrectly' status out of {len(raw_data)} total items")
    
    return processed_data


def update_dataset_info(output_dir: str, dataset_name: str):
    """Update dataset_info.json with the new dataset registration"""
    dataset_info_path = os.path.join(output_dir, 'dataset_info.json')
    
    # Load existing dataset_info.json
    if os.path.exists(dataset_info_path):
        with open(dataset_info_path, 'r', encoding='utf-8') as f:
            dataset_info = json.load(f)
    else:
        dataset_info = {}
    
    # Add new dataset entry
    dataset_info[dataset_name] = {
        "file_name": f"{dataset_name}.json",
        "formatting": "sharegpt",
        "columns": {
            "messages": "conversations",
            "system": "system"
        }
    }
    
    # Save updated dataset_info.json
    with open(dataset_info_path, 'w', encoding='utf-8') as f:
        json.dump(dataset_info, f, indent=2, ensure_ascii=False)
    
    print(f"Updated dataset_info.json with entry: {dataset_name}")


def main():
    """Main transformation function"""
    
    # Check if input file exists
    if not os.path.exists(Root_Path):
        print(f"Error: Input file not found: {Root_Path}")
        return False
    
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"Processing tau-bench data from: {Root_Path}")
    print(f"Output dataset name: {output_dataset_name}")
    print(f"Output directory: {output_dir}")
    
    # Transform the data
    processed_data = transform_tau_bench_data(Root_Path)
    
    if not processed_data:
        print("No data was processed. Please check your input file and criteria.")
        return False
    
    # Save the processed dataset
    output_file_path = os.path.join(output_dir, f"{output_dataset_name}.json")
    with open(output_file_path, 'w', encoding='utf-8') as f:
        json.dump(processed_data, f, indent=2, ensure_ascii=False)
    
    print(f"Saved {len(processed_data)} training examples to: {output_file_path}")
    
    # Update dataset_info.json
    update_dataset_info(output_dir, output_dataset_name)
    
    print("Transformation completed successfully!")
    return True


if __name__ == "__main__":
    main()