import numpy as np
from datasets import concatenate_datasets, load_from_disk

if __name__ == "__main__":
    data = load_from_disk("data_dir/wikipedia_bio")
    # shuffle the data
    data = data.shuffle(seed=1234)
    # split the data in half
    male_data = data.filter(lambda x: x["gender"] == "male")
    other_data = data.filter(lambda x: x["gender"] != "male")
    male_data = male_data.train_test_split(test_size=0.1, seed=1234)
    other_data = other_data.train_test_split(test_size=0.1, seed=1234)
    train_data = concatenate_datasets([male_data["train"], other_data["train"]])
    test_data = concatenate_datasets([male_data["test"], other_data["test"]])
    # save the data
    gender_counts = train_data["gender"]
    print("Gender counts in training data:")
    print(np.unique(gender_counts, return_counts=True))

    print("Gender counts in test data:")
    print(np.unique(test_data["gender"], return_counts=True))
    train_data.save_to_disk("data_dir/wikibio_train")
    test_data.save_to_disk("data_dir/wikibio_test")
    print(
        "Train and test datasets saved to data_dir/wikibio_train and data_dir/wikibio_test"
    )
