import os
from datasets import load_dataset

def download_and_split(dataset_name, save_dir, client_num=3, split_strategy='iid'):
    dataset = load_dataset(dataset_name)

    for split in dataset.keys():
        examples = dataset[split]
        total = len(examples)
        os.makedirs(os.path.join(save_dir, dataset_name, split), exist_ok=True)

        if split_strategy == 'iid':
            # Random shuffle and equal split
            indices = list(range(total))
            from random import shuffle
            shuffle(indices)
            part_size = total // client_num

            for i in range(client_num):
                sub = [examples[j] for j in indices[i*part_size:(i+1)*part_size]]
                with open(os.path.join(save_dir, dataset_name, split, f'client_{i}.jsonl'), 'w') as f:
                    for ex in sub:
                        f.write(str(ex) + '\n')
        else:
            # placeholder for non-iid splits
            print("Non-iid strategy not implemented here. Use prepare_non_iid.py.")