import os
import json
import random


def add_dataset_name_for_one_data(datas, dataset_name):
    for data in datas:
        data["dataset_name"] = dataset_name
    return datas

def add_dataset_name(dataset_name):
    task_num = len(os.listdir(f'data/task_splits/{dataset_name}'))
    tasks = []
    for task_id in range(task_num):
        one_task = []
        for mode in ['train', 'dev', 'test', 'tables']:
            task_path = f'data/task_splits/{dataset_name}/task_{task_id}/{mode}.json'
            datas = json.load(open(task_path, 'r'))
            datas = add_dataset_name_for_one_data(datas, dataset_name)
            one_task.append(datas)
        tasks.append(one_task)
    return tasks

def save_tasks(combine_tasks, combine_dataset_name):
    os.makedirs(f'data/task_splits/{combine_dataset_name}/', exist_ok=True)

    for i in range(len(combine_tasks)):
        one_task_dir = f'data/task_splits/{combine_dataset_name}/task_{i}'
        os.makedirs(one_task_dir, exist_ok=True)
        for j, mode in enumerate(['train', 'dev', 'test', 'tables']):
            path = f'data/task_splits/{combine_dataset_name}/task_{i}/{mode}.json'

            json.dump(combine_tasks[i][j], open(path, "w"), indent=2)


if __name__ == '__main__':
    random.seed(2023)

    spider_tasks = add_dataset_name('spider')
    wikisql_tasks = add_dataset_name('wikisql')
    cosql_tasks = add_dataset_name('cosql')


    spider_task_ids = [0, 13, 5, 15, 10, 12, 11, 9]
    wikisql_task_ids = [2, 7, 3]
    combine_tasks_1 = [spider_tasks[i] for i in spider_task_ids] + [wikisql_tasks[i] for i in wikisql_task_ids]
    print(''.join([f'({i + 1}, {x}) ' for i, x in enumerate([len(x[0]) for x in combine_tasks_1])]))
    print(''.join([f'({i + 1}, {x}) ' for i, x in enumerate([len(x[1]) for x in combine_tasks_1])]))
    print(''.join([f'({i + 1}, {x}) ' for i, x in enumerate([len(x[2]) for x in combine_tasks_1])]))

    # save_tasks(combine_tasks_1, 'combine1')

    # spider_task_ids = random.sample([x for x in range(16)], 11)
    # cosql_task_ids = [x for x in range(16) if x not in spider_task_ids]
    # print(spider_task_ids)
    # print(cosql_task_ids)

    # spider_task_ids = [12, 11, 7, 6, 5, 9, 10, 14, 1, 15, 0]
    # cosql_task_ids = [2, 3, 4, 8, 13]
    # combine_tasks_2 = [spider_tasks[i] for i in spider_task_ids] + [cosql_tasks[i] for i in cosql_task_ids]
    # save_tasks(combine_tasks_2, 'combine2')
