"""
Sample texts from HuggingFace dataset and save them to a specified directory.

Args:
    dataset_name (str): Name of the HuggingFace dataset.
    num_samples (int): Number of text samples.
    save_dir (str): Directory to save samples.
"""
import os
import argparse
from datasets import load_dataset
import tqdm
import yaml


def main(n: int, output_dir: str, seed: int, dataset_name: str, subset: str, split: str):
    os.makedirs(output_dir, exist_ok=True)
    dataset = load_dataset(dataset_name, subset, split=split, streaming=True, trust_remote_code=True)
    shuffled = dataset.shuffle(seed=seed, buffer_size=10_000)
    
    # sample and save texts
    count = 0
    text_names = []
    metadata = []  # list of dict
    for example in tqdm.tqdm(shuffled, total=n, desc="Sampling texts"):
        name = f'sample_{count}'
        t = example['text']
        path = os.path.join(output_dir, f'{name}.txt')
        with open(path, 'w') as f:
            f.write(t)
        text_names.append(name)
        metadata.append({k: v for k, v in example.items() if k != 'text'})

        count += 1
        if count > n:
            break
    
    # save the list of text names
    with open(os.path.join(output_dir, 'text_names.yaml'), 'w') as f:
        yaml.dump(text_names, f)
    # save the metadata
    with open(os.path.join(output_dir, 'metadata.yaml'), 'w') as f:
        yaml.dump(metadata, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Sample and save texts from a HuggingFace dataset.")
    parser.add_argument("--n", type=int, required=True, help="Number of texts to sample and save.")
    parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the texts.")
    parser.add_argument("--seed", type=int, default=0, help="Random seed for shuffling.")
    parser.add_argument("--dataset_name", type=str, default="allenai/c4", help="Name of the HuggingFace dataset.")
    parser.add_argument('--subset', type=str, default='en', help='Subset of the dataset to sample from (e.g., "en" for English).')
    parser.add_argument("--split", type=str, default="validation", help="Dataset split (e.g., 'train', 'test').")

    args = parser.parse_args()
    main(
        n=args.n, 
        output_dir=args.output_dir, 
        seed=args.seed, 
        dataset_name=args.dataset_name, 
        subset=args.subset,
        split=args.split
    )
