#!/usr/bin/env python3
import os
import json
import sys
import argparse
import re
import random
import importlib.util
from pathlib import Path

def extract_dataset_name_from_path(directory_path, suffix=""):
    """Extract dataset name from directory path"""
    dir_name = os.path.basename(directory_path.rstrip('/'))
    base_name = f"{dir_name}_new_cpt_data"
    if suffix:
        return f"{base_name}_{suffix}"
    return base_name

def load_existing_dataset_info(dataset_info_path):
    """Load existing dataset_info.json or create empty dict"""
    if os.path.exists(dataset_info_path):
        with open(dataset_info_path, 'r') as f:
            return json.load(f)
    else:
        return {}

def extract_policy_id(policy_content):
    """Extract policy ID from Policy.md content"""
    match = re.search(r'# Agent Policy Document #(P\d+)', policy_content)
    return match.group(1) if match else "P00000"

def parse_task_specifications(policy_content):
    """Parse task specifications from Policy.md to identify required layers"""
    tasks = {}
    
    # Find all task type sections
    task_sections = re.findall(r'### (Task_Type_\d+)(.*?)(?=### |$)', policy_content, re.DOTALL)
    
    for task_name, task_content in task_sections:
        # Extract task number
        task_num = int(task_name.split('_')[-1])
        
        # Find required layers from "profile instance at layer X" patterns
        layer_matches = re.findall(r'profile instance at each of the layer (\d+(?:, layer \d+)*)', task_content)
        required_layers = []
        
        for layer_match in layer_matches:
            # Parse "layer 1, layer 2, layer 3" format
            layers = re.findall(r'layer (\d+)', f'layer {layer_match}')
            required_layers.extend([int(l) for l in layers])
        
        # Remove duplicates and sort
        required_layers = sorted(list(set(required_layers)))
        
        tasks[task_num] = {
            'name': task_name,
            'required_layers': required_layers,
            'content': task_content.strip()
        }
    
    return tasks

def get_available_layers(directory_path):
    """Get all available layer numbers from the Profiles directory"""
    profiles_dir = os.path.join(directory_path, 'Profiles')
    if not os.path.exists(profiles_dir):
        return []
    
    layers = []
    for filename in os.listdir(profiles_dir):
        if filename.startswith('profiles_') and filename.endswith('.json'):
            try:
                layer_num = int(filename.replace('profiles_', '').replace('.json', ''))
                layers.append(layer_num)
            except ValueError:
                continue
    
    return sorted(layers)

def load_profiles(directory_path, layer):
    """Load profiles from the specified layer"""
    profiles_path = os.path.join(directory_path, 'Profiles', f'profiles_{layer}.json')
    if os.path.exists(profiles_path):
        with open(profiles_path, 'r') as f:
            return json.load(f)
    return {}

def get_numeric_attributes(profiles_data):
    """Identify numeric attributes that can be compared across profiles"""
    if not profiles_data:
        return []
    
    # Get a sample profile to identify attributes
    sample_profile = list(profiles_data.values())[0]
    numeric_attrs = []
    
    for attr_key, attr_value in sample_profile.items():
        try:
            # Try to convert to float to check if it's numeric
            float(attr_value)
            numeric_attrs.append(attr_key)
        except (ValueError, TypeError):
            continue
    
    return numeric_attrs

