import numpy as np, json
DATA_COUNT = 50
np.random.seed(1331)

def split_dataset(dataset, dataset_name):
    random_samples = np.random.choice(len(dataset), DATA_COUNT, replace=False)
    random_samples = [dataset[i] for i in random_samples]
    with open("./data/ablation_study/{}_ablation_data.txt".format(dataset_name), "w") as f:
        json.dump(random_samples, f, indent=4)
    return random_samples



def read_big_bench(address, dataset_name):
    final_data = []
    with open(address, "r") as f:
        dataset = json.load(f)
        all_data = dataset["examples"]
        task_prefix = all_data["description"]
        for data in all_data:
            data["input"] = task_prefix + ":\n" + data["input"]
            final_data.append(data)
    return split_dataset(final_data, dataset_name)



def read_gsm_series(address, dataset_name):
    f = address
    all_data = [json.loads(line.strip()) for line in open(f, encoding = "utf-8")]
    return split_dataset(all_data, dataset_name)


if __name__ == "__main__":
    read_big_bench("./data/gsm8k/test.jsonl", "gsm8k_test")