{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import csv\n",
    "from tqdm import tqdm \n",
    "import collections\n",
    "import gzip\n",
    "import pickle\n",
    "import numpy as np\n",
    "import faiss\n",
    "import os\n",
    "import pytrec_eval\n",
    "import json\n",
    "from msmarco_eval import quality_checks_qids, compute_metrics, load_reference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define params below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_path = # location for dumpped query and passage/document embeddings which is output_dir from run_ann_data_gen.py\n",
    "checkpoint =  # embedding from which checkpoint(ie: 200000)\n",
    "data_type = 1 # 0 for document, 1 for passage\n",
    "test_set = 0 # 0 for dev_set, 1 for eval_set\n",
    "raw_data_dir = \n",
    "processed_data_dir = "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Qrel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "if data_type == 0:\n",
    "    topN = 100\n",
    "else:\n",
    "    topN = 1000\n",
    "dev_query_positive_id = {}\n",
    "query_positive_id_path = os.path.join(processed_data_dir, \"dev-qrel.tsv\")\n",
    "\n",
    "with open(query_positive_id_path, 'r', encoding='utf8') as f:\n",
    "    tsvreader = csv.reader(f, delimiter=\"\\t\")\n",
    "    for [topicid, docid, rel] in tsvreader:\n",
    "        topicid = int(topicid)\n",
    "        docid = int(docid)\n",
    "        if topicid not in dev_query_positive_id:\n",
    "            dev_query_positive_id[topicid] = {}\n",
    "        dev_query_positive_id[topicid][docid] = int(rel)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare rerank data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "6668967it [00:27, 241339.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of queries with 1000 BM25 passages: 6980\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "qidmap_path = processed_data_dir+\"/qid2offset.pickle\"\n",
    "pidmap_path = processed_data_dir+\"/pid2offset.pickle\"\n",
    "if data_type == 0:\n",
    "    if test_set == 1:\n",
    "        query_path = raw_data_dir+\"/msmarco-test2019-queries.tsv.gz\"\n",
    "        passage_path = raw_data_dir+\"/yourpath/msmarco-doctest2019-top100.gz\"\n",
    "    else:\n",
    "        query_path = raw_data_dir+\"/msmarco-docdev-queries.tsv\"\n",
    "        passage_path = raw_data_dir+\"/msmarco-docdev-top100.gz\"\n",
    "else:\n",
    "    if test_set == 1:\n",
    "        query_path = raw_data_dir+\"/msmarco-test2019-queries.tsv\"\n",
    "        passage_path = raw_data_dir+\"/msmarco-passagetest2019-top1000.tsv.gz\"\n",
    "    else:\n",
    "        query_path = raw_data_dir+\"/queries.dev.small.tsv\"\n",
    "        passage_path = raw_data_dir+\"/msmarco-passagetest2019-top1000.tsv.gz\"\n",
    "    \n",
    "with open(qidmap_path, 'rb') as handle:\n",
    "    qidmap = pickle.load(handle)\n",
    "\n",
    "with open(pidmap_path, 'rb') as handle:\n",
    "    pidmap = pickle.load(handle)\n",
    "\n",
    "qset = set()\n",
    "with gzip.open(query_path, 'rt', encoding='utf-8') if query_path[-2:] == \"gz\" else open(query_path, 'rt', encoding='utf-8') as f:\n",
    "    tsvreader = csv.reader(f, delimiter=\"\\t\")\n",
    "    for [qid, query] in tsvreader:\n",
    "        qset.add(qid)\n",
    "\n",
    "bm25 = collections.defaultdict(set)\n",
    "with gzip.open(passage_path, 'rt', encoding='utf-8') if passage_path[-2:] == \"gz\" else open(passage_path, 'rt', encoding='utf-8') as f:\n",
    "    for line in tqdm(f):\n",
    "        if data_type == 0:\n",
    "            [qid, Q0, pid, rank, score, runstring] = line.split(' ')\n",
    "            pid = pid[1:]\n",
    "        else:\n",
    "            [qid, pid, query, passage] = line.split(\"\\t\")\n",
    "        if qid in qset and int(qid) in qidmap:\n",
    "            bm25[qidmap[int(qid)]].add(pidmap[int(pid)]) \n",
    "\n",
    "print(\"number of queries with \" +str(topN) + \" BM25 passages:\", len(bm25))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Calculate Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_string_id(result_dict):\n",
    "    string_id_dict = {}\n",
    "\n",
    "    # format [string, dict[string, val]]\n",
    "    for k, v in result_dict.items():\n",
    "        _temp_v = {}\n",
    "        for inner_k, inner_v in v.items():\n",
    "            _temp_v[str(inner_k)] = inner_v\n",
    "\n",
    "        string_id_dict[str(k)] = _temp_v\n",
    "\n",
    "    return string_id_dict\n",
    "\n",
    "def EvalDevQuery(query_embedding2id, passage_embedding2id, dev_query_positive_id, I_nearest_neighbor,topN):\n",
    "    prediction = {} #[qid][docid] = docscore, here we use -rank as score, so the higher the rank (1 > 2), the higher the score (-1 > -2)\n",
    "\n",
    "    total = 0\n",
    "    labeled = 0\n",
    "    Atotal = 0\n",
    "    Alabeled = 0\n",
    "    qids_to_ranked_candidate_passages = {} \n",
    "    for query_idx in range(len(I_nearest_neighbor)): \n",
    "        seen_pid = set()\n",
    "        query_id = query_embedding2id[query_idx]\n",
    "        prediction[query_id] = {}\n",
    "\n",
    "        top_ann_pid = I_nearest_neighbor[query_idx].copy()\n",
    "        selected_ann_idx = top_ann_pid[:topN]\n",
    "        rank = 0\n",
    "        \n",
    "        if query_id in qids_to_ranked_candidate_passages:\n",
    "            pass    \n",
    "        else:\n",
    "            # By default, all PIDs in the list of 1000 are 0. Only override those that are given\n",
    "            tmp = [0] * 1000\n",
    "            qids_to_ranked_candidate_passages[query_id] = tmp\n",
    "                \n",
    "        for idx in selected_ann_idx:\n",
    "            pred_pid = passage_embedding2id[idx]\n",
    "            \n",
    "            if not pred_pid in seen_pid:\n",
    "                # this check handles multiple vector per document\n",
    "                qids_to_ranked_candidate_passages[query_id][rank]=pred_pid\n",
    "                Atotal += 1\n",
    "                if pred_pid not in dev_query_positive_id[query_id]:\n",
    "                    Alabeled += 1\n",
    "                if rank < 10:\n",
    "                    total += 1\n",
    "                    if pred_pid not in dev_query_positive_id[query_id]:\n",
    "                        labeled += 1\n",
    "                rank += 1\n",
    "                prediction[query_id][pred_pid] = -rank\n",
    "                seen_pid.add(pred_pid)\n",
    "\n",
    "    # use out of the box evaluation script\n",
    "    evaluator = pytrec_eval.RelevanceEvaluator(\n",
    "        convert_to_string_id(dev_query_positive_id), {'map_cut', 'ndcg_cut', 'recip_rank','recall'})\n",
    "\n",
    "    eval_query_cnt = 0\n",
    "    result = evaluator.evaluate(convert_to_string_id(prediction))\n",
    "    \n",
    "    qids_to_relevant_passageids = {}\n",
    "    for qid in dev_query_positive_id:\n",
    "        qid = int(qid)\n",
    "        if qid in qids_to_relevant_passageids:\n",
    "            pass\n",
    "        else:\n",
    "            qids_to_relevant_passageids[qid] = []\n",
    "            for pid in dev_query_positive_id[qid]:\n",
    "                if pid>0:\n",
    "                    qids_to_relevant_passageids[qid].append(pid)\n",
    "            \n",
    "    ms_mrr = compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages)\n",
    "\n",
    "    ndcg = 0\n",
    "    Map = 0\n",
    "    mrr = 0\n",
    "    recall = 0\n",
    "    recall_1000 = 0\n",
    "\n",
    "    for k in result.keys():\n",
    "        eval_query_cnt += 1\n",
    "        ndcg += result[k][\"ndcg_cut_10\"]\n",
    "        Map += result[k][\"map_cut_10\"]\n",
    "        mrr += result[k][\"recip_rank\"]\n",
    "        recall += result[k][\"recall_\"+str(topN)]\n",
    "\n",
    "    final_ndcg = ndcg / eval_query_cnt\n",
    "    final_Map = Map / eval_query_cnt\n",
    "    final_mrr = mrr / eval_query_cnt\n",
    "    final_recall = recall / eval_query_cnt\n",
    "    hole_rate = labeled/total\n",
    "    Ahole_rate = Alabeled/Atotal\n",
    "\n",
    "    return final_ndcg, eval_query_cnt, final_Map, final_mrr, final_recall, hole_rate, ms_mrr, Ahole_rate, result, prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "dev_query_embedding = []\n",
    "dev_query_embedding2id = []\n",
    "passage_embedding = []\n",
    "passage_embedding2id = []\n",
    "for i in range(8):\n",
    "    try:\n",
    "        with open(checkpoint_path + \"dev_query_\"+str(checkpoint)+\"__emb_p__data_obj_\"+str(i)+\".pb\", 'rb') as handle:\n",
    "            dev_query_embedding.append(pickle.load(handle))\n",
    "        with open(checkpoint_path + \"dev_query_\"+str(checkpoint)+\"__embid_p__data_obj_\"+str(i)+\".pb\", 'rb') as handle:\n",
    "            dev_query_embedding2id.append(pickle.load(handle))\n",
    "        with open(checkpoint_path + \"passage_\"+str(checkpoint)+\"__emb_p__data_obj_\"+str(i)+\".pb\", 'rb') as handle:\n",
    "            passage_embedding.append(pickle.load(handle))\n",
    "        with open(checkpoint_path + \"passage_\"+str(checkpoint)+\"__embid_p__data_obj_\"+str(i)+\".pb\", 'rb') as handle:\n",
    "            passage_embedding2id.append(pickle.load(handle))\n",
    "    except:\n",
    "        break\n",
    "if (not dev_query_embedding) or (not dev_query_embedding2id) or (not passage_embedding) or not (passage_embedding2id):\n",
    "    print(\"No data found for checkpoint: \",checkpoint)\n",
    "\n",
    "dev_query_embedding = np.concatenate(dev_query_embedding, axis=0)\n",
    "dev_query_embedding2id = np.concatenate(dev_query_embedding2id, axis=0)\n",
    "passage_embedding = np.concatenate(passage_embedding, axis=0)\n",
    "passage_embedding2id = np.concatenate(passage_embedding2id, axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# reranking metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reranking Results for checkpoint 200000\n",
      "Reranking NDCG@10:0.3729454806160795\n",
      "Reranking map@10:0.3125216510513781\n",
      "Reranking pytrec_mrr:0.3282910349833996\n",
      "Reranking recall@1000:0.8140401146131804\n",
      "Reranking hole rate@10:0.9416614005859711\n",
      "Reranking hole rate:0.9990995606965817\n",
      "Reranking ms_mrr:{'MRR @10': 0.31865272661117866, 'QueriesRanked': 6980}\n"
     ]
    }
   ],
   "source": [
    "pidmap = collections.defaultdict(list)\n",
    "for i in range(len(passage_embedding2id)):\n",
    "    pidmap[passage_embedding2id[i]].append(i)  # abs pos(key) to rele pos(val)\n",
    "    \n",
    "rerank_data = {}\n",
    "all_dev_I = []\n",
    "for i,qid in enumerate(dev_query_embedding2id):\n",
    "    p_set = []\n",
    "    p_set_map = {}\n",
    "    if qid not in bm25:\n",
    "        print(qid,\"not in bm25\")\n",
    "    else:\n",
    "        count = 0\n",
    "        for k,pid in enumerate(bm25[qid]):\n",
    "            if pid in pidmap:\n",
    "                for val in pidmap[pid]:\n",
    "                    p_set.append(passage_embedding[val])\n",
    "                    p_set_map[count] = val # new rele pos(key) to old rele pos(val)\n",
    "                    count += 1\n",
    "            else:\n",
    "                print(pid,\"not in passages\")\n",
    "    dim = passage_embedding.shape[1]\n",
    "    faiss.omp_set_num_threads(16)\n",
    "    cpu_index = faiss.IndexFlatIP(dim)\n",
    "    p_set =  np.asarray(p_set)\n",
    "    cpu_index.add(p_set)    \n",
    "    _, dev_I = cpu_index.search(dev_query_embedding[i:i+1], len(p_set))\n",
    "    for j in range(len(dev_I[0])):\n",
    "        dev_I[0][j] = p_set_map[dev_I[0][j]]\n",
    "    all_dev_I.append(dev_I[0])\n",
    "result = EvalDevQuery(dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, all_dev_I, topN)\n",
    "final_ndcg, eval_query_cnt, final_Map, final_mrr, final_recall, hole_rate, ms_mrr, Ahole_rate, metrics, prediction = result\n",
    "print(\"Reranking Results for checkpoint \"+str(checkpoint))\n",
    "print(\"Reranking NDCG@10:\" + str(final_ndcg))\n",
    "print(\"Reranking map@10:\" + str(final_Map))\n",
    "print(\"Reranking pytrec_mrr:\" + str(final_mrr))\n",
    "print(\"Reranking recall@\"+str(topN)+\":\" + str(final_recall))\n",
    "print(\"Reranking hole rate@10:\" + str(hole_rate))\n",
    "print(\"Reranking hole rate:\" + str(Ahole_rate))\n",
    "print(\"Reranking ms_mrr:\" + str(ms_mrr))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# full ranking metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Results for checkpoint 200000\n",
      "NDCG@10:0.3782199354440816\n",
      "map@10:0.31537880338381635\n",
      "pytrec_mrr:0.3336037287188709\n",
      "recall@1000:0.9524594078319008\n",
      "hole rate@10:0.9402578796561605\n",
      "hole rate:0.9989885386819485\n",
      "ms_mrr:{'MRR @10': 0.32180822758902866, 'QueriesRanked': 6980}\n"
     ]
    }
   ],
   "source": [
    "dim = passage_embedding.shape[1]\n",
    "faiss.omp_set_num_threads(16)\n",
    "cpu_index = faiss.IndexFlatIP(dim)\n",
    "cpu_index.add(passage_embedding)    \n",
    "_, dev_I = cpu_index.search(dev_query_embedding, topN)\n",
    "result = EvalDevQuery(dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, dev_I, topN)\n",
    "final_ndcg, eval_query_cnt, final_Map, final_mrr, final_recall, hole_rate, ms_mrr, Ahole_rate, metrics, prediction = result\n",
    "print(\"Results for checkpoint \"+str(checkpoint))\n",
    "print(\"NDCG@10:\" + str(final_ndcg))\n",
    "print(\"map@10:\" + str(final_Map))\n",
    "print(\"pytrec_mrr:\" + str(final_mrr))\n",
    "print(\"recall@\"+str(topN)+\":\" + str(final_recall))\n",
    "print(\"hole rate@10:\" + str(hole_rate))\n",
    "print(\"hole rate:\" + str(Ahole_rate))\n",
    "print(\"ms_mrr:\" + str(ms_mrr))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}