{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import openai\n",
    "import pandas as pd\n",
    "import time\n",
    "import requests\n",
    "import json\n",
    "import jsonlines\n",
    "import re\n",
    "\n",
    "sk = \"YOUR API KEY\"\n",
    "openai.api_key = sk\n",
    "\n",
    "def chat(input_data, model=\"gpt-3.5-turbo\", temperature=0.8):\n",
    "    \n",
    "    nmessages = [{\"role\": \"user\", \"content\": input_data, \"temperature\": temperature}]\n",
    "\n",
    "    while (1):\n",
    "        try:\n",
    "            response = openai.ChatCompletion.create(          \n",
    "                model=model,\n",
    "                messages=nmessages\n",
    "            )\n",
    "            resmessage = response['choices'][0]['message']['content']\n",
    "            break\n",
    "        except:\n",
    "            time.sleep(10)\n",
    "            \n",
    "    return  resmessage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "processed_tuples = []\n",
    "with open('./results/Cricket_Players/GPT4_cricket_wo_evidence.jsonl', 'r') as f:\n",
    "    for line in f:\n",
    "        line = json.loads(line)\n",
    "        tuple_id = int(line['tuple_id'])\n",
    "        processed_tuples.append(tuple_id)\n",
    "\n",
    "print(len(processed_tuples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import jsonlines\n",
    "import pandas as pd\n",
    "import ast\n",
    "\n",
    "tuples = {}\n",
    "template = '''What's the most likely value for the [TO-FILL] cell in the table below? Please respond using JSON: {answer_format}, the key is attribute name of each [TO-FILL], value is the predicted value for each [TO-FILL].\\n'''\n",
    "\n",
    "with open('/Users/yichendezaizai/Data_Imputation/data/cricket_players/annotated_data/folds.json', 'r') as f:\n",
    "    folds = json.load(f)\n",
    "    test_qids = folds['test']\n",
    "print(len(test_qids))\n",
    "\n",
    "# 读取表格数据tv_ret_1.csv\n",
    "tableData = []\n",
    "ground_truth = []\n",
    "keep_columns =['Name', 'Team', 'Type', 'ValueinCR', 'National Side', 'Batting Style', 'Bowling', 'MatchPlayed',\n",
    "       'InningsBatted', 'RunsScored']\n",
    "\n",
    "df = pd.read_csv('/Users/yichendezaizai/Data_Imputation/data/cricket_players/cricket_ret_1.csv', usecols=keep_columns)\n",
    "\n",
    "\n",
    "count, acc = 0, 0\n",
    "\n",
    "missing_columns = ['National Side', 'Batting Style']\n",
    "\n",
    "example = ['Vaibhav Arora', 'PBKS', 'Bowler', '2.00', 'India', 'Right Handed', 'Right-arm fast medium', '', '']\n",
    "\n",
    "for index, row in df.iterrows():\n",
    "\n",
    "    if index not in test_qids or index in processed_tuples:\n",
    "        continue\n",
    "\n",
    "    input_data = template + '[caption]: cricket player ' + '\\n'\n",
    "    for col in df.columns:\n",
    "        input_data += '|' + col\n",
    "    input_data += '|\\n'\n",
    "    \n",
    "    for cell in example:\n",
    "        input_data += '|' + str(cell)\n",
    "    input_data += '|\\n |'\n",
    "    \n",
    "    answer_format = '{'\n",
    "    missing_pos = []\n",
    "    for col_id, col in enumerate(df.columns):\n",
    "        # if col in missing_columns and not pd.isnull(row[col]):\n",
    "        if col in missing_columns:\n",
    "            cell_value = '[TO-FILL]'\n",
    "            answer_format += col + \": \" + '\"\"' + \", \"\n",
    "            missing_pos.append(col_id)\n",
    "        else:\n",
    "            cell_value = row[col]\n",
    "        input_data += '|' + str(cell_value)\n",
    "        \n",
    "    input_data += '|\\n'\n",
    "    answer_format = answer_format[:-2] + '}'\n",
    "    input_data = input_data.format(answer_format=answer_format)\n",
    "\n",
    "    print(\"---------------------------------------------------\")\n",
    "    print(f\"Input: \\n{input_data}\")\n",
    "\n",
    "    output = chat(input_data, model=\"gpt-4\", temperature=0.3)\n",
    "\n",
    "    print(f\"Output: \\n{output}\")\n",
    "            \n",
    "    # imputed_data = ast.literal_eval(output)\n",
    "\n",
    "    fout = jsonlines.open('./results/Cricket_Players/GPT4_cricket_wo_evidence.jsonl', 'a')\n",
    "    fout.write({'tuple_id':index, 'input': input_data, 'output':output})\n",
    "    fout.close()\n",
    "\n",
    "# print(f\"Accuracy: {acc/count}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(count)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imputed Data with retrieved tables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "top_K = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "all_scores = defaultdict(dict)\n",
    "with open('/Users/yichendezaizai/Data_Imputation/retrieval_results/rerank_results/final_data/cricket_players_test.tsv', 'r') as f:\n",
    "    # with open('/Users/yichendezaizai/Data_Imputation/retrieval_results/first_stage/BM25_top100_res_with_score_cricket_players.tsv', 'r') as f:\n",
    "    for line in f:\n",
    "        qid, docid, rank, score = line.strip().split('\\t')\n",
    "        \n",
    "        #qid, docid, score = line.strip().split('\\t')\n",
    "        score = float(score)\n",
    "        all_scores[int(qid)][int(docid)] = score\n",
    "\n",
    "qq = list(all_scores.keys())\n",
    "\n",
    "# topK_pids\n",
    "topK_results = {}\n",
    "for qid in qq:\n",
    "    score_list = sorted(list(all_scores[qid].items()), key=lambda x: x[1], reverse=True)\n",
    "    for rank, (docid, score) in enumerate(score_list):\n",
    "        if rank >= top_K:\n",
    "            continue\n",
    "        if qid not in topK_results:\n",
    "            topK_results[qid] = []\n",
    "        topK_results[qid].append(docid)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "collection = {}\n",
    "with open('/Users/yichendezaizai/Data_Imputation/data/cricket_players/annotated_data/collection.tsv', 'r') as f:\n",
    "    for line in f:\n",
    "        line = line.strip()\n",
    "        qid, query = line[:line.find('\\t')], line[line.find('\\t')+1:]\n",
    "        collection[int(qid)] = query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "processed_tuples = []\n",
    "with open('./results/Cricket_Players/ablation/cricket_players_with_retrieved_tuples_top20.jsonl', 'r') as f:\n",
    "    for line in f:\n",
    "        line = json.loads(line)\n",
    "        tuple_id = int(line['tuple_id'])\n",
    "        processed_tuples.append(tuple_id)\n",
    "\n",
    "print(len(processed_tuples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_table(serialized_tuple):\n",
    "    # 分割标题和数据\n",
    "    caption_split = serialized_tuple.split(' attribute ')\n",
    "    title = caption_split[0].split(']: ')[1].strip()\n",
    "\n",
    "    # 提取属性和值\n",
    "    attributes = caption_split[1:]\n",
    "\n",
    "    headers = []\n",
    "    values = []\n",
    "    sign = 0\n",
    "    \n",
    "\n",
    "    for attribute in attributes:\n",
    "        attribute_value_split = attribute.split(' value ')\n",
    "        attribute_name = attribute_value_split[0].strip()\n",
    "        value = attribute_value_split[1].split(' attribute ')[0].strip()  # 分割可能的下一个属性\n",
    "        \n",
    "        if sign == 1 and len(attribute_name) > 10:\n",
    "            attribute_name = attribute_name[:10]\n",
    "        headers.append(attribute_name)\n",
    "        \n",
    "        values.append(value)\n",
    "\n",
    "    # 构建表格\n",
    "    table = 'caption: ' + title + '\\n|' + ' | '.join(headers) + ' |\\n|' + ' | '.join(values) + ' |'\n",
    "    return table\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import jsonlines\n",
    "import pandas as pd\n",
    "import ast\n",
    "\n",
    "tuples = {}\n",
    "template = '''Based on the retrieved tabular data, what's the most likely value for the [TO-FILL] cell in the table below? Please respond using JSON: {answer_format}, the key is attribute name of each [TO-FILL], value is the predicted value for each [TO-FILL].\\n'''\n",
    "\n",
    "\n",
    "with open('/Users/yichendezaizai/Data_Imputation/data/cricket_players/annotated_data/folds.json', 'r') as f:\n",
    "    folds = json.load(f)\n",
    "    test_qids = folds['test']\n",
    "print(len(test_qids))\n",
    "\n",
    "\n",
    "keep_columns =['Name', 'Team', 'Type', 'ValueinCR', 'National Side', 'Batting Style', 'Bowling', 'MatchPlayed',\n",
    "       'InningsBatted', 'RunsScored']\n",
    "\n",
    "df = pd.read_csv('/Users/yichendezaizai/Data_Imputation/data/cricket_players/cricket_ret_1.csv', usecols=keep_columns)\n",
    "\n",
    "\n",
    "missing_columns = ['National Side', 'Batting Style']\n",
    "\n",
    "# example = ['Vaibhav Arora', 'PBKS', 'Bowler', '2.00', 'India', 'Right Handed', 'Right-arm fast medium', '', '']\n",
    "\n",
    "for index, row in df.iterrows():\n",
    "\n",
    "    if index not in test_qids or index in processed_tuples:\n",
    "        continue\n",
    "\n",
    "    input_data = template + '[caption]: cricket player ' + '\\n'\n",
    "    for col in df.columns:\n",
    "        input_data += '|' + col\n",
    "    input_data += '|\\n'\n",
    "    \n",
    "    '''\n",
    "    for cell in example:\n",
    "        input_data += '|' + str(cell)\n",
    "    input_data += '|\\n'\n",
    "    '''\n",
    "    \n",
    "    answer_format = '{'\n",
    "    missing_pos = []\n",
    "    ground_truth = []\n",
    "    for col_id, col in enumerate(df.columns):\n",
    "\n",
    "        #if col in missing_columns and not pd.isnull(row[col]):\n",
    "        if col in missing_columns:\n",
    "            cell_value = '[TO-FILL]'\n",
    "            answer_format += col + \": \" + '\"\"' + \", \"\n",
    "            missing_pos.append(col_id)\n",
    "            \n",
    "        else:\n",
    "            cell_value = row[col]\n",
    "\n",
    "        input_data += '|' + str(cell_value)\n",
    "        \n",
    "    input_data += '|\\n'\n",
    "    answer_format = answer_format[:-2] + '}'\n",
    "    input_data = input_data.format(answer_format=answer_format)\n",
    "\n",
    "    # Adding retrieved tables\n",
    "    input_data += 'Retrieved Tables:\\n'\n",
    "    retrieved_tables = topK_results[index]\n",
    "    for rank, docid in enumerate(retrieved_tables):\n",
    "        input_data += 'Table ' + str(rank+1) + ': ' + convert_to_table(collection[docid]) + '\\n\\n'\n",
    "\n",
    "\n",
    "    print(\"---------------------------------------------------\")\n",
    "    print(f\"Input: \\n{input_data}\")\n",
    "\n",
    "    output = chat(input_data, model=\"gpt-3.5-turbo\", temperature=0.3)\n",
    "\n",
    "    print(f\"Output: \\n{output}\")\n",
    "            \n",
    "    # imputed_data = ast.literal_eval(output)\n",
    "\n",
    "    \n",
    "    fout = jsonlines.open('./results/Cricket_Players/ablation/cricket_players_with_retrieved_tuples_top20.jsonl', 'a')\n",
    "    fout.write({'tuple_id':index, 'input': input_data, 'output':output})\n",
    "    fout.close()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/Users/yichendezaizai/Data_Imputation/data/cricket_players/annotated_data/folds.json', 'r') as f:\n",
    "    folds = json.load(f)\n",
    "    test_qids = folds['test']\n",
    "\n",
    "qrels = defaultdict(list)\n",
    "with open('/Users/yichendezaizai/Data_Imputation/data/cricket_players/annotated_data/qrels.tsv') as f:\n",
    "    for line in f:\n",
    "        # qid, _, docid, rel = line.strip().split('\\t')\n",
    "        qid, docid, rel = line.strip().split('\\t')\n",
    "        if int(qid) not in test_qids:\n",
    "            continue\n",
    "        qrels[int(qid)].append(int(docid))\n",
    "\n",
    "\n",
    "def calculate_recall(topk_pids, qrels, K):\n",
    "    recall_sum = 0.0\n",
    "    num_queries = len(qrels)\n",
    "\n",
    "    for qid, qrel in qrels.items():\n",
    "        if qid not in topk_pids:\n",
    "            continue\n",
    "        retrieved_docs = set(topk_pids[qid][:K])\n",
    "        relevant_docs = set(qrel)\n",
    "\n",
    "        intersection = relevant_docs.intersection(retrieved_docs)\n",
    "        recall = len(intersection) / len(relevant_docs) if len(relevant_docs) > 0 else 0.0\n",
    "        recall_sum += recall\n",
    "\n",
    "    # 计算平均Recall Rate\n",
    "    recall_rate = recall_sum / num_queries\n",
    "    print(\"Recall@{} =\".format(K), recall_rate)\n",
    "    \n",
    "\n",
    "def calculate_success(topk_pids, qrels, K):\n",
    "    success_at_k = []\n",
    "    total = len(qrels)\n",
    "    for qid, qrel in qrels.items():\n",
    "        if qid not in topk_pids:\n",
    "            continue\n",
    "        relevant_docs = set(qrel)\n",
    "        topK_docs = set(topk_pids[qid][:K]) if qid in topk_pids else set()\n",
    "        if relevant_docs.intersection(topK_docs):\n",
    "            success_at_k.append(1)\n",
    "            \n",
    "    success_at_k_avg = sum(success_at_k) / total\n",
    "    success_at_k_avg = round(success_at_k_avg, 3)\n",
    "    \n",
    "    print(\"Success@{} =\".format(K), success_at_k_avg)\n",
    "\n",
    "\n",
    "print(len(qrels))\n",
    "for K in [1, 5, 10, 20, 50, 100]:\n",
    "    calculate_recall(topK_results, qrels, K)\n",
    "    calculate_success(topK_results, qrels, K)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Calculate scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import jsonlines\n",
    "import pandas as pd\n",
    "import ast\n",
    "\n",
    "tuples = {}\n",
    "\n",
    "with open('/Users/yichendezaizai/Data_Imputation/data/cricket_players/annotated_data/folds.json', 'r') as f:\n",
    "    folds = json.load(f)\n",
    "    test_qids = folds['test']\n",
    "print(len(test_qids))\n",
    "\n",
    "\n",
    "imputed_record = {}\n",
    "with open('./results/cricket_players/ablation/cricket_players_with_retrieved_tuples_top20.jsonl', 'r') as f:\n",
    "    for line in f:\n",
    "        line = json.loads(line)\n",
    "        tuple_id = int(line['tuple_id'])\n",
    "        imputed_record[line['tuple_id']] = ast.literal_eval(line['output'])\n",
    "\n",
    "\n",
    "ground_truth = {}\n",
    "with jsonlines.open('/Users/yichendezaizai/Data_Imputation/data/cricket_players/annotated_data/answers.jsonl', 'r') as f:\n",
    "    for line in f:\n",
    "        ground_truth[line['query_id']] = line['answers']\n",
    "\n",
    "\n",
    "\n",
    "keep_columns =['Name', 'Team', 'Type', 'ValueinCR', 'National Side', 'Batting Style', 'Bowling', 'MatchPlayed',\n",
    "       'InningsBatted', 'RunsScored']\n",
    "\n",
    "df = pd.read_csv('/Users/yichendezaizai/Data_Imputation/data/cricket_players/cricket_ret_1.csv', usecols=keep_columns)\n",
    "\n",
    "\n",
    "count, acc = 0, 0\n",
    "missing_columns = ['National Side', 'Batting Style']\n",
    "for qid, row in df.iterrows():\n",
    "    print(qid)\n",
    "    if qid not in test_qids:\n",
    "        continue\n",
    "            \n",
    "    imputed_data = imputed_record[qid]\n",
    "\n",
    "    correct_values = {}\n",
    "    for key, value in ground_truth[qid].items():\n",
    "        correct_values[key] = []\n",
    "        for vv in value:\n",
    "            correct_values[key].append(vv.lower().replace('(','').replace(')',''))\n",
    "\n",
    "    for i, col in enumerate(missing_columns):\n",
    "\n",
    "        if imputed_data[col].lower().replace('(','').replace(')','') in correct_values[col]:\n",
    "            acc += 1\n",
    "\n",
    "        count += 1\n",
    "\n",
    "print(f\"Accuracy: {acc/count}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
