import json
import os
import numpy as np

current = os.path.dirname(__file__)


def keep_n_workers(n):
    data_all = os.path.join(
        current, "data", "cifar10h_p_0_spam_0_all_workers.json"
    )
    with open(data_all, "r") as file:
        data = json.load(file)

    rng = np.random.default_rng(42)
    nt = len(data)
    new_dict = {}
    for i in range(nt):
        answers = data[str(i)]
        new_dict[str(i)] = {}
        who = rng.choice(list(answers.keys()), n, replace=False)
        for key in who:
            new_dict[str(i)][str(key)] = answers[str(key)]
    return new_dict, data_all, n


def save_res(dico, n):
    new_data = os.path.join(
        current, "data", f"cifar10h_p_0_spam_0_nw_{n}.json"
    )
    with open(new_data, "w") as file:
        json.dump(dico, file)


if __name__ == "__main__":
    n = 10
    new, _, n = keep_n_workers(n)
    save_res(new, n)
