import os
import json
import random
from tqdm import tqdm


dataset_list = [
    "NER/ace2005-ner",
    "NER/conll-2003",
    "NER/few-nerd-supervised",
    "RC/fewrel",
    "RC/semeval",
    "RC/tacred",
    "ED/ace2005-dygie",
    "ED/maven",
    "ED/RichERE",
    "EAE/ace2005-eae",
    "EAE/RichERE",
    "SC/goemo",
    "SC/sst-5"
]

all_data = []
MAX = 30000
count = 0

for dataset in tqdm(dataset_list):
    file = open(os.path.join(os.path.join("../unified_data", dataset), "train.json"))
    data = json.load(file)
    if len(data["request_states"]) > MAX:
        print(dataset)
        data["request_states"] = random.sample(data["request_states"], k=MAX)
    else:
        pass
    count += len(data["request_states"])
    all_data.append(data)

print(count)
with open("../unified_data/all_data.json", "w") as f:
    json.dump(all_data, f, indent=4)
