{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Imputation for different dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import openai\n",
    "import pandas as pd\n",
    "import time\n",
    "import requests\n",
    "import json\n",
    "import jsonlines\n",
    "import re\n",
    "\n",
    "sk = \"YOUR API KEY\"\n",
    "openai.api_key = sk\n",
    "\n",
    "def chat(input_data, model=\"gpt-3.5-turbo\", temperature=0.8):\n",
    "    \n",
    "    nmessages = [{\"role\": \"user\", \"content\": input_data, \"temperature\": temperature}]\n",
    "\n",
    "    while (1):\n",
    "        try:\n",
    "            response = openai.ChatCompletion.create(          \n",
    "                model=model,\n",
    "                messages=nmessages\n",
    "            )\n",
    "            resmessage = response['choices'][0]['message']['content']\n",
    "            break\n",
    "        except:\n",
    "            time.sleep(10)\n",
    "            \n",
    "    return  resmessage"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Obtain all possible answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "queries = {}\n",
    "\n",
    "with open('../data/wikituples/final_data/queries.tsv', 'r') as f:\n",
    "    for line in f:\n",
    "        line = line.strip()\n",
    "        qid, query = line[:line.find('\\t')], line[line.find('\\t')+1:]\n",
    "        queries[qid] = query\n",
    "\n",
    "print(len(queries))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_entity_vocab(ignore_bad_title=True, min_ent_count=1):\n",
    "    entity_vocab = {}\n",
    "    bad_title = 0\n",
    "    few_entity = 0\n",
    "    with open('../data/wikituples/entity_vocab.txt', 'r', encoding=\"utf-8\") as f:\n",
    "        for line in f:\n",
    "            _, entity_id, entity_title, mid, count = line.strip().split('\\t')\n",
    "            if ignore_bad_title and entity_title == '':\n",
    "                bad_title += 1\n",
    "            elif int(count) < min_ent_count:\n",
    "                few_entity += 1\n",
    "            else:\n",
    "                entity_vocab[len(entity_vocab)] = {\n",
    "                    'wiki_id': int(entity_id),\n",
    "                    'wiki_title': entity_title,\n",
    "                    'count': count\n",
    "                }\n",
    "    print('total number of entity: %d\\nremove because of empty title: %d\\nremove because count<%d: %d'%(len(entity_vocab),bad_title,min_ent_count,few_entity))\n",
    "    return entity_vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "import jsonlines\n",
    "\n",
    "entity_vocab = load_entity_vocab(min_ent_count=2, ignore_bad_title=True)\n",
    "all_entity_set = set([item['wiki_id'] for _,item in entity_vocab.items()])\n",
    "\n",
    "entityid_to_text = {}\n",
    "for _,item in entity_vocab.items():\n",
    "    entityid_to_text[item['wiki_id']] = [item['wiki_title']]\n",
    "    \n",
    "\n",
    "train_tuple_to_table = {}\n",
    "train_tuple_id = 0\n",
    "with jsonlines.open('../data/wikituples/train_tables.jsonl', 'r') as f:\n",
    "    for table in f:\n",
    "        \n",
    "        table_id = table.get(\"_id\", \"\")\n",
    "        pgTitle = table.get(\"pgTitle\", \"\").lower()\n",
    "        secTitle = table.get(\"sectionTitle\", \"\").lower()\n",
    "        headers = table.get(\"processed_tableHeaders\", [])\n",
    "        rows = table.get(\"tableData\", {})\n",
    "        entity_cells = np.array(table.get(\"entityCell\",[[]]))\n",
    "        subject = table['subject_column']\n",
    "        for i in range(len(rows)):\n",
    "            for j in range(len(rows[i])):\n",
    "                if len(rows[i][j]['surfaceLinks']) > 0:\n",
    "                    if rows[i][j]['surfaceLinks'][0]['target']['id'] in all_entity_set:\n",
    "                        entityid_to_text[rows[i][j]['surfaceLinks'][0]['target']['id']].append(rows[i][j]['text'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "processed_tuples = []\n",
    "with open('./results/WikiTuples/GPT_wikituples_wo_evidence.jsonl', 'r') as f:\n",
    "    for line in f:\n",
    "        line = json.loads(line)\n",
    "        tuple_id = int(line['tuple_id'])\n",
    "        processed_tuples.append(tuple_id)\n",
    "\n",
    "print(len(processed_tuples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import jsonlines\n",
    "import random\n",
    "\n",
    "template = '''What's the most likely value for the [TO-FILL] cell in the table below? Please respond using JSON: {answer_format}, the key is attribute name of each [TO-FILL], value is the predicted value for each [TO-FILL].\\n'''\n",
    "tuples = {}\n",
    "\n",
    "count, acc = 0,0\n",
    "with open('../data/wikituples/missing_tables.jsonl', 'r') as f:\n",
    "    for line in f:\n",
    "        line = json.loads(line)\n",
    "        table_id = line['_id']\n",
    "        pgTitle, sectionTitle, tableCaption = line['pgTitle'], line['sectionTitle'], line['tableCaption']\n",
    "        tuple_ids = line['tuple_id']\n",
    "        tableData = line['tableData']\n",
    "        headers = line['processed_tableHeaders']\n",
    "        ground_truth = line['ground_truth']\n",
    "        caption = 'caption:' + line['pgTitle'] + ' | ' + line['sectionTitle'] + ' | ' + line['tableCaption']\n",
    "        for index, t_id in enumerate(tuple_ids):\n",
    "            \n",
    "            if t_id in processed_tuples:\n",
    "                continue\n",
    "\n",
    "            input_data = template + caption + '\\n'\n",
    "            tuple_d = tableData[index]\n",
    "            tuple_answer = ground_truth[index]\n",
    "            for j in range(len(headers)):\n",
    "                input_data += '|' + headers[j]\n",
    "            input_data += '|\\n'\n",
    "            missing_pos = []\n",
    "\n",
    "            try:\n",
    "                new_tuple_ids = [ tt for tt in tuple_ids if tt != t_id]\n",
    "                seleteced_tuple = random.choice(new_tuple_ids)\n",
    "                example_tuple = ground_truth[tuple_ids.index(seleteced_tuple)]\n",
    "\n",
    "                for j in range(len(headers)):\n",
    "                    if type(example_tuple[j]) is list:\n",
    "                        input_data += '|' + example_tuple[j][-1]\n",
    "                    else:\n",
    "                        input_data += '|' + example_tuple[j]\n",
    "\n",
    "                input_data += '|\\n'\n",
    "            except:\n",
    "                print(\"no other examples\")\n",
    "            \n",
    "            answer_format = '{'\n",
    "            for j in range(len(headers)):\n",
    "                \n",
    "                if tuple_d[j] == 'N/A':\n",
    "                    answer_format += headers[j] + \": \" + '\"\"' + \", \"\n",
    "                    input_data += '|' + '[TO-FILL]'\n",
    "                    missing_pos.append(j)\n",
    "                else:\n",
    "                    input_data += '|' + tuple_d[j]\n",
    "\n",
    "            input_data += '|\\n'\n",
    "\n",
    "            answer_format = answer_format[:-2] + '}'\n",
    "            # print(input_data)\n",
    "            input_data = input_data.replace('Red, Green & White}|', 'Red, Green & White|')\n",
    "            input_data = input_data.format(answer_format=answer_format)\n",
    "\n",
    "            print(\"---------------------------------------------------\")\n",
    "            print(f\"Input: \\n{input_data}\")\n",
    "\n",
    "            output = chat(input_data, model=\"gpt-3.5-turbo\", temperate=0.3)\n",
    "\n",
    "            print(f\"Output: \\n{output}\")\n",
    "                \n",
    "            fout = jsonlines.open('./results/WikiTuples/GPT_wikituples_wo_evidence.jsonl', 'a')\n",
    "            fout.write({'tuple_id':t_id, 'input': input_data, 'output': output})\n",
    "            fout.close()\n",
    "                "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Imputation with Retrieved Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "retrieved_tuples = {} # qid, top-k docids\n",
    "all_scores = defaultdict(dict)\n",
    "with open('../results/rerank/wikituples_test.tsv', 'r') as f:\n",
    "    for line in f:\n",
    "        qid, docid, rank, score = line.strip().split('\\t')\n",
    "        # qid, docid, score = line.strip().split('\\t')\n",
    "        score = float(score)\n",
    "        all_scores[qid][docid] = score\n",
    "\n",
    "qq = list(all_scores.keys())\n",
    "\n",
    "# topK_pids\n",
    "topK_results = {}\n",
    "for qid in qq:\n",
    "    score_list = sorted(list(all_scores[qid].items()), key=lambda x: x[1], reverse=True)\n",
    "    for rank, (docid, score) in enumerate(score_list):\n",
    "        if rank > 4:\n",
    "            continue\n",
    "        if qid not in topK_results:\n",
    "            topK_results[qid] = []\n",
    "        topK_results[qid].append(docid)\n",
    "\n",
    "with open('/Users/yichendezaizai/Data_Imputation/data/wikituples/final_data/folds.json', 'r') as f:\n",
    "    folds = json.load(f)\n",
    "    test_qids = folds['test']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "collection = {}\n",
    "with open('/Users/yichendezaizai/Data_Imputation/data/wikituples/final_data/collection.tsv', 'r') as f:\n",
    "    for line in f:\n",
    "        line = line.strip()\n",
    "        qid, query = line[:line.find('\\t')], line[line.find('\\t')+1:]\n",
    "        collection[qid] = query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_table(tuple_id, serialized_tuple):\n",
    "    # print(tuple_id, serialized_tuple)\n",
    "    # 分割标题和数据\n",
    "    caption_split = serialized_tuple.split(' attribute ')\n",
    "    title = caption_split[0].split(']: ')[1].strip()\n",
    "\n",
    "    # 提取属性和值\n",
    "    attributes = caption_split[1:]\n",
    "\n",
    "    headers = []\n",
    "    values = []\n",
    "    sign = 0\n",
    "    if int(tuple_id) >= 482835 and int(tuple_id) <=482849:\n",
    "        sign = 1\n",
    "\n",
    "    for attribute in attributes:\n",
    "        attribute_value_split = attribute.split(' value ')\n",
    "        attribute_name = attribute_value_split[0].strip()\n",
    "        value = attribute_value_split[1].split(' attribute ')[0].strip()  # 分割可能的下一个属性\n",
    "        \n",
    "        if sign == 1 and len(attribute_name) > 10:\n",
    "            attribute_name = attribute_name[:10]\n",
    "        headers.append(attribute_name)\n",
    "        \n",
    "        values.append(value)\n",
    "\n",
    "    # 构建表格\n",
    "    table = 'caption: ' + title + '\\n|' + ' | '.join(headers) + ' |\\n|' + ' | '.join(values) + ' |'\n",
    "    return table\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "processed_tuples = []\n",
    "with open('./results/WikiTuples/GPT4_wikituples_with_retrieved_tuples_by_monoT5.jsonl', 'r') as f:\n",
    "    for line in f:\n",
    "        line = json.loads(line)\n",
    "        tuple_id = int(line['tuple_id'])\n",
    "        processed_tuples.append(tuple_id)\n",
    "\n",
    "print(len(processed_tuples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import jsonlines\n",
    "import random\n",
    "\n",
    "template = '''Based on the retrieved tabular data, what's the most likely value for the [TO-FILL] cell in the table below? Please respond using JSON: {answer_format}, the key is attribute name of each [TO-FILL], value is the predicted value for each [TO-FILL].\\n'''\n",
    "tuples = {}\n",
    "\n",
    "count, acc = 0,0\n",
    "with open('../data/wikituples/missing_tables.jsonl', 'r') as f:\n",
    "    for line in f:\n",
    "        line = json.loads(line)\n",
    "        table_id = line['_id']\n",
    "        pgTitle, sectionTitle, tableCaption = line['pgTitle'], line['sectionTitle'], line['tableCaption']\n",
    "        tuple_ids = line['tuple_id']\n",
    "        tableData = line['tableData']\n",
    "        headers = line['processed_tableHeaders']\n",
    "        ground_truth = line['ground_truth']\n",
    "        caption = 'caption:' + line['pgTitle'] + ' | ' + line['sectionTitle'] + ' | ' + line['tableCaption']\n",
    "        for index, t_id in enumerate(tuple_ids):\n",
    "            \n",
    "            if t_id in processed_tuples or t_id not in test_qids:\n",
    "                continue\n",
    "\n",
    "            input_data = template + caption + '\\n'\n",
    "            tuple_d = tableData[index]\n",
    "            tuple_answer = ground_truth[index]\n",
    "            for j in range(len(headers)):\n",
    "                input_data += '|' + headers[j]\n",
    "            input_data += '|\\n'\n",
    "            missing_pos = []\n",
    "            \n",
    "            answer_format = '{'\n",
    "            for j in range(len(headers)):\n",
    "                \n",
    "                if tuple_d[j] == 'N/A':\n",
    "                    answer_format += headers[j] + \": \" + '\"\"' + \", \"\n",
    "                    input_data += '|' + '[TO-FILL]'\n",
    "                    missing_pos.append(j)\n",
    "                else:\n",
    "                    input_data += '|' + tuple_d[j]\n",
    "\n",
    "            input_data += '|\\n'\n",
    "            # print(input_data)\n",
    "            answer_format = answer_format[:-2] + '}'\n",
    "            input_data = input_data.replace('Red, Green & White}|', 'Red, Green & White|')\n",
    "            input_data = input_data.replace('Green & Gold}|', 'Green & Gold|')\n",
    "            input_data = input_data.format(answer_format=answer_format)\n",
    "            \n",
    "            \n",
    "            # Adding retrieved tables\n",
    "            input_data += 'Retrieved Tables:\\n'\n",
    "            retrieved_tables = topK_results[str(t_id)]\n",
    "            for rank, docid in enumerate(retrieved_tables):\n",
    "                input_data += 'Table ' + str(rank+1) + ': ' + convert_to_table(docid, collection[docid]) + '\\n\\n'\n",
    "\n",
    "            print(\"---------------------------------------------------\")\n",
    "            print(f\"Input: \\n{input_data}\")\n",
    "\n",
    "            output = chat(input_data, model=\"gpt-4\", temperate=0.3)\n",
    "\n",
    "            print(f\"Output: \\n{output}\")\n",
    "                \n",
    "            fout = jsonlines.open('./results/WikiTuples/GPT4_wikituples_with_retrieved_tuples_by_monoT5.jsonl', 'a')\n",
    "            fout.write({'tuple_id':t_id, 'input': input_data, 'output': output})\n",
    "            fout.close()\n",
    "                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import jsonlines\n",
    "import time\n",
    "import ast\n",
    "\n",
    "processed_tuples = {}\n",
    "with open('./results/WikiTuples/GPT4_wikituples_with_retrieved_tuples_by_monoT5.jsonl', 'r') as f:\n",
    "    for line in f:\n",
    "        line = json.loads(line)\n",
    "        _id = line['tuple_id']\n",
    "        processed_tuples[_id] = line\n",
    "        \n",
    "\n",
    "with open('/Users/yichendezaizai/Data_Imputation/data/wikituples/final_data/folds.json', 'r') as f:\n",
    "    folds = json.load(f)\n",
    "    test_qids = folds['test']\n",
    "\n",
    "def process_answer(answer_set):\n",
    "    return [aa.lower().replace('_', ' ') for aa in answer_set]\n",
    "\n",
    "count, acc = 0,0\n",
    "with open('../data/wikituples/missing_tables.jsonl', 'r') as f:\n",
    "    for line in f:\n",
    "        \n",
    "        line = json.loads(line)\n",
    "        table_id = line['_id']\n",
    "        pgTitle, sectionTitle, tableCaption = line['pgTitle'], line['sectionTitle'], line['tableCaption']\n",
    "        tuple_ids = line['tuple_id']\n",
    "        tableData = line['tableData']\n",
    "        headers = line['processed_tableHeaders']\n",
    "        ground_truth = line['ground_truth']\n",
    "        caption = 'caption:' + line['pgTitle'] + ' | ' + line['sectionTitle'] + ' | ' + line['tableCaption']\n",
    "        \n",
    "        for index, t_id in enumerate(tuple_ids):\n",
    "            if t_id not in processed_tuples:\n",
    "                continue\n",
    "            \n",
    "            print(t_id)\n",
    "            if t_id not in test_qids:\n",
    "                continue\n",
    "\n",
    "            tuple_d = tableData[index]\n",
    "            tuple_answer = ground_truth[index]\n",
    "            \n",
    "            missing_pos = []\n",
    "            for j in range(len(headers)):\n",
    "                if tuple_d[j] == 'N/A':\n",
    "                    missing_pos.append(j)\n",
    "            count += len(missing_pos)\n",
    "\n",
    "            line = processed_tuples[t_id]\n",
    "\n",
    "            if 'imputed_data' not in line:\n",
    "                output = line['output']\n",
    "                \n",
    "                imputed_data = ast.literal_eval(output)\n",
    "\n",
    "                imputed_values = list(imputed_data.values())\n",
    "                # print(imputed_values)\n",
    "\n",
    "                for mm, pos in enumerate(missing_pos):\n",
    "                    try:\n",
    "                        hint = headers[pos]\n",
    "                        cell_value = imputed_data[hint]\n",
    "                    except:\n",
    "                        try:\n",
    "                            cell_value = imputed_values[mm]\n",
    "                        except:\n",
    "                            continue\n",
    "\n",
    "                    if type(tuple_answer[pos]) == list:\n",
    "                            entity_id = tuple_answer[pos][0]\n",
    "                            answer_set = set(entityid_to_text[entity_id])\n",
    "                    else:\n",
    "                        answer_set = tuple_answer[pos]\n",
    "                    \n",
    "                    answer_set = process_answer(answer_set)\n",
    "\n",
    "                    if type(cell_value) is list:\n",
    "                        for cc in cell_value:\n",
    "                            cell_value = cc.lower().replace('_', ' ')\n",
    "                            if any(cell_value in answer for answer in answer_set):\n",
    "                                acc += 1\n",
    "                    else:\n",
    "                        cell_value = cell_value.lower().replace('_', ' ')\n",
    "                        # 如果cell_value在任意一个answer_set中，就算正确\n",
    "                        if any(cell_value in answer for answer in answer_set):\n",
    "                            acc += 1\n",
    "\n",
    "            else:\n",
    "\n",
    "                imputed_data = line['imputed_data']\n",
    "\n",
    "                # print(f\"imputed_data: {imputed_data}\")\n",
    "        \n",
    "                for mm, pos in enumerate(missing_pos):\n",
    "\n",
    "                    try:\n",
    "                        cell_value = imputed_data[pos]\n",
    "                    except:\n",
    "                        continue\n",
    "                        \n",
    "                    if type(tuple_answer[pos]) == list:\n",
    "                        entity_id = tuple_answer[pos][0]\n",
    "                        answer_set = set(entityid_to_text[entity_id])\n",
    "\n",
    "                    else:\n",
    "                        answer_set = tuple_answer[pos]\n",
    "                    \n",
    "                    answer_set = process_answer(answer_set)\n",
    "                    if type(cell_value) is list:\n",
    "                        for cc in cell_value:\n",
    "                            cell_value = cc.lower().replace('_', ' ')\n",
    "                            if any(cell_value in answer for answer in answer_set):\n",
    "                                acc += 1\n",
    "                    else:\n",
    "                        cell_value = cell_value.lower().replace('_', ' ')\n",
    "                        # 如果cell_value在任意一个answer_set中，就算正确\n",
    "                        if any(cell_value in answer for answer in answer_set):\n",
    "                            acc += 1\n",
    "\n",
    "\n",
    "accuaracy = round(acc/count, 3)\n",
    "print(f\"Imputed Accuracy: {accuaracy} on {len(test_qids)} tuples\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
