#!/usr/bin/env python3
"""
Script to download and prepare the seed datasets for SmolTraces as mentioned in the paper.

The datasets to collect are:
1. OlympicArena - Scientific domains
2. AGIEval - Logic
3. LiveCodeBench v4 - Coding problems of varying difficulty
4. NuminaMATH - Math
5. OmniMath - Math

This script handles downloading, basic validation, and preparation of these datasets.
"""

import os
import sys
import argparse
import logging
from datasets import load_dataset
import pandas as pd
import json
from tqdm import tqdm
from datasets import Dataset
from typing import Dict, List, Any, Optional, Union

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Define the datasets and their information
SEED_DATASETS = {
    "OlympicArena": {
        "hf_path": "GAIR/OlympicArena",  
        "configs": ["Math", "Physics", "Chemistry", "Biology", "Geography", "Astronomy", "CS"],
        "splits": ["val", "test"],  
        "description": "Scientific domains benchmark with ~11k bilingual problems",
        "domain": "science"
    },
    "AGIEval": {
        "hf_path": "lighteval/agi_eval_en",
        "configs": ["aqua_rat", "logiqa-en", "lsat-ar", "lsat-lr", "lsat-rc", "math", "sat-en", "sat-math"],
        "splits": ["train", "validation"], 
        "description": "Logic and reasoning benchmark from standardized tests",
        "domain": "logic"
    },
    "LiveCodeBench": {
        "hf_path": "livecodebench/code_generation_lite", 
        "splits": ["test"],
        "description": "713 coding problems of varying difficulty",
        "domain": "coding",
        "version_tag": "release_v4",
        "max_samples": 713,
        "use_manual_processing": True  # Flag to indicate manual processing for large datasets
    },
    "NuminaMath": {
        "hf_path": "AI-MO/NuminaMath-TIR",
        "splits": ["train", "test"],
        "description": "Math competition dataset with ~72k problem-solution pairs",
        "domain": "math",
        "max_samples": 20000  # Limit to 20k samples randomly sampled as per paper
    },
    "OmniMath": {
        "hf_path": "KbsdJames/Omni-MATH",
        "splits": ["test"],  # Only has test split
        "description": "Olympiad-level math benchmark with ~4.4k problems",
        "domain": "math"
    }
}

