from datasets import load_from_disk, Dataset, Features, Image as HFImage, Value, Sequence
import argparse
import os
from tqdm import tqdm
from PIL import Image
import io

def convert_dataset(input_path, info, use_original_as_fallback=False, args=None, lower_threshold=None, upper_threshold=None):
    # Load dataset from disk
    print(f"Loading dataset from ./evolved_data/{input_path}")
    dataset = load_from_disk(f'./evolved_data/{input_path}')

    # Filter out samples with progress bar
    print(f"Filtering dataset with {len(dataset)} samples...")
    has_evolved = []
    has_actual_evolved = []  # Track samples with actual evolved questions (no fallback)
    no_evolved_but_kept = []  # Track samples with no evolved question but kept due to fallback
    
    # New counters for harder evolved questions
    has_harder_evolved = []  # Track samples with any harder evolved questions
    harder_evolved_count = 0  # Count all harder evolved questions

    for i, sample in tqdm(enumerate(dataset), total=len(dataset), desc="Filtering"):
        # Check if evolved_question exists and is not None
        if 'evolved_question' in sample and sample['evolved_question'] is not None:
            has_evolved.append(i)
            has_actual_evolved.append(i)
        elif use_original_as_fallback:
            has_evolved.append(i)
            no_evolved_but_kept.append(i)
            
        # Check for harder evolved questions in all_evolved_questions
        if ('all_evolved_questions' in sample and 'all_is_harder' in sample and 
            isinstance(sample['all_evolved_questions'], list) and 
            isinstance(sample['all_is_harder'], list)):
            
            # Check if any evolved questions are marked as harder
            for j, is_harder in enumerate(sample['all_is_harder']):
                if is_harder and j < len(sample['all_evolved_questions']):
                    has_harder_evolved.append(i)
                    harder_evolved_count += 1
                    break

    valid_samples = dataset.select(has_evolved)
    actual_evolved_samples = dataset.select(has_actual_evolved)
    
    # Select samples with harder evolved questions for the new combined_all dataset
    harder_evolved_samples = dataset.select(has_harder_evolved)

    print(f"Filtered to {len(valid_samples)} valid samples")
    print(f"Samples with actual evolved questions: {len(actual_evolved_samples)}")
    print(f"Samples with harder evolved questions: {len(harder_evolved_samples)}")
    print(f"Total harder evolved questions: {harder_evolved_count}")

    if use_original_as_fallback:
        print(f"Including {len(no_evolved_but_kept)} samples using original question as fallback")

    # Create original and evolved datasets with a single mapping operation
    def transform_data(example):
        original_problem = example['original_question']

        # Use the original as evolved if no evolved question and fallback enabled
        if 'evolved_question' in example and example['evolved_question'] is not None:
            evolved_problem = example['evolved_question']
            is_fallback = False
        else:
            evolved_problem = original_problem
            is_fallback = True

        # Convert image to PIL.Image if it's not already
        pil_image = example['image']
        if not isinstance(pil_image, Image.Image):
            # If it's bytes or something else, try to convert to PIL Image
            if isinstance(pil_image, bytes):
                pil_image = Image.open(io.BytesIO(pil_image))
            # Add other conversions as needed based on your data

        # Process original problem
        original_valid = True
        if original_problem.count("<image>") == 0:
            # Add image tag at the beginning if missing
            original_problem = "<image>\n" + original_problem
        elif original_problem.count("<image>") > 1:
            # Discard if multiple image tags
            original_valid = False

        # Process evolved problem
        evolved_valid = True
        if evolved_problem.count("<image>") == 0:
            # Add image tag at the beginning if missing
            evolved_problem = "<image>\n" + evolved_problem
        elif evolved_problem.count("<image>") > 1:
            # Discard if multiple image tags
            evolved_valid = False

        # Get pass values if they exist
        original_pass = example.get('original_pass', 0)
        evolved_pass = example.get('evolved_pass', 0)

        return {
            'original': {
                'id': example['id'],
                'images': [pil_image],  # Using PIL Image object and renamed to "images"
                'problem': original_problem,
                'answer': example['original_answer'],  # Changed from 'answer' to 'original_answer'
                'valid': original_valid,
                'pass_rates': original_pass,
                'is_evolved': False
            },
            'evolved': {
                'id': example['id'],
                'images': [pil_image],  # Using PIL Image object and renamed to "images"
                'problem': evolved_problem,
                'answer': example['original_answer'],  # Changed from 'answer' to 'original_answer'
                'valid': evolved_valid,
                'is_fallback': is_fallback,
                'pass_rates': evolved_pass,
                'is_evolved': True
            },
            'has_actual_evolved': not is_fallback,
        }

    # New function to transform data including all harder evolved questions
    def transform_data_with_harder(example):
        original_problem = example['original_question']
        
        # Convert image to PIL.Image if it's not already
        pil_image = example['image']
        if not isinstance(pil_image, Image.Image):
            # If it's bytes or something else, try to convert to PIL Image
            if isinstance(pil_image, bytes):
                pil_image = Image.open(io.BytesIO(pil_image))
        
        # Process original problem
        original_valid = True
        if original_problem.count("<image>") == 0:
            # Add image tag at the beginning if missing
            original_problem = "<image>\n" + original_problem
        elif original_problem.count("<image>") > 1:
            # Discard if multiple image tags
            original_valid = False
            
        # Store basic information from the original problem
        result = {
            'original': {
                'id': example['id'],
                'images': [pil_image],
                'problem': original_problem,
                'answer': example['original_answer'],
                'valid': original_valid,
                'pass_rates': example.get('original_pass', 0),
                'is_evolved': False
            },
            'harder_evolved': []
        }
        
        # Add all harder evolved questions with their corresponding pass rates
        if ('all_evolved_questions' in example and 'all_is_harder' in example and 
            isinstance(example['all_evolved_questions'], list) and 
            isinstance(example['all_is_harder'], list)):
            
            # Get all pass rates if available
            all_pass_rates = example.get('all_pass_rates', [])
            
            for i, (evolved_question, is_harder) in enumerate(zip(example['all_evolved_questions'], example['all_is_harder'])):
                if is_harder and evolved_question is not None:
                    # Process evolved problem
                    evolved_valid = True
                    if evolved_question.count("<image>") == 0:
                        # Add image tag at the beginning if missing
                        evolved_question = "<image>\n" + evolved_question
                    elif evolved_question.count("<image>") > 1:
                        # Discard if multiple image tags
                        evolved_valid = False
                    
                    if evolved_valid:
                        # Parse pass rate if available in all_pass_rates
                        pass_rate = None  # Default to None instead of 0
                        if all_pass_rates and i < len(all_pass_rates) and all_pass_rates[i]:
                            # Extract numeric value from format like "7/16"
                            try:
                                pass_rate_str = all_pass_rates[i]
                                if isinstance(pass_rate_str, str) and '/' in pass_rate_str:
                                    pass_rate = int(pass_rate_str.split('/')[0])
                                elif isinstance(pass_rate_str, (int, float)):
                                    pass_rate = pass_rate_str
                            except (ValueError, IndexError):
                                # Fallback if parsing fails
                                pass_rate = example.get('evolved_pass')  # No default value
                        else:
                            # Use evolved_pass as fallback
                            pass_rate = example.get('evolved_pass')  # No default value
                        
                        result['harder_evolved'].append({
                            'id': f"{example['id']}_harder_{i}",
                            'images': [pil_image],
                            'problem': evolved_question,
                            'answer': example['original_answer'],
                            'valid': True,
                            'pass_rates': pass_rate,
                            'is_evolved': True
                        })
        
        return result

    # Transform in a single pass
    transformed = valid_samples.map(
        transform_data,
        num_proc=64,  # Specify number directly to avoid overhead
        desc="Transforming dataset"
    )
    
    # Transform harder evolved samples
    transformed_harder = harder_evolved_samples.map(
        transform_data_with_harder,
        num_proc=64,
        desc="Transforming dataset with harder evolved questions"
    )

    # Extract and filter the original and evolved datasets
    def extract_original(example):
        result = {
            'id': example['original']['id'],
            'images': example['original']['images'],  # Changed from 'image' to 'images'
            'problem': example['original']['problem'],
            'answer': example['original']['answer'],
            'pass_rates': example['original']['pass_rates'],
            'is_evolved': False
        }
        # Ensure validity check depends only on original and evolved
        return result, example['original']['valid'] and example['evolved']['valid']

    def extract_evolved(example):
        result = {
            'id': example['evolved']['id'],
            'images': example['evolved']['images'],  # Changed from 'image' to 'images'
            'problem': example['evolved']['problem'],
            'answer': example['evolved']['answer'],
            'pass_rates': example['evolved']['pass_rates'],
            'is_evolved': True,
            'is_fallback': example['evolved']['is_fallback']
        }
        # Ensure validity check depends only on original and evolved
        return result, example['original']['valid'] and example['evolved']['valid']

    # Apply extraction and filtering for original dataset
    print("Extracting and filtering original dataset...")
    original_data = []
    for item in tqdm(transformed, desc="Processing original"):
        entry, is_valid = extract_original(item)
        if is_valid:
            original_data.append(entry)

    # Apply extraction and filtering for evolved dataset...
    print("Extracting and filtering evolved dataset...")
    evolved_data = []
    fallback_count = 0
    actual_evolved_data = []  # Store only non-fallback evolved samples

    for item in tqdm(transformed, desc="Processing evolved"):
        entry, is_valid = extract_evolved(item)
        if is_valid:
            evolved_data.append(entry)
            if entry.get('is_fallback', False):
                fallback_count += 1
            else:
                # Store only non-fallback evolved samples separately
                evolved_entry = entry.copy()
                if 'is_fallback' in evolved_entry:
                    del evolved_entry['is_fallback']
                actual_evolved_data.append(evolved_entry)

    # Create combined dataset: all originals + all actual evolved samples
    combined_data = []
    # First add all original samples
    combined_data.extend(original_data)
    # Then add all actual evolved samples (without fallbacks)
    combined_data.extend(actual_evolved_data)

    # Create combined_all dataset: all originals + all harder evolved questions
    combined_all_data = []
    harder_evolved_questions = []
    # First add all original samples
    combined_all_data.extend(original_data)
    # Then add all harder evolved questions
    harder_evolved_count = 0
    for item in tqdm(transformed_harder, desc="Processing harder evolved for combined_all"):
        # Add all harder evolved questions from this item
        for harder_evolved in item['harder_evolved']:
            if item['original']['valid']:
                entry = {
                    'id': harder_evolved['id'],
                    'images': harder_evolved['images'],
                    'problem': harder_evolved['problem'],
                    'answer': harder_evolved['answer'],
                    'pass_rates': harder_evolved['pass_rates'],
                    'is_evolved': True
                }
                combined_all_data.append(entry)
                harder_evolved_questions.append(entry)
                harder_evolved_count += 1

    # Create ori_only and evol_only datasets (moved from generate_additional_datasets)
    print("Creating ori_only and evol_only datasets...")
    evol_only_data = []
    ori_only_data = []

    # Collect samples for evol_only and ori_only
    for item in tqdm(transformed, desc="Processing ori_only and evol_only"):
        # Skip fallback samples for evol_only
        if not item.get('has_actual_evolved', False):
            continue
            
        # Check if both original and evolved are valid
        if item['original']['valid'] and item['evolved']['valid']:
            # Add to evol_only dataset (only evolved questions)
            evol_entry = {
                'id': item['evolved']['id'],
                'images': item['evolved']['images'],
                'problem': item['evolved']['problem'],
                'answer': item['evolved']['answer'],
                'pass_rates': item['evolved']['pass_rates'],
                'is_evolved': True
            }
            evol_only_data.append(evol_entry)
            
            # Add to ori_only dataset (original questions where evolved versions exist)
            ori_entry = {
                'id': item['original']['id'],
                'images': item['original']['images'],
                'problem': item['original']['problem'],
                'answer': item['original']['answer'],
                'pass_rates': item['original']['pass_rates'],
                'is_evolved': False
            }
            ori_only_data.append(ori_entry)

    # Clean up 'is_fallback' from evolved_data for consistency
    for entry in evolved_data:
        if 'is_fallback' in entry:
            del entry['is_fallback']

    # Define features to ensure proper image handling
    features = Features({
        'id': Value('string'),
        'images': Sequence(HFImage()),
        'problem': Value('string'),
        'answer': Value('string'),
        'pass_rates': Value('int32'),  # Changed from 'pass' to 'pass_rates'
        'is_evolved': Value('bool')    # Added new field for distinguishing original/evolved
    })

    # Create datasets from filtered data with explicit features
    original_dataset = Dataset.from_list(original_data, features=features)
    evolved_dataset = Dataset.from_list(evolved_data, features=features)
    combined_dataset = Dataset.from_list(combined_data, features=features)
    combined_all_dataset = Dataset.from_list(combined_all_data, features=features)
    
    # Create ori_only and evol_only datasets
    ori_only_dataset = Dataset.from_list(ori_only_data, features=features)
    evol_only_dataset = Dataset.from_list(evol_only_data, features=features)
    
    print(f"After image tag filtering: {len(original_dataset)} valid original samples")
    print(f"Overlap count (original used as evolved): {fallback_count} samples")
    print(f"Combined dataset (all original + actual evolved): {len(combined_dataset)} samples")
    print(f"Combined_all dataset (all original + all harder evolved): {len(combined_all_dataset)} samples")
    print(f"Number of harder evolved questions added: {harder_evolved_count}")
    print(f"Ori_only dataset (original with evolved counterparts): {len(ori_only_dataset)} samples")
    print(f"Evol_only dataset (only evolved questions): {len(evol_only_dataset)} samples")

    # NEW: Create filtered datasets based on pass rate thresholds
    if lower_threshold is not None or upper_threshold is not None:
        print(f"\nCreating filtered datasets with pass rate thresholds: lower={lower_threshold}, upper={upper_threshold}")
        
        # Filter evolved dataset with pass rate thresholds
        evolved_filtered_data = []
        evolved_filtered_fallback_count = 0
        
        # Keep track of which original questions need to be included as fallbacks
        original_fallback_ids = set()
        
        # Count None pass rates
        none_pass_rates_count = 0
        
        # Process each transformed item to create a filtered evolved dataset with fallbacks
        for item in tqdm(transformed, desc="Filtering evolved dataset by pass rates"):
            # Get the evolved question and check if it meets the threshold criteria
            evolved_entry = item['evolved']
            original_entry = item['original']
            pass_rate = evolved_entry['pass_rates']
            
            # Count None pass rates
            if pass_rate is None:
                none_pass_rates_count += 1
            
            # Check if pass rate is within threshold limits - FIXED to handle None values
            within_limits = True
            if pass_rate is None:
                # If pass_rate is None, treat it as outside the limits
                within_limits = False
            else:
                if lower_threshold is not None and pass_rate < lower_threshold:
                    within_limits = False
                if upper_threshold is not None and pass_rate > upper_threshold:
                    within_limits = False
            
            # Create entry dictionary
            entry_dict = {
                'id': evolved_entry['id'],
                'images': evolved_entry['images'],
                'problem': evolved_entry['problem'],
                'answer': evolved_entry['answer'],
                'pass_rates': pass_rate,
                'is_evolved': True
            }
            
            if item['has_actual_evolved'] and within_limits and evolved_entry['valid']:
                # If it's an actual evolved question (not a fallback) and within thresholds, add it
                evolved_filtered_data.append(entry_dict)
            elif item['has_actual_evolved'] and not within_limits and original_entry['valid']:
                # If it's outside thresholds, use original as fallback
                fallback_entry = {
                    'id': original_entry['id'],
                    'images': original_entry['images'],
                    'problem': original_entry['problem'],
                    'answer': original_entry['answer'],
                    'pass_rates': original_entry['pass_rates'],
                    'is_evolved': False
                }
                evolved_filtered_data.append(fallback_entry)
                evolved_filtered_fallback_count += 1
                original_fallback_ids.add(original_entry['id'])
            elif not item['has_actual_evolved'] and original_entry['valid']:
                # If it's already a fallback (no actual evolved question), keep the original
                fallback_entry = {
                    'id': original_entry['id'],
                    'images': original_entry['images'],
                    'problem': original_entry['problem'],
                    'answer': original_entry['answer'],
                    'pass_rates': original_entry['pass_rates'],
                    'is_evolved': False
                }
                evolved_filtered_data.append(fallback_entry)
                evolved_filtered_fallback_count += 1
                original_fallback_ids.add(original_entry['id'])
        
        print(f"Number of evolved entries with None pass rates: {none_pass_rates_count}")
        
        # Filter combined dataset with pass rate thresholds
        combined_filtered_data = []
        
        # First add all original samples
        for entry in original_data:
            combined_filtered_data.append(entry)
        
        # Then add filtered evolved samples
        filtered_evolved_count = 0
        for entry in actual_evolved_data:
            pass_rate = entry['pass_rates']
            
            # Check if pass rate is within threshold limits - FIXED to handle None values
            within_limits = True
            if pass_rate is None:
                # If pass_rate is None, treat it as outside the limits
                within_limits = False
            else:
                if lower_threshold is not None and pass_rate < lower_threshold:
                    within_limits = False
                if upper_threshold is not None and pass_rate > upper_threshold:
                    within_limits = False
            
            if within_limits:
                combined_filtered_data.append(entry)
                filtered_evolved_count += 1
        
        # Filter combined_all dataset with pass rate thresholds
        combined_all_filtered_data = []
        
        # First add all original samples
        for entry in original_data:
            combined_all_filtered_data.append(entry)
        
        # Then add filtered harder evolved samples
        filtered_harder_evolved_count = 0
        harder_none_pass_rates_count = 0
        
        for entry in harder_evolved_questions:
            pass_rate = entry['pass_rates']
            
            # Count None pass rates
            if pass_rate is None:
                harder_none_pass_rates_count += 1
            
            # Check if pass rate is within threshold limits - FIXED to handle None values
            within_limits = True
            if pass_rate is None:
                # If pass_rate is None, treat it as outside the limits
                within_limits = False
            else:
                if lower_threshold is not None and pass_rate < lower_threshold:
                    within_limits = False
                if upper_threshold is not None and pass_rate > upper_threshold:
                    within_limits = False
            
            if within_limits:
                combined_all_filtered_data.append(entry)
                filtered_harder_evolved_count += 1
        
        print(f"Number of harder evolved entries with None pass rates: {harder_none_pass_rates_count}")
        
        # Create datasets from filtered data
        evolved_filtered_dataset = Dataset.from_list(evolved_filtered_data, features=features)
        combined_filtered_dataset = Dataset.from_list(combined_filtered_data, features=features)
        combined_all_filtered_dataset = Dataset.from_list(combined_all_filtered_data, features=features)
        
        print(f"Filtered evolved dataset: {len(evolved_filtered_dataset)} samples")
        print(f"  - Original questions used as fallback: {evolved_filtered_fallback_count}")
        print(f"  - Actual evolved questions: {len(evolved_filtered_dataset) - evolved_filtered_fallback_count}")
        print(f"Filtered combined dataset: {len(combined_filtered_dataset)} samples")
        print(f"  - Original questions: {len(original_data)}")
        print(f"  - Filtered evolved questions: {filtered_evolved_count}")
        print(f"Filtered combined_all dataset: {len(combined_all_filtered_dataset)} samples")
        print(f"  - Original questions: {len(original_data)}")
        print(f"  - Filtered harder evolved questions: {filtered_harder_evolved_count}")
        
        # Helper function for creating train/test splits
        def create_train_test_split(dataset, test_size=0, random_seed=42):
            # Shuffle the dataset for random sampling
            shuffled_indices = list(range(len(dataset)))
            
            import random
            random.seed(random_seed)
            random.shuffle(shuffled_indices)
            
            # Select test and train indices
            test_indices = shuffled_indices[:min(test_size, len(dataset))]
            train_indices = shuffled_indices[min(test_size, len(dataset)):]
            
            # Create the splits
            train_dataset = dataset.select(train_indices)
            test_dataset = dataset.select(test_indices)
            
            return train_dataset, test_dataset
        
        # Create train/test splits for all datasets
        original_train, original_test = create_train_test_split(original_dataset, test_size=min(0, len(original_dataset)))
        evolved_train, evolved_test = create_train_test_split(evolved_dataset, test_size=min(0, len(evolved_dataset)))
        combined_train, combined_test = create_train_test_split(combined_dataset, test_size=min(0, len(combined_dataset)))
        combined_all_train, combined_all_test = create_train_test_split(combined_all_dataset, test_size=min(0, len(combined_all_dataset)))
        
        # Add train/test splits for our new filtered datasets
        evolved_filtered_train, evolved_filtered_test = create_train_test_split(
            evolved_filtered_dataset, 
            test_size=min(0, len(evolved_filtered_dataset))
        )
        combined_filtered_train, combined_filtered_test = create_train_test_split(
            combined_filtered_dataset, 
            test_size=min(0, len(combined_filtered_dataset))
        )
        combined_all_filtered_train, combined_all_filtered_test = create_train_test_split(
            combined_all_filtered_dataset, 
            test_size=min(0, len(combined_all_filtered_dataset))
        )
        
        # Create output directories including the threshold info in the directory names
        threshold_suffix = ""
        if lower_threshold is not None:
            threshold_suffix += f"_lower{lower_threshold}"
        if upper_threshold is not None:
            threshold_suffix += f"_upper{upper_threshold}"
        
        # Output directories
        output_dirs = [
            f'./data/verifiable_data/{info}_original/train',
            f'./data/verifiable_data/{info}_original/test',
            f'./data/verifiable_data/{info}_evolved/train',
            f'./data/verifiable_data/{info}_evolved/test',
            f'./data/verifiable_data/{info}_combined/train',
            f'./data/verifiable_data/{info}_combined/test',
            f'./data/verifiable_data/{info}_combined_all/train',
            f'./data/verifiable_data/{info}_combined_all/test',
            f'./data/verifiable_data/{info}_ori_only/train',
            f'./data/verifiable_data/{info}_evol_only/train',
            # New directories for filtered datasets
            f'./data/verifiable_data/{info}_evolved{threshold_suffix}/train',
            f'./data/verifiable_data/{info}_evolved{threshold_suffix}/test',
            f'./data/verifiable_data/{info}_combined{threshold_suffix}/train',
            f'./data/verifiable_data/{info}_combined{threshold_suffix}/test',
            f'./data/verifiable_data/{info}_combined_all{threshold_suffix}/train',
            f'./data/verifiable_data/{info}_combined_all{threshold_suffix}/test',
        ]
        
        for directory in output_dirs:
            os.makedirs(directory, exist_ok=True)
        
        # Save the original datasets
        print("Saving datasets to parquet files...")
        original_train.to_parquet(f'./data/verifiable_data/{info}_original/train/data.parquet')
        original_test.to_parquet(f'./data/verifiable_data/{info}_original/test/data.parquet')
        evolved_train.to_parquet(f'./data/verifiable_data/{info}_evolved/train/data.parquet')
        evolved_test.to_parquet(f'./data/verifiable_data/{info}_evolved/test/data.parquet')
        combined_train.to_parquet(f'./data/verifiable_data/{info}_combined/train/data.parquet')
        combined_test.to_parquet(f'./data/verifiable_data/{info}_combined/test/data.parquet')
        combined_all_train.to_parquet(f'./data/verifiable_data/{info}_combined_all/train/data.parquet')
        combined_all_test.to_parquet(f'./data/verifiable_data/{info}_combined_all/test/data.parquet')
        
        # Save ori_only and evol_only datasets
        ori_only_dataset.to_parquet(f'./data/verifiable_data/{info}_ori_only/train/data.parquet')
        evol_only_dataset.to_parquet(f'./data/verifiable_data/{info}_evol_only/train/data.parquet')
        
        # Save the filtered datasets
        evolved_filtered_train.to_parquet(f'./data/verifiable_data/{info}_evolved{threshold_suffix}/train/data.parquet')
        evolved_filtered_test.to_parquet(f'./data/verifiable_data/{info}_evolved{threshold_suffix}/test/data.parquet')
        combined_filtered_train.to_parquet(f'./data/verifiable_data/{info}_combined{threshold_suffix}/train/data.parquet')
        combined_filtered_test.to_parquet(f'./data/verifiable_data/{info}_combined{threshold_suffix}/test/data.parquet')
        combined_all_filtered_train.to_parquet(f'./data/verifiable_data/{info}_combined_all{threshold_suffix}/train/data.parquet')
        combined_all_filtered_test.to_parquet(f'./data/verifiable_data/{info}_combined_all{threshold_suffix}/test/data.parquet')
        
        print(f"Original dataset: {len(original_train)} train, {len(original_test)} test samples")
        print(f"Evolved dataset: {len(evolved_train)} train, {len(evolved_test)} test samples")
        print(f"Combined dataset: {len(combined_train)} train, {len(combined_test)} test samples")
        print(f"Combined_all dataset: {len(combined_all_train)} train, {len(combined_all_test)} test samples")
        print(f"Ori_only dataset: {len(ori_only_dataset)} train samples")
        print(f"Evol_only dataset: {len(evol_only_dataset)} train samples")
        print(f"Filtered evolved dataset: {len(evolved_filtered_train)} train, {len(evolved_filtered_test)} test samples")
        print(f"Filtered combined dataset: {len(combined_filtered_train)} train, {len(combined_filtered_test)} test samples")
        print(f"Filtered combined_all dataset: {len(combined_all_filtered_train)} train, {len(combined_all_filtered_test)} test samples")
        
        # Return all datasets including filtered ones
        return (
            original_dataset, evolved_dataset, combined_dataset, combined_all_dataset, 
            ori_only_dataset, evol_only_dataset, evolved_filtered_dataset, combined_filtered_dataset,
            combined_all_filtered_dataset
        )
    else:
        # If no thresholds provided, continue with the original code
        # Helper function for creating train/test splits
        def create_train_test_split(dataset, test_size=0, random_seed=42):
            # Shuffle the dataset for random sampling
            shuffled_indices = list(range(len(dataset)))
            
            import random
            random.seed(random_seed)
            random.shuffle(shuffled_indices)
            
            # Select test and train indices
            test_indices = shuffled_indices[:min(test_size, len(dataset))]
            train_indices = shuffled_indices[min(test_size, len(dataset)):]
            
            # Create the splits
            
            train_dataset = dataset.select(train_indices)
            test_dataset = dataset.select(test_indices)
            
            return train_dataset, test_dataset
        
        # Create train/test splits
        original_train, original_test = create_train_test_split(original_dataset)
        evolved_train, evolved_test = create_train_test_split(evolved_dataset)
        combined_train, combined_test = create_train_test_split(combined_dataset)
        combined_all_train, combined_all_test = create_train_test_split(combined_all_dataset)
        
        # Create output directories
        output_dirs = [
            f'./data/verifiable_data/{info}_original/train',
            f'./data/verifiable_data/{info}_original/test',
            f'./data/verifiable_data/{info}_evolved/train',
            f'./data/verifiable_data/{info}_evolved/test',
            f'./data/verifiable_data/{info}_combined/train',
            f'./data/verifiable_data/{info}_combined/test',
            f'./data/verifiable_data/{info}_combined_all/train',
            f'./data/verifiable_data/{info}_combined_all/test',
            f'./data/verifiable_data/{info}_ori_only/train',
            f'./data/verifiable_data/{info}_evol_only/train',
        ]

        for directory in output_dirs:
            os.makedirs(directory, exist_ok=True)

        # Save datasets in parquet format
        print("Saving datasets to parquet files...")
        original_train.to_parquet(f'./data/verifiable_data/{info}_original/train/data.parquet')
        original_test.to_parquet(f'./data/verifiable_data/{info}_original/test/data.parquet')
        evolved_train.to_parquet(f'./data/verifiable_data/{info}_evolved/train/data.parquet')
        evolved_test.to_parquet(f'./data/verifiable_data/{info}_evolved/test/data.parquet')
        combined_train.to_parquet(f'./data/verifiable_data/{info}_combined/train/data.parquet')
        combined_test.to_parquet(f'./data/verifiable_data/{info}_combined/test/data.parquet')
        combined_all_train.to_parquet(f'./data/verifiable_data/{info}_combined_all/train/data.parquet')
        combined_all_test.to_parquet(f'./data/verifiable_data/{info}_combined_all/test/data.parquet')
        
        # Save ori_only and evol_only datasets
        ori_only_dataset.to_parquet(f'./data/verifiable_data/{info}_ori_only/train/data.parquet')
        evol_only_dataset.to_parquet(f'./data/verifiable_data/{info}_evol_only/train/data.parquet')

        print(f"Original dataset: {len(original_train)} train, {len(original_test)} test samples")
        print(f"Evolved dataset: {len(evolved_train)} train, {len(evolved_test)} test samples")
        print(f"Combined dataset: {len(combined_train)} train, {len(combined_test)} test samples")
        print(f"Combined_all dataset: {len(combined_all_train)} train, {len(combined_all_test)} test samples")
        print(f"Ori_only dataset: {len(ori_only_dataset)} train samples")
        print(f"Evol_only dataset: {len(evol_only_dataset)} train samples")

        # Modified return statement to include ori_only and evol_only datasets
        return original_dataset, evolved_dataset, combined_dataset, combined_all_dataset, ori_only_dataset, evol_only_dataset


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Convert Huggingface dataset format')
    parser.add_argument('--input_path', type=str, required=True, help='Input path relative to ./evolved_data/')
    parser.add_argument('--info', type=str, required=True, help='Info tag for output directory naming')
    parser.add_argument('--use_original_as_fallback', action='store_true',
                        help='If true, use original question as evolved when no evolved questions exist')
    parser.add_argument('--lower_threshold', type=int, default=6, 
                        help='Lower threshold for pass rates (inclusive)')
    parser.add_argument('--upper_threshold', type=int, default=14,
                        help='Upper threshold for pass rates (inclusive)')
    args = parser.parse_args()

    # Pass thresholds to convert_dataset
    results = convert_dataset(
        args.input_path,
        args.info,
        args.use_original_as_fallback,
        args,
        lower_threshold=args.lower_threshold,
        upper_threshold=args.upper_threshold
    )
    
    # Check if filtered datasets were created
    if args.lower_threshold is not None or args.upper_threshold is not None:
        (original_dataset, evolved_dataset, combined_dataset, combined_all_dataset, 
         ori_only_dataset, evol_only_dataset, evolved_filtered_dataset, combined_filtered_dataset,
         combined_all_filtered_dataset) = results
        
        # Print summary statistics of filtered datasets
        print("\nFiltered Dataset Statistics:")
        
        # Calculate percentage of evolved vs original in filtered evolved dataset
        evolved_count = sum(1 for item in evolved_filtered_dataset if item['is_evolved'])
        original_count = len(evolved_filtered_dataset) - evolved_count
        print(f"Filtered evolved dataset breakdown:")
        print(f"  - Evolved questions: {evolved_count} ({evolved_count/len(evolved_filtered_dataset)*100:.2f}%)")
        print(f"  - Original fallbacks: {original_count} ({original_count/len(evolved_filtered_dataset)*100:.2f}%)")
        
        # Analyze pass rates distribution in filtered datasets
        evolved_pass_rates = [item['pass_rates'] for item in evolved_filtered_dataset if item['is_evolved'] and item['pass_rates'] is not None]
        if evolved_pass_rates:
            avg_pass_rate = sum(evolved_pass_rates) / len(evolved_pass_rates)
            min_pass_rate = min(evolved_pass_rates)
            max_pass_rate = max(evolved_pass_rates)
            print(f"Pass rates in filtered evolved dataset:")
            print(f"  - Average: {avg_pass_rate:.2f}")
            print(f"  - Min: {min_pass_rate}")
            print(f"  - Max: {max_pass_rate}")
        
        # Analyze pass rates in combined filtered dataset
        combined_evolved_pass_rates = [item['pass_rates'] for item in combined_filtered_dataset if item['is_evolved'] and item['pass_rates'] is not None]
        if combined_evolved_pass_rates:
            avg_pass_rate = sum(combined_evolved_pass_rates) / len(combined_evolved_pass_rates)
            min_pass_rate = min(combined_evolved_pass_rates)
            max_pass_rate = max(combined_evolved_pass_rates)
            print(f"Pass rates in filtered combined dataset (evolved questions only):")
            print(f"  - Average: {avg_pass_rate:.2f}")
            print(f"  - Min: {min_pass_rate}")
            print(f"  - Max: {max_pass_rate}")
            
        # Analyze pass rates in combined_all filtered dataset
        combined_all_evolved_pass_rates = [item['pass_rates'] for item in combined_all_filtered_dataset if item['is_evolved'] and item['pass_rates'] is not None]
        if combined_all_evolved_pass_rates:
            avg_pass_rate = sum(combined_all_evolved_pass_rates) / len(combined_all_evolved_pass_rates)
            min_pass_rate = min(combined_all_evolved_pass_rates)
            max_pass_rate = max(combined_all_evolved_pass_rates)
            print(f"Pass rates in filtered combined_all dataset (evolved questions only):")
            print(f"  - Average: {avg_pass_rate:.2f}")
            print(f"  - Min: {min_pass_rate}")
            print(f"  - Max: {max_pass_rate}")
    else:
        (original_dataset, evolved_dataset, combined_dataset, combined_all_dataset, 
         ori_only_dataset, evol_only_dataset) = results