def generate_comparison_question_and_answer(selected_profiles, numeric_attrs):
    """Generate a comparison question and compute the correct answer"""
    if len(numeric_attrs) < 1:
        return None, None
    
    # Define different types of comparison operations
    comparison_types = [
        "largest_single",
        "smallest_single", 
        "largest_sum",
        "smallest_sum",
        "largest_difference",
        "smallest_difference"
    ]
    
    comparison_type = random.choice(comparison_types)
    
    if comparison_type == "largest_single":
        attr = random.choice(numeric_attrs)
        question_text = f"Which profile has the largest value for {attr}?"
        
        max_value = float('-inf')
        best_profile = None
        for profile_key, profile_data in selected_profiles:
            value = float(profile_data[attr])
            if value > max_value:
                max_value = value
                best_profile = profile_key
        
        answer_text = f"Profile {best_profile} has the largest {attr} value of {max_value}."
        
    elif comparison_type == "smallest_single":
        attr = random.choice(numeric_attrs)
        question_text = f"Which profile has the smallest value for {attr}?"
        
        min_value = float('inf')
        best_profile = None
        for profile_key, profile_data in selected_profiles:
            value = float(profile_data[attr])
            if value < min_value:
                min_value = value
                best_profile = profile_key
        
        answer_text = f"Profile {best_profile} has the smallest {attr} value of {min_value}."
        
    elif comparison_type == "largest_sum":
        if len(numeric_attrs) < 2:
            return None, None
        attrs = random.sample(numeric_attrs, 2)
        question_text = f"Which profile has the largest sum of {attrs[0]} and {attrs[1]}?"
        
        max_sum = float('-inf')
        best_profile = None
        for profile_key, profile_data in selected_profiles:
            sum_value = float(profile_data[attrs[0]]) + float(profile_data[attrs[1]])
            if sum_value > max_sum:
                max_sum = sum_value
                best_profile = profile_key
        
        answer_text = f"Profile {best_profile} has the largest sum of {attrs[0]} and {attrs[1]} with a total of {max_sum}."
        
    elif comparison_type == "smallest_sum":
        if len(numeric_attrs) < 2:
            return None, None
        attrs = random.sample(numeric_attrs, 2)
        question_text = f"Which profile has the smallest sum of {attrs[0]} and {attrs[1]}?"
        
        min_sum = float('inf')
        best_profile = None
        for profile_key, profile_data in selected_profiles:
            sum_value = float(profile_data[attrs[0]]) + float(profile_data[attrs[1]])
            if sum_value < min_sum:
                min_sum = sum_value
                best_profile = profile_key
        
        answer_text = f"Profile {best_profile} has the smallest sum of {attrs[0]} and {attrs[1]} with a total of {min_sum}."
        
    elif comparison_type == "largest_difference":
        if len(numeric_attrs) < 2:
            return None, None
        attrs = random.sample(numeric_attrs, 2)
        question_text = f"Which profile has the largest absolute difference between {attrs[0]} and {attrs[1]}?"
        
        max_diff = float('-inf')
        best_profile = None
        for profile_key, profile_data in selected_profiles:
            diff_value = abs(float(profile_data[attrs[0]]) - float(profile_data[attrs[1]]))
            if diff_value > max_diff:
                max_diff = diff_value
                best_profile = profile_key
        
        answer_text = f"Profile {best_profile} has the largest absolute difference between {attrs[0]} and {attrs[1]} with a difference of {max_diff}."
        
    elif comparison_type == "smallest_difference":
        if len(numeric_attrs) < 2:
            return None, None
        attrs = random.sample(numeric_attrs, 2)
        question_text = f"Which profile has the smallest absolute difference between {attrs[0]} and {attrs[1]}?"
        
        min_diff = float('inf')
        best_profile = None
        for profile_key, profile_data in selected_profiles:
            diff_value = abs(float(profile_data[attrs[0]]) - float(profile_data[attrs[1]]))
            if diff_value < min_diff:
                min_diff = diff_value
                best_profile = profile_key
        
        answer_text = f"Profile {best_profile} has the smallest absolute difference between {attrs[0]} and {attrs[1]} with a difference of {min_diff}."
    
    return question_text, answer_text

def generate_profile_comparison_qa_data(directory_path, policy_id, samples_per_layer=1000):
    """Generate profile comparison Q&A data"""
    qa_data = []
    
    # Get available layers
    available_layers = get_available_layers(directory_path)
    if not available_layers:
        print("Warning: No profile layers found")
        return qa_data
    
    print(f"Found {len(available_layers)} layers: {available_layers}")
    
    for layer in available_layers:
        print(f"Generating {samples_per_layer} profile comparison samples for layer {layer}...")
        
        # Load profiles for this layer
        profiles_data = load_profiles(directory_path, layer)
        if len(profiles_data) < 2:
            print(f"Warning: Layer {layer} has fewer than 2 profiles, skipping")
            continue
        
        # Get numeric attributes for comparison
        numeric_attrs = get_numeric_attributes(profiles_data)
        if not numeric_attrs:
            print(f"Warning: Layer {layer} has no numeric attributes for comparison, skipping")
            continue
        
        profile_keys = list(profiles_data.keys())
        
        # Generate samples: 500 with 2 profiles, 500 with 3 profiles
        for num_profiles in [2, 3]:
            samples_for_this_count = samples_per_layer // 2
            
            if len(profiles_data) < num_profiles:
                print(f"Warning: Layer {layer} has fewer than {num_profiles} profiles, skipping {num_profiles}-profile comparisons")
                continue
            
            for sample_idx in range(samples_for_this_count):
                # Randomly select profiles
                selected_profile_keys = random.sample(profile_keys, num_profiles)
                selected_profiles = [(key, profiles_data[key]) for key in selected_profile_keys]
                
                # Generate comparison question and answer
                question_text, answer_text = generate_comparison_question_and_answer(selected_profiles, numeric_attrs)
                
                if question_text is None or answer_text is None:
                    continue
                
                # Format profile information for display
                profile_info_parts = []
                for profile_key, profile_data in selected_profiles:
                    profile_info_parts.append(f"Profile {profile_key}: {str(profile_data)}")
                
                profiles_display = "\n\n".join(profile_info_parts)
                
                # Create full question
                full_question = f"""Consider the following profiles from layer {layer}:

{profiles_display}

{question_text}"""
                
                # Create full answer
                full_answer = f"Looking at the profiles from layer {layer}:\n\n{answer_text}"
                
                # Add to Q&A data
                qa_data.append({
                    "text": f"Question: {full_question}\n\nAnswer: {full_answer}"
                })
    
    return qa_data