def download_dataset(dataset_info, output_dir, sample_size=None, save_all_samples=True):
    """
    Download a dataset from the Hugging Face Hub and save its structure.
    
    Args:
        dataset_info: Dictionary with dataset information
        output_dir: Directory to save dataset information
        sample_size: Number of samples to download (None = all)
        save_all_samples: Whether to save all samples or just a reference subset (default: True)
    """
    dataset_name = dataset_info["hf_path"]
    if not dataset_name:
        logging.warning(f"Dataset {dataset_name} is not available on HuggingFace. Follow manual instructions.")
        return None
    
    try:
        all_splits = {}
        
        # Handle special case for AGIEval which requires configuration names
        if "configs" in dataset_info:
            configs = dataset_info["configs"]
            for config in configs:
                logging.info(f"Loading dataset: {dataset_name}, config: {config}")
                
                for split in dataset_info["splits"]:
                    logging.info(f"  Split: {split}")
                    
                    # Load the dataset with specified sample size or complete
                    try:
                        if sample_size:
                            dataset = load_dataset(dataset_name, config, split=f"{split}[:{sample_size}]", trust_remote_code=True)
                            logging.info(f"  Loaded {len(dataset)} samples (limited to {sample_size})")
                        else:
                            dataset = load_dataset(dataset_name, config, split=split, trust_remote_code=True)
                            logging.info(f"  Loaded complete dataset with {len(dataset)} samples")
                        
                        # Store in our dictionary
                        config_key = f"{config}_{split}"
                        all_splits[config_key] = dataset
                        
                        # Save dataset info and sample
                        dataset_dir = os.path.join(output_dir, f"{dataset_name.replace('/', '_')}_{config}")
                        os.makedirs(dataset_dir, exist_ok=True)
                        
                        # Save column info
                        column_info = {
                            "config": config,
                            "split": split,
                            "columns": dataset.column_names,
                            "num_samples": len(dataset)
                        }
                        
                        with open(os.path.join(dataset_dir, f"{split}_info.json"), "w") as f:
                            json.dump(column_info, f, indent=2)
                        
                        # Save samples - by default save all samples unless explicitly asked to limit
                        samples_to_save = dataset
                        if not save_all_samples and not sample_size and len(dataset) > 5:
                            samples_to_save = dataset.select(range(min(5, len(dataset))))
                            logging.info(f"  Saving 5 sample entries as reference (out of {len(dataset)} total)")
                        else:
                            logging.info(f"  Saving all {len(samples_to_save)} sample entries")
                        
                        samples = []
                        for i in range(len(samples_to_save)):
                            sample = samples_to_save[i]
                            if not isinstance(sample, dict):
                                sample = dict(sample)
                            
                            # Prepare a more readable sample for saving
                            readable_sample = {}
                            for k, v in sample.items():
                                if isinstance(v, str) and len(v) > 500:
                                    readable_sample[k] = v[:500] + "... [truncated]"
                                else:
                                    readable_sample[k] = v
                            
                            samples.append(readable_sample)
                        
                        with open(os.path.join(dataset_dir, f"{split}_samples.json"), "w") as f:
                            json.dump(samples, f, indent=2)
                        
                        logging.info(f"  Saved dataset info and samples to {dataset_dir}")
                    except Exception as e:
                        logging.error(f"  Error loading {dataset_name}/{config}/{split}: {e}")
        else:
            # Standard dataset loading
            for split in dataset_info["splits"]:
                logging.info(f"Loading dataset: {dataset_name}, split: {split}")
                
                # Load the dataset with specified sample size or complete
                try:
                    # Handle special case for LiveCodeBench which needs version_tag
                    if dataset_name == "livecodebench/code_generation_lite" and "version_tag" in dataset_info:
                        version_tag = dataset_info["version_tag"]
                        logging.info(f"Using version_tag: {version_tag}")
                        
                        if sample_size:
                            dataset = load_dataset(dataset_name, version_tag=version_tag, 
                                                split=f"{split}[:{sample_size}]", trust_remote_code=True)
                            logging.info(f"Loaded {len(dataset)} samples (limited to {sample_size})")
                        else:
                            dataset = load_dataset(dataset_name, version_tag=version_tag, 
                                                split=split, trust_remote_code=True)
                            logging.info(f"Loaded complete dataset with {len(dataset)} samples")
                    else:
                        if sample_size:
                            dataset = load_dataset(dataset_name, split=f"{split}[:{sample_size}]", trust_remote_code=True)
                            logging.info(f"Loaded {len(dataset)} samples (limited to {sample_size})")
                        else:
                            dataset = load_dataset(dataset_name, split=split, trust_remote_code=True)
                            logging.info(f"Loaded complete dataset with {len(dataset)} samples")
                    
                    # Apply any filtering function if provided
                    if "filter_func" in dataset_info and dataset_info["filter_func"]:
                        original_size = len(dataset)
                        dataset = dataset.filter(dataset_info["filter_func"])
                        logging.info(f"  Filtered dataset from {original_size} to {len(dataset)} samples")
                    
                    # Handle max_samples limit if specified in dataset_info
                    if "max_samples" in dataset_info and dataset_info["max_samples"]:
                        if len(dataset) > dataset_info["max_samples"]:
                            dataset = dataset.select(range(dataset_info["max_samples"]))
                            logging.info(f"  Limited dataset to first {dataset_info['max_samples']} samples")
                    
                    # Store in our dictionary
                    all_splits[split] = dataset
                    
                    # Save dataset info and sample
                    dataset_dir = os.path.join(output_dir, dataset_name.replace("/", "_"))
                    os.makedirs(dataset_dir, exist_ok=True)
                    
                    # Save column info
                    column_info = {
                        "columns": dataset.column_names,
                        "num_samples": len(dataset)
                    }
                    
                    with open(os.path.join(dataset_dir, f"{split}_info.json"), "w") as f:
                        json.dump(column_info, f, indent=2)
                    
                    # Save samples - by default save all samples unless explicitly asked to limit
                    samples_to_save = dataset
                    if not save_all_samples and not sample_size and len(dataset) > 5:
                        samples_to_save = dataset.select(range(min(5, len(dataset))))
                        logging.info(f"Saving 5 sample entries as reference (out of {len(dataset)} total)")
                    else:
                        logging.info(f"Saving all {len(samples_to_save)} sample entries")
                    
                    samples = []
                    for i in range(len(samples_to_save)):
                        sample = samples_to_save[i]
                        if not isinstance(sample, dict):
                            sample = dict(sample)
                        
                        # Prepare a more readable sample for saving
                        readable_sample = {}
                        for k, v in sample.items():
                            if isinstance(v, str) and len(v) > 500:
                                readable_sample[k] = v[:500] + "... [truncated]"
                            else:
                                readable_sample[k] = v
                        
                        samples.append(readable_sample)
                    
                    with open(os.path.join(dataset_dir, f"{split}_samples.json"), "w") as f:
                        json.dump(samples, f, indent=2)
                    
                    logging.info(f"Saved dataset info and samples to {dataset_dir}")
                except Exception as e:
                    logging.error(f"Error loading {dataset_name}/{split}: {e}")
        
        return all_splits
    
    except Exception as e:
        logging.error(f"Error loading dataset {dataset_name}: {e}")
        return None

