{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6459709d-a572-4950-8565-1f8e5859052b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading shards into index: 100%|███████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 127.10it/s]\n",
      "1406it [00:00, 9734.79it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 1401/1401 [00:00<00:00, 7371.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "align_loss: 0.033\n",
      "uniform_loss_ours: -0.120\n",
      "uniform_loss_standard: -0.120\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='7'\n",
    "import pickle\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import gc\n",
    "import time\n",
    "import numpy as np\n",
    "import glob\n",
    "from itertools import chain\n",
    "from tqdm import tqdm\n",
    "from collections import defaultdict\n",
    "\n",
    "\n",
    "def pickle_load(path):\n",
    "    with open(path, 'rb') as f:\n",
    "        obj = pickle.load(f)\n",
    "    return obj\n",
    "\n",
    "\n",
    "def load_qid2pos(file_path):\n",
    "    qid2pos = defaultdict(list)\n",
    "    with open(file_path, \"r\", encoding=\"utf-8\") as fi:\n",
    "        for idx, line in enumerate(fi):\n",
    "            if idx == 0:continue\n",
    "            qid, posid, score = line.strip().split(\"\\t\")\n",
    "            \n",
    "            if int(score) <= 0:\n",
    "                continue\n",
    "                \n",
    "            if posid not in qid2pos[qid]:\n",
    "                qid2pos[qid].append(posid)\n",
    "    return qid2pos\n",
    "\n",
    "\n",
    "def get_qry_pos_tensor(qid2pos, q_reps, q_lookup, p_reps, look_up, cuda_available):\n",
    "    qry_list = []\n",
    "    pos_list = []\n",
    "    for idx, qid in tqdm(enumerate(q_lookup)):\n",
    "        posid_list = qid2pos[qid]\n",
    "        \n",
    "        for posid in posid_list:\n",
    "            if posid not in look_up:\n",
    "                continue\n",
    "            \n",
    "            ## valid qry\n",
    "            qry_emb = q_reps[idx]\n",
    "            qry_list.append(qry_emb)\n",
    "\n",
    "            ## valid pos\n",
    "            pos_idx = look_up.index(posid)\n",
    "            pos_emb = p_reps[pos_idx]\n",
    "            pos_list.append(pos_emb)\n",
    "        \n",
    "    qry_tensor = torch.tensor(qry_list)\n",
    "    pos_tensor = torch.tensor(pos_list)\n",
    "    \n",
    "    if cuda_available:\n",
    "        qry_tensor = qry_tensor.cuda()\n",
    "        pos_tensor = pos_tensor.cuda()\n",
    "        \n",
    "    qry_tensor = F.normalize(qry_tensor, p=2, dim=-1)\n",
    "    pos_tensor = F.normalize(pos_tensor, p=2, dim=-1)\n",
    "    \n",
    "    return qry_tensor, pos_tensor\n",
    "\n",
    "\n",
    "def get_align_loss(x, y, alpha=2):\n",
    "    return (x - y).norm(p=2, dim=1).pow(alpha).mean()\n",
    "\n",
    "\n",
    "def get_uniform_loss_ours(x, y, alpha=2, t=2):\n",
    "    uniform_loss = []\n",
    "    for sub_x in tqdm(x):\n",
    "        sub_loss = (sub_x - y).norm(p=2, dim=1).pow(alpha).mul(-t).exp()\n",
    "        uniform_loss.append(sub_loss)\n",
    "    uniform_loss = torch.cat(uniform_loss, dim=0).mean().log()\n",
    "    return uniform_loss\n",
    "\n",
    "def get_uniform_loss_standard(x, y, t=2):\n",
    "    tot = torch.cat([x, y], dim=0)\n",
    "    return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    \n",
    "    DATA_DIR = \"/data/private/sunsi/dataset/beir\"\n",
    "    RES_DIR = \"/data/private/sunsi/experiments/hybrid-transformers/results\"\n",
    "    FOLDER_NAME = \"inference-beir.gpt2-xl.bottom-8\"\n",
    "    DATASET = \"arguana\"  ## fiqa, arguana \n",
    "    cuda_available = True\n",
    "    \n",
    "    ## load qry\n",
    "    query_reps = os.path.join(RES_DIR, FOLDER_NAME, DATASET, \"query/qry.pt\")\n",
    "    q_reps, q_lookup = pickle_load(query_reps)\n",
    "    \n",
    "    ## load psg\n",
    "    passage_reps = os.path.join(RES_DIR, FOLDER_NAME, DATASET, \"corpus/*\")\n",
    "    index_files = glob.glob(passage_reps)\n",
    "    \n",
    "    p_reps_0, p_lookup_0 = pickle_load(index_files[0])\n",
    "    shards = chain([(p_reps_0, p_lookup_0)], map(pickle_load, index_files[1:]))\n",
    "    shards = tqdm(shards, desc='Loading shards into index', total=len(index_files))\n",
    "     \n",
    "    p_reps = []\n",
    "    look_up = []\n",
    "    for _p_reps, p_lookup in shards:\n",
    "        p_reps.append(_p_reps)\n",
    "        # look_up += p_lookup\n",
    "        look_up.extend(p_lookup) ## ss modifed for beir docid not int is str\n",
    "\n",
    "    p_reps = np.concatenate(p_reps, axis=0)\n",
    "    \n",
    "    ## load qrels\n",
    "    qrel_path = os.path.join(DATA_DIR, DATASET, \"qrels/test.tsv\")\n",
    "    qid2pos = load_qid2pos(qrel_path)\n",
    "    \n",
    "    ## get qry, pos tensor\n",
    "    qry_tensor, pos_tensor = get_qry_pos_tensor(qid2pos, q_reps, q_lookup, p_reps, look_up, cuda_available=cuda_available)\n",
    "    \n",
    "    ## [1] get align loss\n",
    "    align_loss = get_align_loss(qry_tensor, pos_tensor)\n",
    "    \n",
    "    ## [2] get uniform loss\n",
    "    psg_tensor = torch.tensor(p_reps)\n",
    "    if cuda_available:\n",
    "        psg_tensor = psg_tensor.cuda()\n",
    "    psg_tensor = F.normalize(psg_tensor, p=2, dim=-1)\n",
    "    uniform_loss_ours = get_uniform_loss_ours(qry_tensor, psg_tensor)\n",
    "    uniform_loss_standard = get_uniform_loss_standard(qry_tensor, psg_tensor)\n",
    "    \n",
    "    print(\"align_loss: %.3f\"%align_loss.item())\n",
    "    print(\"uniform_loss_ours: %.3f\"%uniform_loss_ours.item())\n",
    "    print(\"uniform_loss_standard: %.3f\"%uniform_loss_standard.item())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cocondenser",
   "language": "python",
   "name": "cocondenser"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