def randomly_sample_profile(profiles_data):
    """Randomly sample one profile from the profiles data"""
    if not profiles_data:
        return None
    profile_key = random.choice(list(profiles_data.keys()))
    return profile_key, profiles_data[profile_key]

def load_exec_module(directory_path):
    """Load the exec.py module dynamically"""
    exec_path = os.path.join(directory_path, 'Task', 'exec.py')
    if not os.path.exists(exec_path):
        return None
    
    spec = importlib.util.spec_from_file_location("exec_module", exec_path)
    exec_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(exec_module)
    return exec_module

def format_profile_display(profile_key, profile_data):
    """Format profile data for display in Q&A"""
    return f"Profile {profile_key}: {str(profile_data)}"

def get_global_attributes_text(exec_module):
    """Extract global attributes from exec module"""
    try:
        global_attr_1 = getattr(exec_module, 'GLOBAL_ATTRIBUTE_VALUE1', 'Unknown')
        global_attr_2 = getattr(exec_module, 'GLOBAL_ATTRIBUTE_VALUE2', 'Unknown')
        global_attr_3 = getattr(exec_module, 'GLOBAL_ATTRIBUTE_VALUE3', 'Unknown')
        return f"Global attributes: global_attribute_1 = {global_attr_1}, global_attribute_2 = {global_attr_2}, global_attribute_3 = {global_attr_3}"
    except:
        return "Global attributes: [Unable to retrieve]"

def extract_argument_specifications(task_content):
    """Extract argument specifications from task content"""
    arg_specs = {}
    # Look for - arg_X: patterns in the task content, capturing until the next bullet point or end
    #import ipdb; ipdb.set_trace()
    arg_matches = re.findall(r'-\s*arg_(\d+):\s*(.*?)(?=\s*-\s*(?:arg_\d+|Each)|$)', task_content, re.DOTALL)
    #if len(arg_matches) !=5 :
    #    import ipdb; ipdb.set_trace()
    for arg_num, spec in arg_matches:
        # Clean up the specification text
        cleaned_spec = spec.strip().rstrip('.')
        if cleaned_spec:
            arg_specs[int(arg_num)] = cleaned_spec
    return arg_specs