def normalize_dataset_manually(dataset: Dataset, dataset_name: str, output_path: str) -> None:
    """
    Normalize datasets with large string fields manually to avoid PyArrow concat issues.
    
    Args:
        dataset: The dataset to normalize
        dataset_name: Name of the dataset for specific handling
        output_path: Path to save the normalized data
    """
    logging.info(f"Manual normalization for {dataset_name} with {len(dataset)} samples")
    
    with open(output_path, "w") as f:
        for i in tqdm(range(len(dataset)), desc=f"Normalizing {dataset_name}"):
            example = dataset[i]
            if not isinstance(example, dict):
                example = dict(example)
            
            # Normalize LiveCodeBench
            if dataset_name == "LiveCodeBench":
                normalized = {}
                
                # Combine question title and content for question field
                question = ""
                if "question_title" in example:
                    question += f"# {example['question_title']}\n\n"
                if "question_content" in example:
                    question += example["question_content"]
                
                # Add example test cases if available
                if "public_test_cases" in example:
                    try:
                        test_cases = json.loads(example["public_test_cases"])
                        if test_cases:
                            question += "\n\nExample Test Cases:\n"
                            for i, test_case in enumerate(test_cases):
                                question += f"\nTest Case {i+1}:\n"
                                if "input" in test_case:
                                    question += f"Input:\n{test_case['input']}\n"
                                if "output" in test_case:
                                    question += f"Output:\n{test_case['output']}\n"
                    except:
                        # If test cases can't be parsed, use as-is
                        question += f"\n\nExample Test Cases:\n{example['public_test_cases']}"
                
                # Add starter code if available
                if "starter_code" in example and example["starter_code"]:
                    question += f"\n\nStarter Code:\n```\n{example['starter_code']}\n```"
                
                normalized["question"] = question
                
                # This is an seed dataset without solutions - answer will be generated by models
                normalized["answer"] = ""
                
                # Add metadata for tracking
                normalized["metadata"] = {
                    "dataset": "LiveCodeBench",
                    "difficulty": example.get("difficulty", ""),
                    "question_id": example.get("question_id", ""),
                    "platform": example.get("platform", "")
                }
            # Add more dataset-specific handlers as needed
            
            # Write the normalized example
            f.write(json.dumps(normalized) + "\n")
    
    logging.info(f"Saved manually normalized dataset to {output_path}")

def normalize_dataset_format(dataset: Dataset, dataset_name: str) -> Dataset:
    """
    Normalize datasets to a common format with 'question' and 'answer' fields.
    
    Args:
        dataset: The dataset to normalize
        dataset_name: Name of the dataset for specific handling
        
    Returns:
        Normalized dataset
    """
    if dataset_name == "LiveCodeBench":
        # LiveCodeBench v4 format - this might cause issues with large strings
        def map_livecode(example):
            question = ""
            # Combine problem description and example test cases
            if "description" in example:
                question = example["description"]
            if "example_test_cases" in example:
                question += f"\n\nExample Test Cases:\n{example['example_test_cases']}"
                
            return {
                "question": question,
                "answer": example.get("solution", "") if "solution" in example else example.get("canonical_solution", ""),
                "original_data": example
            }
        
        try:
            return dataset.map(map_livecode)
        except Exception as e:
            logging.warning(f"Exception during dataset.map for LiveCodeBench: {e}")
            logging.warning("Will handle this dataset separately with manual processing")
            return dataset
    
    # Default handling for unknown datasets
    def map_default(example):
        # Try to find question and answer fields with various common names
        question_field = next((field for field in ["question", "problem", "prompt"] if field in example), None)
        answer_field = next((field for field in ["answer", "solution", "response"] if field in example), None)
        
        if question_field and answer_field:
            return {
                "question": example[question_field],
                "answer": example[answer_field],
                "original_data": example
            }
        else:
            # Just pass through the original data if we can't find appropriate fields
            return {
                "question": str(example),
                "answer": "",
                "original_data": example
            }
    
    return dataset.map(map_default)

