

import os
import json
import sys
from datasets import load_dataset

# Fix Unicode encoding issues on Windows
if sys.platform == 'win32':
    sys.stdout.reconfigure(encoding='utf-8')


# ------------------------------
# Config
# ------------------------------
SEED = 42

# Datasets to download
DATASETS_CONFIG = [
    {
        "name": "GSM8K",
        "hf_name": "gsm8k",
        "config": "main",
        "split": "train",
        "question_field": "question",
        "answer_field": "answer",
        "local_file": "dataset_GSM8K.json",
        "n_samples": 1000,
        "trust_remote_code": False
    },
    {
        "name": "ASDiv",
        "hf_name": "EleutherAI/asdiv",
        "split": "validation",
        "question_field": "question",
        "answer_field": "answer",
        "body_field": "body",
        "local_file": "dataset_ASDiv.json",
        "n_samples": 1000,
        "trust_remote_code": False
    },
    {
        "name": "SVAMP",
        "hf_name": "ChilleD/SVAMP",
        "split": "train",
        "question_field": "Body",
        "answer_field": "Answer",
        "local_file": "dataset_SVAMP.json",
        "n_samples": 1000,
        "trust_remote_code": False
    }
]

# ------------------------------
# Helper Functions
# ------------------------------
def load_and_prepare_dataset(dataset_config, seed):
    """Load and prepare a dataset according to its configuration. Uses local cache if available."""
    local_file = dataset_config['local_file']
    n_questions = dataset_config['n_samples']

    # Check if we already have the dataset saved locally
    if os.path.exists(local_file):
        print(f"\nLoading {dataset_config['name']} from local file: {local_file}")
        with open(local_file, "r", encoding='utf-8') as f:
            return json.load(f)

    print(f"\nDownloading {dataset_config['name']} dataset from HuggingFace...")

    try:
        # Load dataset with trust_remote_code for datasets that need it
        trust_remote_code = dataset_config.get('trust_remote_code', False)

        if 'config' in dataset_config:
            ds = load_dataset(dataset_config['hf_name'], dataset_config['config'],
                            split=dataset_config['split'], trust_remote_code=trust_remote_code)
        else:
            ds = load_dataset(dataset_config['hf_name'],
                            split=dataset_config['split'], trust_remote_code=trust_remote_code)

        # Shuffle and select n_questions
        if len(ds) < n_questions:
            print(f"Warning: {dataset_config['name']} has only {len(ds)} samples, using all of them.")
            sample = ds.shuffle(seed=seed)
        else:
            sample = ds.shuffle(seed=seed).select(range(n_questions))

        # Extract questions and answers based on field names
        dataset_content = []
        for i, x in enumerate(sample):
            question = x.get(dataset_config['question_field'], "")
            answer = x.get(dataset_config['answer_field'], "")

            # Handle SVAMP's special structure
            if dataset_config['name'] == "SVAMP":
                question = f"{x.get('Body', '')} {x.get('Question', '')}"

            # Handle ASDiv - combine body and question
            if dataset_config['name'] == "ASDiv":
                body = x.get(dataset_config.get('body_field', 'body'), "")
                question = f"{body} {question}"

            dataset_content.append({
                "id": f"{dataset_config['name']}_{i}",
                "question": question,
                "answer": str(answer),
                "dataset": dataset_config['name']
            })

        # Save locally for future use
        with open(local_file, "w", encoding='utf-8') as f:
            json.dump(dataset_content, f, indent=2)
        print(f"Saved {dataset_config['name']} to {local_file}")

        return dataset_content

    except Exception as e:
        print(f"Error loading {dataset_config['name']}: {e}")
        return []

# ------------------------------
# Main Execution
# ------------------------------
print(f"=" * 70)
print(f"Dataset Downloader")
print(f"Datasets: {', '.join([d['name'] for d in DATASETS_CONFIG])}")
print(f"=" * 70)

all_questions = []

for dataset_config in DATASETS_CONFIG:
    dataset_items = load_and_prepare_dataset(dataset_config, SEED)
    all_questions.extend(dataset_items)
    print(f"{dataset_config['name']}: {len(dataset_items)} questions ready")

print("\n" + "=" * 70)
print(f"COMPLETE - Dataset files saved")
for dataset_config in DATASETS_CONFIG:
    count = len([q for q in all_questions if q['dataset'] == dataset_config['name']])
    print(f"- {dataset_config['local_file']}: {count} samples")
print("=" * 70)
