import sys
sys.path.append('..')

from typing import Any
from pathlib import Path

from datasets import Dataset, load_dataset
from datasets.utils.logging import disable_progress_bar
from dotenv import load_dotenv
from tqdm.auto import tqdm

_ = load_dotenv()
_ = disable_progress_bar()

ROOT = Path('.')
DATA = ROOT / 'data'

DATA.mkdir(parents=True, exist_ok=True)

output_text_column = 'text'

def create_dataset(
    output_dir: Path, 
    dataset_dict: dict[str, dict[str, Any]], 
    min_words: int = 0  # new param
):
    if output_dir.exists():
        return

    args = dataset_dict['args']
    num_samples: int = dataset_dict['num_samples']  # type: ignore
    text_column: str = dataset_dict['text_column']  # type: ignore

    # Load dataset in streaming mode
    dataset = load_dataset(*args, split='train', streaming=True, trust_remote_code=False)

    # Filter and collect samples with at least `min_words` words
    filtered_samples = []
    for sample in (bar := tqdm(dataset, desc="Filtering samples")):
        text = sample.get(text_column, "")
        
        if isinstance(text, str) and len(text.strip().split()) >= min_words:
            filtered_samples.append(sample)

        bar.set_postfix({'Collected Samples': len(filtered_samples)})

        if len(filtered_samples) >= num_samples:
            break

    if len(filtered_samples) < num_samples:
        raise ValueError(f"Only found {len(filtered_samples)} samples with at least {min_words} words.")

    # Convert to Hugging Face Dataset
    dataset = Dataset.from_list(filtered_samples)

    # Rename column if needed
    if text_column != output_text_column:
        dataset = dataset.rename_column(text_column, output_text_column)

    dataset.save_to_disk(output_dir)

dataset_dicts = {
    'wikipedia': {
        'args': ['wikimedia/wikipedia', '20231101.en'],
        'num_samples': 25_000,
        'text_column': 'text'
    },
    'colossal_clean_crawled_corpus': {
        'args': ['allenai/c4', 'en'],
        'num_samples': 25_000,
        'text_column': 'text'
    },
    'arxiv_pile': { # TODO: Cite in paper
        'args': ['timaeus/pile-arxiv'],
        'num_samples': 25_000,
        'text_column': 'text'
    },
    'github_python': {
        'args': ['angie-chen55/python-github-code'],
        'num_samples': 25_000,
        'text_column': 'code'
    },

}

for name, dataset_dict in dataset_dicts.items():
    create_dataset(DATA / name, dataset_dict, min_words=500)