import os
import json
import pandas as pd

base_path = ''


def extract_wikitq_query(kb_dict, url_dict, split='train'):
    query_df = pd.read_json(os.path.join(base_path, 'openwikitable_dataset', f'{split}.json'))
    query_df = query_df[query_df['dataset'] == 'WikiTQ']

    url_list = []
    section_id_list = []
    for query in query_df.itertuples(index=False):
        url, section_title = url_dict[query.original_table_id]
        url_list.append(url)

        doc_sections = kb_dict[url]['section_titles']
        sec_id = -1
        for idx, sec_title in enumerate(doc_sections):
            if sec_title == section_title:
                sec_id = idx
                break

        section_id_list.append(sec_id)

    query_df['wikipedia_url'] = url_list
    query_df['evidence_section_id'] = section_id_list
    query_df['open_wikitable_split'] = [split] * len(section_id_list)
    query_df['dataset_name'] = ['OpenWikiTable'] * len(section_id_list)

    # Remove the query_df that we cannot find the exact match of section id.
    query_df = query_df[query_df['evidence_section_id'] != -1]

    # Remove redundant columns
    query_df = query_df.drop(columns=['question_id', 'original_table_id', 'sql', 
                                      'hard_positive_idx', 'positive_idx', 'negative_idx', 'dataset'])

    query_df.to_csv(os.path.join(base_path, f'{split}_clean.csv'), index=False)

# We will extract only the WikiTableQuestions query dataset 
# from the Open-WikiTable query dataset.
if __name__ == '__main__':

    ori_table_id_to_urls = json.load(open(os.path.join(base_path, 'openwikitable_wikipedia', 'ori_table_id_to_urls.json'), 'r'))
    kb_wiki = json.load(open(os.path.join(base_path, 'openwikitable_wikipedia', 'infoseek_kb_wiki.json'), 'r'))

    extract_wikitq_query(kb_wiki, ori_table_id_to_urls, split='train')
    extract_wikitq_query(kb_wiki, ori_table_id_to_urls, split='valid')
    extract_wikitq_query(kb_wiki, ori_table_id_to_urls, split='test')