import json
import random

DATASET_TYPES = ['decomposed_queries', 'image_tasks', 'agent_tasks']

for group in ['train', 'val', 'test']:
    
    # Read in each dataset
    dataset_dict = {}
    for dataset_type in DATASET_TYPES:
        with open(f'{dataset_type}/{group}_data.json', 'r') as f:
            dataset_dict[dataset_type] = json.load(f)
    
    if group in ['val', 'test']:
        for dataset_type in DATASET_TYPES:
            random.shuffle(dataset_dict[dataset_type])
        
        # Find the smallest length among them
        min_len = min(len(dataset_dict[dt]) for dt in DATASET_TYPES)
        print(min_len)
        
        # Slice each dataset to min_len
        for dt in DATASET_TYPES:
            dataset_dict[dt] = dataset_dict[dt][:min_len]

    # Combine the datasets
    combined_list = []
    for dataset_type in DATASET_TYPES:
        combined_list.extend(dataset_dict[dataset_type])

    if group == "val":
        random.shuffle(combined_list)
        combined_list = combined_list[:int(len(combined_list)/2)]

    # Dump into the combined folder
    with open(f'combined/{group}_data.json', 'w') as f:
        json.dump(combined_list, f, indent=4)
