import random
import os
import json


def wrap_datas(dataset, datas):
    for data in datas:
        data['db_id'] = '##'.join([dataset, data['db_id']])
    return datas

def wrap_dbs(dataset, dbs):
    for db in dbs:
        db['db_id'] = '##'.join([dataset, db['db_id']])
    return dbs

def fuse_task():

    spider_path = '../../text2sql/continual-text2sql/data/task_splits/spider_context/task_{}/{}.json'
    cosql_path = './data/task_splits/cosql_context/task_{}/{}.json'
    fusion_path = './data/task_splits/fusion_context/task_{}/{}.json'

    os.makedirs('./data/task_splits/fusion_context/', exist_ok=True)

    for task_id in range(16):
        os.makedirs(f'./data/task_splits/fusion_context/task_{task_id}', exist_ok=True)

        for mode in ['train', 'dev', 'test']:
            spider_datas = json.load(open(spider_path.format(task_id, mode), 'r'))
            cosql_datas = json.load(open(cosql_path.format(task_id, mode), 'r'))

            spider_datas = wrap_datas('spider', spider_datas)
            cosql_datas = wrap_datas('cosql', cosql_datas)

            fusion_datas = spider_datas + cosql_datas
            json.dump(fusion_datas, open(fusion_path.format(task_id, mode), 'w'))

        spider_dbs = json.load(open(spider_path.format(task_id, 'tables'), 'r'))
        cosql_dbs = json.load(open(cosql_path.format(task_id, 'tables'), 'r'))

        spider_dbs = wrap_dbs('spider', spider_dbs)
        cosql_dbs = wrap_dbs('cosql', cosql_dbs)

        fusion_dbs = spider_dbs + cosql_dbs
        json.dump(fusion_dbs, open(fusion_path.format(task_id, 'tables'), 'w'))


if __name__ == '__main__':
    fuse_task()