#!/usr/bin/env python3
import os
import json
import sys
import argparse
import re
import random
import importlib.util
from pathlib import Path
from typing import Dict, List, Any, Optional

def extract_dataset_name_from_args(dataset_name):
    """Extract dataset name from command line argument"""
    return f"{dataset_name}_cpt_data"

def load_existing_dataset_info(dataset_info_path):
    """Load existing dataset_info.json or create empty dict"""
    if os.path.exists(dataset_info_path):
        with open(dataset_info_path, 'r') as f:
            return json.load(f)
    else:
        return {}

def load_policy_file(policy_path):
    """Load policy file content (supports .md files)"""
    if not os.path.exists(policy_path):
        raise FileNotFoundError(f"Policy file not found: {policy_path}")
    
    with open(policy_path, 'r', encoding='utf-8') as f:
        content = f.read().strip()
    
    return content

def load_dataset_files(dataset_paths):
    """Load multiple dataset JSON files"""
    datasets = {}
    
    for path in dataset_paths:
        if not os.path.exists(path):
            print(f"Warning: Dataset file not found: {path}")
            continue
        
        try:
            with open(path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            
            # Extract dataset name from filename
            dataset_name = os.path.splitext(os.path.basename(path))[0]
            datasets[dataset_name] = data
            print(f"Loaded dataset: {dataset_name} with {len(data)} entries")
        
        except json.JSONDecodeError as e:
            print(f"Warning: Failed to parse JSON file {path}: {e}")
        except Exception as e:
            print(f"Warning: Failed to load dataset {path}: {e}")
    
    return datasets

def load_tool_specifications(tools_path):
    """Load tool specifications from tools directory"""
    tools_info = {}
    
    if not os.path.exists(tools_path):
        print(f"Warning: Tools path not found: {tools_path}")
        return tools_info
    
    # Look for __init__.py to get tool list
    init_file = os.path.join(tools_path, '__init__.py')
    if os.path.exists(init_file):
        try:
            with open(init_file, 'r') as f:
                init_content = f.read()
            
            # Extract tool names from imports
            import_matches = re.findall(r'from \.(\w+) import (\w+)', init_content)
            for module_name, class_name in import_matches:
                tool_file = os.path.join(tools_path, f"{module_name}.py")
                if os.path.exists(tool_file):
                    tools_info[class_name] = {
                        'module': module_name,
                        'class': class_name,
                        'file_path': tool_file
                    }
            
            print(f"Found {len(tools_info)} tools: {list(tools_info.keys())}")
        
        except Exception as e:
            print(f"Warning: Failed to parse tools __init__.py: {e}")
    
    return tools_info

def extract_policy_sections(policy_content):
    """Extract different sections from policy content"""
    sections = {}
    
    # Split by headers (## or #)
    section_pattern = r'^(#{1,3})\s+(.+?)$'
    lines = policy_content.split('\n')
    
    current_section = None
    current_content = []
    
    for line in lines:
        match = re.match(section_pattern, line)
        if match:
            # Save previous section
            if current_section:
                sections[current_section] = '\n'.join(current_content).strip()
            
            # Start new section
            level = len(match.group(1))
            title = match.group(2).strip()
            current_section = title
            current_content = []
        else:
            if current_section:
                current_content.append(line)
    
    # Save last section
    if current_section:
        sections[current_section] = '\n'.join(current_content).strip()
    
    return sections

def generate_policy_qa_data(policy_content, policy_sections, samples_per_section=100):
    """Generate Q&A data based on policy content"""
    qa_data = []
    
    # Generate overall policy questions
    overall_questions = [
        "What is the main purpose of this policy?",
        "What are the key rules and guidelines outlined in this policy?",
        "What should an agent do when faced with requests outside their scope?",
        "How should an agent handle user confirmation before taking actions?",
    ]
    
    for question in overall_questions:
        qa_text = f"""Question: {question}

Answer: Based on the policy document:

{policy_content[:1000]}..."""
        
        qa_data.append({"text": qa_text})
    
    # Generate section-specific questions
    for section_title, section_content in policy_sections.items():
        if not section_content.strip():
            continue
        
        # Generate questions for this section
        section_questions = [
            f"What are the key points in the {section_title} section?",
            f"How should an agent handle {section_title.lower()}?",
            f"What are the rules for {section_title.lower()}?",
        ]
        
        for question in section_questions:
            qa_text = f"""Question: {question}

Answer: According to the {section_title} section:

{section_content}"""
            
            qa_data.append({"text": qa_text})
    
    print(f"Generated {len(qa_data)} policy-based Q&A samples")
    return qa_data

def generate_dataset_qa_data(datasets, policy_content, samples_per_dataset=500):
    """Generate Q&A data based on dataset information"""
    qa_data = []
    
    for dataset_name, dataset_content in datasets.items():
        print(f"Generating Q&A for dataset: {dataset_name}")
        
        # Sample entries from dataset
        if isinstance(dataset_content, dict):
            dataset_items = list(dataset_content.items())
        elif isinstance(dataset_content, list):
            dataset_items = [(i, item) for i, item in enumerate(dataset_content)]
        else:
            continue
        
        # Generate samples
        for _ in range(min(samples_per_dataset, len(dataset_items))):
            key, item = random.choice(dataset_items)
            
            # Create Q&A based on dataset entry
            if dataset_name == 'users':
                questions = [
                    f"What information is available for user {key}?",
                    f"What are the payment methods for user {key}?",
                    f"What is the membership status of user {key}?",
                ]
            elif dataset_name == 'flights':
                questions = [
                    f"What are the details for flight {key}?",
                    f"What is the status of flight {key}?",
                    f"What are the available seats for flight {key}?",
                ]
            elif dataset_name == 'reservations':
                questions = [
                    f"What are the details for reservation {key}?",
                    f"Who are the passengers for reservation {key}?",
                    f"What flights are included in reservation {key}?",
                ]
            else:
                questions = [f"What information is available for {dataset_name} entry {key}?"]
            
            for question in questions:
                qa_text = f"""Question: {question}

Answer: Based on the {dataset_name} data:

{json.dumps(item, indent=2)[:800]}..."""
                
                qa_data.append({"text": qa_text})
    
    print(f"Generated {len(qa_data)} dataset-based Q&A samples")
    return qa_data

def generate_tool_qa_data(tools_info, policy_content, samples_per_tool=50):
    """Generate Q&A data based on available tools"""
    qa_data = []
    
    # Generate tool overview
    tool_names = list(tools_info.keys())
    tools_overview = f"""Question: What tools are available for the agent?

Answer: The agent has access to the following tools: {', '.join(tool_names)}

These tools allow the agent to perform various actions according to the policy guidelines."""
    
    qa_data.append({"text": tools_overview})
    
    # Generate tool-specific questions
    for tool_name, tool_info in tools_info.items():
        for _ in range(samples_per_tool):
            questions = [
                f"When should the agent use the {tool_name} tool?",
                f"What does the {tool_name} tool do?",
                f"How should the agent properly use {tool_name}?",
            ]
            
            question = random.choice(questions)
            qa_text = f"""Question: {question}

Answer: The {tool_name} tool is available for the agent to use according to the policy guidelines. The agent should use this tool when appropriate and following proper procedures as outlined in the policy document."""
            
            qa_data.append({"text": qa_text})
    
    print(f"Generated {len(qa_data)} tool-based Q&A samples")
    return qa_data

def generate_scenario_qa_data(policy_content, datasets, tools_info, num_scenarios=1000):
    """Generate scenario-based Q&A data combining policy, datasets, and tools"""
    qa_data = []
    
    # Extract policy scenarios
    scenarios = []
    if 'users' in datasets and 'flights' in datasets:
        scenarios.extend([
            "A user wants to book a flight",
            "A user wants to modify an existing reservation", 
            "A user wants to cancel their reservation",
            "A user is asking about baggage allowance",
            "A user wants to upgrade their cabin class",
            "A user is complaining about a delayed flight",
        ])
    
    for _ in range(num_scenarios):
        if not scenarios:
            break
        
        scenario = random.choice(scenarios)
        
        # Sample relevant data
        sample_data = {}
        if 'users' in datasets:
            user_key = random.choice(list(datasets['users'].keys()))
            sample_data['user'] = datasets['users'][user_key]
        
        qa_text = f"""Question: How should the agent handle this scenario: {scenario}?

Answer: According to the policy, the agent should follow these guidelines:

{policy_content[:500]}...

The agent should use appropriate tools like {', '.join(list(tools_info.keys())[:3])} and follow proper procedures for user confirmation and data handling."""
        
        qa_data.append({"text": qa_text})
    
    print(f"Generated {len(qa_data)} scenario-based Q&A samples")
    return qa_data

def process_generalizable_content(policy_path, dataset_paths, tools_path):
    """
    Process policy file, dataset files, and tools to generate CPT training data
    """
    pretrain_data = []
    
    # Load policy content
    print("Loading policy content...")
    policy_content = load_policy_file(policy_path)
    policy_sections = extract_policy_sections(policy_content)
    print(f"Loaded policy with {len(policy_sections)} sections")
    
    # Add raw policy content
    policy_text = f"""Policy Document:

{policy_content}"""
    pretrain_data.append({"text": policy_text})
    
    # Load datasets
    print("Loading datasets...")
    datasets = load_dataset_files(dataset_paths)
    
    # Load tools
    print("Loading tool specifications...")
    tools_info = load_tool_specifications(tools_path)
    
    # Generate different types of Q&A data
    print("Generating Q&A data...")
    
    # 1. Policy-based Q&A
    policy_qa = generate_policy_qa_data(policy_content, policy_sections)
    pretrain_data.extend(policy_qa)
    
    # 2. Dataset-based Q&A
    dataset_qa = generate_dataset_qa_data(datasets, policy_content)
    pretrain_data.extend(dataset_qa)
    
    # 3. Tool-based Q&A  
    tool_qa = generate_tool_qa_data(tools_info, policy_content)
    pretrain_data.extend(tool_qa)
    
    # 4. Scenario-based Q&A
    scenario_qa = generate_scenario_qa_data(policy_content, datasets, tools_info)
    pretrain_data.extend(scenario_qa)
    
    return pretrain_data

def update_dataset_info(dataset_info, dataset_name):
    """Update dataset_info with new dataset entry"""
    dataset_info[dataset_name] = {
        "file_name": f"{dataset_name}.json",
        "columns": {
            "prompt": "text"
        }
    }
    return dataset_info

def save_dataset_info(dataset_info, dataset_info_path):
    """Save updated dataset_info to file"""
    with open(dataset_info_path, 'w') as f:
        json.dump(dataset_info, f, indent=2)

def save_pretrain_data(pretrain_data, output_path):
    """Save pretrain data to JSON file"""
    with open(output_path, 'w') as f:
        json.dump(pretrain_data, f, indent=2)

def main():
    parser = argparse.ArgumentParser(description='Generate CPT data from policy, datasets, and tools')
    parser.add_argument('--policy_file', required=True, help='Path to the policy file (e.g., wiki.md)')
    parser.add_argument('--dataset_files', required=True, nargs='+', help='Paths to dataset JSON files')
    parser.add_argument('--tools_path', required=True, help='Path to the tools directory')
    parser.add_argument('--dataset_name', required=True, help='Name for the output dataset')
    parser.add_argument('--output_dir', default='/code/jiateng-sandbox/intern_project/third_party/LLaMA-Factory/data', 
                       help='Output directory for the generated dataset')
    
    args = parser.parse_args()
    
    # Validate inputs
    if not os.path.exists(args.policy_file):
        print(f"Error: Policy file '{args.policy_file}' does not exist")
        sys.exit(1)
    
    for dataset_file in args.dataset_files:
        if not os.path.exists(dataset_file):
            print(f"Warning: Dataset file '{dataset_file}' does not exist")
    
    if not os.path.exists(args.tools_path):
        print(f"Warning: Tools path '{args.tools_path}' does not exist")
    
    # Define paths
    dataset_name = extract_dataset_name_from_args(args.dataset_name)
    dataset_info_path = os.path.join(args.output_dir, 'dataset_info.json')
    output_json_path = os.path.join(args.output_dir, f"{dataset_name}.json")
    
    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)
    
    print(f"Dataset name: {dataset_name}")
    print(f"Policy file: {args.policy_file}")
    print(f"Dataset files: {args.dataset_files}")
    print(f"Tools path: {args.tools_path}")
    print(f"Output JSON path: {output_json_path}")
    
    # Process content
    print("Processing content...")
    pretrain_data = process_generalizable_content(
        args.policy_file, 
        args.dataset_files, 
        args.tools_path
    )
    print(f"Generated {len(pretrain_data)} total data entries")
    
    # Save pretrain data
    print("Saving pretrain data...")
    save_pretrain_data(pretrain_data, output_json_path)
    
    # Load and update dataset_info
    print("Updating dataset_info.json...")
    dataset_info = load_existing_dataset_info(dataset_info_path)
    dataset_info = update_dataset_info(dataset_info, dataset_name)
    save_dataset_info(dataset_info, dataset_info_path)
    
    print("CPT data generation completed successfully!")
    print(f"Dataset '{dataset_name}' has been created")
    print(f"Data file: {output_json_path}")
    print(f"Dataset info updated: {dataset_info_path}")

if __name__ == "__main__":
    main()