def generate_task_qa_data(directory_path, policy_content, policy_id, tasks, samples_per_task=5000):
    """Generate task-level Q&A data"""
    qa_data = []
    
    # Load exec module
    exec_module = load_exec_module(directory_path)
    if not exec_module:
        print("Warning: Could not load exec.py module")
        return qa_data
    
    for task_num, task_info in tasks.items():
        print(f"Generating {samples_per_task} samples for {task_info['name']}...")
        
        for sample_idx in range(samples_per_task):
            # Sample profiles for required layers
            sampled_profiles = {}
            profile_descriptions = []
            instances_for_exec = {}
            
            for layer in task_info['required_layers']:
                profiles_data = load_profiles(directory_path, layer)
                if profiles_data:
                    profile_key, profile_data = randomly_sample_profile(profiles_data)
                    sampled_profiles[layer] = (profile_key, profile_data)
                    profile_descriptions.append(format_profile_display(profile_key, profile_data))
                    
                    # Prepare data for exec.py (convert attribute names)
                    exec_profile_data = {}
                    for attr_key, attr_value in profile_data.items():
                        # Convert profile_1_attribute_1 to Profile_1_Attribute_1
                        parts = attr_key.split('_')
                        if len(parts) >= 4:
                            new_key = f"Profile_{parts[1]}_Attribute_{parts[3]}"
                            exec_profile_data[new_key] = attr_value
                    instances_for_exec[layer] = exec_profile_data
            
            if not sampled_profiles:
                continue
            
            # Create question
            profiles_text = "\n\n".join(profile_descriptions)
            question = f"""For Policy #{policy_id}, obtaining the following instances:

{profiles_text}

How would you complete {task_info['name']} using these profile instances?"""
            
            # Create answer
            answer_parts = []
            
            # Add global attributes
            global_attrs_text = get_global_attributes_text(exec_module)
            answer_parts.append(global_attrs_text)
            
            answer_parts.append(f"To finish {task_info['name']} for Policy #{policy_id}:")
            answer_parts.append(f"{task_info['content']}")
            
            # Compute results using exec.py
            try:
                compute_func = getattr(exec_module, f'compute_task_{task_num}', None)
                if compute_func:
                    results = compute_func(instances_for_exec)
                    
                    # Format individual arguments
                    result_parts = []
                    for i, result in enumerate(results, 1):
                        result_parts.append(f"arg_{i} is thus computed as {result}")
                    
                    answer_parts.append("Computed arguments:")
                    answer_parts.append(", ".join(result_parts))
                    answer_parts.append(f"This task should be finished by calling finish_task_{task_num}({results})")
                else:
                    answer_parts.append(f"[Function compute_task_{task_num} not found]")
            except Exception as e:
                answer_parts.append(f"[Error computing results: {str(e)}]")
            
            answer = "\n\n".join(answer_parts)
            
            # Add to Q&A data
            qa_data.append({
                "text": f"Question: {question}\n\nAnswer: {answer}"
            })
    
    return qa_data

def generate_argument_specific_qa_data(directory_path, policy_content, policy_id, tasks, samples_per_task=5000):
    """Generate argument-specific Q&A data focusing on individual arguments"""
    qa_data = []
    
    # Load exec module
    exec_module = load_exec_module(directory_path)
    if not exec_module:
        print("Warning: Could not load exec.py module")
        return qa_data
    
    for task_num, task_info in tasks.items():
        print(f"Generating {samples_per_task} argument-specific samples for {task_info['name']}...")
        
        # Extract argument specifications
        arg_specs = extract_argument_specifications(task_info['content'])
        
        for sample_idx in range(samples_per_task):
            # Sample profiles for required layers (same as before)
            sampled_profiles = {}
            profile_descriptions = []
            instances_for_exec = {}
            
            for layer in task_info['required_layers']:
                profiles_data = load_profiles(directory_path, layer)
                if profiles_data:
                    profile_key, profile_data = randomly_sample_profile(profiles_data)
                    sampled_profiles[layer] = (profile_key, profile_data)
                    profile_descriptions.append(format_profile_display(profile_key, profile_data))
                    
                    # Prepare data for exec.py (convert attribute names)
                    exec_profile_data = {}
                    for attr_key, attr_value in profile_data.items():
                        # Convert profile_1_attribute_1 to Profile_1_Attribute_1
                        parts = attr_key.split('_')
                        if len(parts) >= 4:
                            new_key = f"Profile_{parts[1]}_Attribute_{parts[3]}"
                            exec_profile_data[new_key] = attr_value
                    instances_for_exec[layer] = exec_profile_data
            
            if not sampled_profiles:
                continue
            
            # Compute results once
            try:
                compute_func = getattr(exec_module, f'compute_task_{task_num}', None)
                if not compute_func:
                    continue
                results = compute_func(instances_for_exec)
            except Exception:
                continue
            
            # Generate one Q&A for each argument
            #import ipdb; ipdb.set_trace()
            for arg_num in range(1, len(results) + 1):
                if arg_num in arg_specs:
                    profiles_text = "\n\n".join(profile_descriptions)
                    
                    # Create question for specific argument
                    question = f"""For Policy #{policy_id}, obtaining the following instances:

{profiles_text}

What is the value of argument {arg_num} to complete {task_info['name']}?"""
                    
                    # Create answer for specific argument
                    arg_spec = arg_specs[arg_num]
                    arg_value = results[arg_num - 1]  # results is 0-indexed
                    
                    # Include global attributes for completeness
                    global_attrs_text = get_global_attributes_text(exec_module)
                    answer = f"{global_attrs_text}\n\nFor argument {arg_num}: {arg_spec}\n\nThe computed value is: {arg_value}"
                    
                    # Add to Q&A data
                    qa_data.append({
                        "text": f"Question: {question}\n\nAnswer: {answer}"
                    })
    
    return qa_data

