import numpy as np
import os
from datasets import load_from_disk
from tqdm import tqdm
import argparse, json


parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='/path/to/config', help='Config Path')
args = parser.parse_args()
config = json.load(open(args.config))

full_dataset = load_from_disk(config["dataset_path"])

labels = np.load(os.path.join(config["working_dir"], "cluster_assignments.npy")).astype(int)

modified_dataset = full_dataset

# Check if 'label' column exists in the dataset
if 'label' in full_dataset.column_names:
    # Replace the existing 'label' column
    modified_dataset = full_dataset.remove_columns(['label'])

modified_dataset = modified_dataset.add_column("label", labels)

# Save each cluster individually as a separate dataset on disk
for i in tqdm(range(config["clustering_k"])):
    cluster_dataset = modified_dataset.filter(lambda x: x["label"] == i, num_proc=16,keep_in_memory =False)
    cluster_dataset.save_to_disk(os.path.join(config["working_dir"], "dataset_split" ,f"{i}"), num_proc=16)