import json

import random

random.seed(0)

data_dir = 'data/'

datasets = ['boolq', 'piqa', 'siqa', 'hellaswag', 'winogrande', 'arce', 'arcc', 'obqa']

split_info = {}

for split_idx in range(3):
    train_datasets = random.sample(datasets, 4)
    test_datasets = [d for d in datasets if not d in train_datasets]

    split_info[f'split_{split_idx}'] = {'train': train_datasets, 'test': test_datasets}

    train_data = []
    for ds in train_datasets:
        ds_file = f'{data_dir}/{ds}.uni_train_0-shot.json'
        train_data.extend(json.load(open(ds_file, 'r')))
    json.dump(train_data, open(f'{data_dir}/commonsense.uni.split.{split_idx}_train_0-shot.json', 'w'), indent=2, ensure_ascii=False)

    # test_data = []
    # for ds in test_datasets:
    #     ds_file = f'{data_dir}/{ds}_test_0-shot.json'
    #     test_data.extend(json.load(open(ds_file, 'r')))
    # json.dump(test_data, open(f'{data_dir}/commonsense.split.{split_idx}_test_0-shot.json', 'w'), indent=2, ensure_ascii=False)

    dataset_info = json.load(open(f'{data_dir}/dataset_info.json', 'r'))
    dataset_info[f'commonsense.uni.split.{split_idx}_train_0-shot'] = {
        "file_name": f"commonsense.uni.split.{split_idx}_train_0-shot.json",
        "columns": {
            "prompt": "instruction",
            "query": "input",
            "response": "output"
        }
    }
    # dataset_info[f'commonsense.split.{split_idx}_test_0-shot'] = {
    #     "file_name": f"commonsense.split.{split_idx}_test_0-shot.json",
    #     "columns": {
    #         "prompt": "instruction",
    #         "query": "input",
    #         "response": "output"
    #     }
    # }
    json.dump(dataset_info, open(f'{data_dir}/dataset_info.json', 'w'), indent=2, ensure_ascii=False)

json.dump(split_info, open(f'held_out_info_uni.json', 'w'), indent=2, ensure_ascii=False)