def process_directory_content(directory_path, samples_per_task=5000, samples_per_layer=1000):
    """
    Process the directory content and extract text data for CPT.
    Focuses on Policy-related content for training.
    
    CPT data contains:
    (1) General QAs related to the Policy document
    (2) Task-level QAs with profile sampling and computed results
    (3) Profile comparison QAs with attribute-based selection
    (4) Another copy of General QAs related to the Policy document (for increased representation)
    
    Args:
        directory_path: Path to the directory to process
        samples_per_task: Number of samples to generate per task for task-level QAs
        samples_per_layer: Number of samples to generate per layer for profile comparison QAs
    """
    pretrain_data = []
    
    # ============================================================================
    # Part (1): General QAs related to the Policy document
    # ============================================================================
    
    policy_dir = os.path.join(directory_path, 'Policy')
    policy_content = ""
    policy_id = "P00000"
    part1_data = []  # Store part 1 data for duplication later
    
    if os.path.exists(policy_dir):
        
        # Load Policy Q&A data
        policy_qa_path = os.path.join(policy_dir, 'Policy_QA.json')
        policy_md_path = os.path.join(policy_dir, 'Policy.md')
        
        if os.path.exists(policy_qa_path) and os.path.exists(policy_md_path):
            
            # Read Policy Q&A data
            with open(policy_qa_path, 'r') as f:
                policy_qa_data = json.load(f)
            
            # Read Policy.md content
            with open(policy_md_path, 'r') as f:
                policy_content = f.read().strip()
            
            # Extract policy ID
            policy_id = extract_policy_id(policy_content)
            
            # Create training data from Policy Q&A pairs
            for qa_item in policy_qa_data:
                question = qa_item.get('question', '')
                answer = qa_item.get('answer', '')
                
                # Format as a conversation for CPT training
                qa_text = f"""Question: {question}

Answer: {answer}"""
                
                qa_entry = {
                    "text": qa_text
                }
                pretrain_data.append(qa_entry)
                part1_data.append(qa_entry)  # Store for duplication
            
            # Add the complete Policy document as training data
            policy_text = f"""Policy Document Content:

{policy_content}"""
            
            policy_entry = {
                "text": policy_text
            }
            pretrain_data.append(policy_entry)
            part1_data.append(policy_entry)  # Store for duplication
            
            print(f"Processed {len(policy_qa_data)} Policy Q&A pairs")
            print("Added Policy document content")
        
        else:
            print(f"Warning: Policy_QA.json or Policy.md not found in {policy_dir}")
    
    else:
        print(f"Warning: Policy directory not found in {directory_path}")
    
    # ============================================================================
    # Part (2): Task-level QAs with profile sampling and computed results  
    # ============================================================================
    
    if policy_content:
        print("Generating task-level Q&A data...")
        
        # Parse task specifications
        tasks = parse_task_specifications(policy_content)
        print(f"Found {len(tasks)} tasks: {list(tasks.keys())}")
        
        # Generate comprehensive task Q&A data
        task_qa_data = generate_task_qa_data(directory_path, policy_content, policy_id, tasks, samples_per_task=samples_per_task)
        pretrain_data.extend(task_qa_data)
        print(f"Generated {len(task_qa_data)} comprehensive task-level Q&A samples")
        
        # Generate argument-specific Q&A data
        arg_qa_data = generate_argument_specific_qa_data(directory_path, policy_content, policy_id, tasks, samples_per_task=samples_per_task)
        pretrain_data.extend(arg_qa_data)
        print(f"Generated {len(arg_qa_data)} argument-specific Q&A samples")
    
    # ============================================================================
    # Part (3): Profile comparison QAs with attribute-based selection
    # ============================================================================
    
    print("Generating profile comparison Q&A data...")
    profile_comparison_qa_data = generate_profile_comparison_qa_data(directory_path, policy_id, samples_per_layer=samples_per_layer)
    pretrain_data.extend(profile_comparison_qa_data)
    print(f"Generated {len(profile_comparison_qa_data)} profile comparison Q&A samples")
    
    # ============================================================================
    # Part (4): Another copy of General QAs related to the Policy document
    # ============================================================================
    
    if part1_data:
        print("Adding another copy of Policy Q&A data for increased representation...")
        pretrain_data.extend(part1_data)
        print(f"Added duplicate of {len(part1_data)} Policy-related entries from Part (1)")
    
    return pretrain_data