def save_dataset_metadata(output_dir):
    """
    Save metadata about all datasets to a README file.
    
    Args:
        output_dir: Directory where datasets are saved
    """
    readme_content = """# SmolTraces Seed Datasets

This directory contains the seed datasets used for training SmolTraces models as referenced in the paper.

## Datasets

| Dataset | Domain | Description | Size | Source |
|---------|--------|-------------|------|--------|
"""
    
    for name, info in SEED_DATASETS.items():
        source = info["hf_path"]
        description = info['description']
        if name == "NuminaMath" and info.get("max_samples"):
            description += f" (Subset of {info['max_samples']})" # Add subset info
        readme_content += f"| {name} | {info['domain']} | {description} | Varies | {source} |\n"
    
    readme_content += """
## Usage

Each dataset has been downloaded with its structure preserved. Sample files are provided 
to understand the format of each dataset.
"""
    
    with open(os.path.join(output_dir, "README.md"), "w") as f:
        f.write(readme_content)
    
    logging.info(f"Saved dataset metadata to {os.path.join(output_dir, 'README.md')}")

def main():
    parser = argparse.ArgumentParser(description="Download and prepare seed datasets for SmolTraces")
    parser.add_argument("--output_dir", type=str, default="datasets/seed_data",
                        help="Directory to save the downloaded datasets")
    parser.add_argument("--sample_size", type=int, default=None,
                        help="Number of samples to download from each dataset (None = all)")
    parser.add_argument("--save_reference_only", action="store_true",
                        help="Save only 5 reference samples instead of the complete dataset")
    parser.add_argument("--datasets", type=str, nargs="+", default=list(SEED_DATASETS.keys()),
                        help=f"Datasets to download. Options: {', '.join(SEED_DATASETS.keys())}")
    
    args = parser.parse_args()
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Track datasets that were successfully downloaded
    successful_downloads = []
    
    # Process each dataset
    for dataset_name in args.datasets:
        if dataset_name not in SEED_DATASETS:
            logging.warning(f"Dataset {dataset_name} not recognized. Skipping.")
            continue
        
        logging.info(f"Processing dataset: {dataset_name}")
        dataset_info = SEED_DATASETS[dataset_name]
        
        # Pass the inverse of save_reference_only as save_all_samples
        result = download_dataset(dataset_info, args.output_dir, args.sample_size, not args.save_reference_only)
        if result:
            # Apply normalization based on whether manual processing is needed
            use_manual_processing = dataset_info.get("use_manual_processing", False)
            
            for split_name, dataset in result.items():
                # Output path for normalized data
                dataset_dir = os.path.join(args.output_dir, dataset_info["hf_path"].replace("/", "_"))
                normalized_file = os.path.join(dataset_dir, f"{split_name}_normalized.jsonl")
                
                if use_manual_processing:
                    # For datasets with large text that might cause PyArrow issues
                    normalize_dataset_manually(dataset, dataset_name, normalized_file)
                else:
                    # Standard processing for smaller datasets
                    normalized_dataset = normalize_dataset_format(dataset, dataset_name)
                    
                    with open(normalized_file, "w") as f:
                        for i in range(len(normalized_dataset)):
                            sample = normalized_dataset[i]
                            if not isinstance(sample, dict):
                                sample = dict(sample)
                            
                            # Remove the original_data field for storage efficiency
                            if "original_data" in sample:
                                clean_sample = {k: v for k, v in sample.items() if k != "original_data"}
                                f.write(json.dumps(clean_sample) + "\n")
                            else:
                                f.write(json.dumps(sample) + "\n")
                    
                    logging.info(f"Saved normalized dataset to {normalized_file}")
            
            successful_downloads.append(dataset_name)
    
    # Save metadata
    save_dataset_metadata(args.output_dir)
    
    # Summary
    logging.info("\n" + "="*50)
    logging.info(f"Downloaded {len(successful_downloads)} of {len(args.datasets)} datasets")
    logging.info(f"Successful: {', '.join(successful_downloads)}")
    if len(successful_downloads) < len(args.datasets):
        failed = set(args.datasets) - set(successful_downloads)
        logging.warning(f"Failed: {', '.join(failed)}")
    logging.info("="*50)

if __name__ == "__main__":
    main() 