def update_dataset_info(dataset_info, dataset_name):
    """Update dataset_info with new dataset entry"""
    dataset_info[dataset_name] = {
        "file_name": f"{dataset_name}.json",
        "columns": {
            "prompt": "text"
        }
    }
    return dataset_info

def save_dataset_info(dataset_info, dataset_info_path):
    """Save updated dataset_info to file"""
    with open(dataset_info_path, 'w') as f:
        json.dump(dataset_info, f, indent=2)

def save_pretrain_data(pretrain_data, output_path):
    """Save pretrain data to JSON file"""
    with open(output_path, 'w') as f:
        json.dump(pretrain_data, f, indent=2)

def generate_dataset_variant(directory_path, target_dataset_path, dataset_info_path, suffix, samples_per_task, samples_per_layer):
    """Generate a single dataset variant with specified parameters"""
    
    # Extract dataset name with suffix
    dataset_name = extract_dataset_name_from_path(directory_path, suffix)
    print(f"\n{'='*60}")
    print(f"Generating dataset variant: {dataset_name}")
    print(f"Samples per task: {samples_per_task}")
    print(f"Samples per layer: {samples_per_layer}")
    print(f"{'='*60}")
    
    # Define output paths
    output_json_path = os.path.join(target_dataset_path, f"{dataset_name}.json")
    
    print(f"Processing directory: {directory_path}")
    print(f"Output JSON path: {output_json_path}")
    
    # Process directory content
    print("Processing directory content...")
    pretrain_data = process_directory_content(directory_path, samples_per_task=samples_per_task, samples_per_layer=samples_per_layer)
    print(f"Generated {len(pretrain_data)} total data entries")
    
    # Save pretrain data
    print("Saving pretrain data...")
    save_pretrain_data(pretrain_data, output_json_path)
    
    # Load and update dataset_info
    print("Updating dataset_info.json...")
    dataset_info = load_existing_dataset_info(dataset_info_path)
    dataset_info = update_dataset_info(dataset_info, dataset_name)
    save_dataset_info(dataset_info, dataset_info_path)
    
    print(f"Dataset variant '{dataset_name}' completed!")
    print(f"Data file: {output_json_path}")
    return dataset_name

def main():
    parser = argparse.ArgumentParser(description='Generate CPT data from directory content')
    parser.add_argument('directory_path', help='Path to the directory to process')
    args = parser.parse_args()
    
    directory_path = args.directory_path
    
    # Validate input directory
    if not os.path.exists(directory_path):
        print(f"Error: Directory '{directory_path}' does not exist")
        sys.exit(1)
    
    if not os.path.isdir(directory_path):
        print(f"Error: '{directory_path}' is not a directory")
        sys.exit(1)
    
    # Define paths
    target_dataset_path = '/code/jiateng-sandbox/intern_project/third_party/LLaMA-Factory/data'
    dataset_info_path = os.path.join(target_dataset_path, 'dataset_info.json')
    
    # Create target directory if it doesn't exist
    os.makedirs(target_dataset_path, exist_ok=True)
    
    print(f"Target dataset path: {target_dataset_path}")
    print(f"Dataset info path: {dataset_info_path}")
    
    # Generate both variants
    generated_datasets = []
    
    # Generate "less" variant (1000/5000)
    dataset_name_less = generate_dataset_variant(
        directory_path=directory_path,
        target_dataset_path=target_dataset_path,
        dataset_info_path=dataset_info_path,
        suffix="less",
        samples_per_task=5000,
        samples_per_layer=1000
    )
    generated_datasets.append(dataset_name_less)
    
    # Generate "more" variant (2000/8000)
    dataset_name_more = generate_dataset_variant(
        directory_path=directory_path,
        target_dataset_path=target_dataset_path,
        dataset_info_path=dataset_info_path,
        suffix="more",
        samples_per_task=8000,
        samples_per_layer=2000
    )
    generated_datasets.append(dataset_name_more)
    
    print(f"\n{'='*60}")
    print("CPT data generation completed successfully!")
    print(f"Generated datasets:")
    for dataset_name in generated_datasets:
        print(f"  - {dataset_name}")
    print(f"All datasets saved to: {target_dataset_path}")
    print(f"Dataset info updated: {dataset_info_path}")
    print(f"{'='*60}")

if __name__ == "__main__":
    main()
