{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6febdda4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-10-11T03:34:09.649712Z",
     "iopub.status.busy": "2023-10-11T03:34:09.649075Z",
     "iopub.status.idle": "2023-10-11T03:34:12.039455Z",
     "shell.execute_reply": "2023-10-11T03:34:12.037161Z"
    },
    "papermill": {
     "duration": 2.408779,
     "end_time": "2023-10-11T03:34:12.043947",
     "exception": false,
     "start_time": "2023-10-11T03:34:09.635168",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import csv\n",
    "import json\n",
    "import os\n",
    "from typing import Dict, List, Any, Union\n",
    "\n",
    "import numpy as np\n",
    "from sklearn.metrics import ndcg_score\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "device = torch.device(f\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "data_directory = \"data\"\n",
    "results_directory = \"results\"\n",
    "\n",
    "seeds = range(33, 43)\n",
    "batch_size = 32\n",
    "num_steps = 10000\n",
    "lr_decay_gamma = 0.7\n",
    "steps_per_decay = 500"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c78d3dd8",
   "metadata": {
    "papermill": {
     "duration": 0.00874,
     "end_time": "2023-10-11T03:34:12.063335",
     "exception": false,
     "start_time": "2023-10-11T03:34:12.054595",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Load dataset\n",
    "\n",
    "<https://webscope.sandbox.yahoo.com/catalog.php?datatype=c>\n",
    "\n",
    "Assuming Yahoo! LETOR dataset is downloaded to `data_directory` as follows:\n",
    "\n",
    "`data_directory`  \n",
    "  ├─ set1.train.txt  \n",
    "  ├─ set1.test.txt  \n",
    "  ├─ set1.valid.txt (optional)  \n",
    "  ├─ set2.train.txt  \n",
    "  ├─ set2.test.txt  \n",
    "  ├─ set2.valid.txt (optional)  \n",
    "  └─ feature_index.json (provided)\n",
    "\n",
    "Each file has the following content (3 truncated rows are shown):\n",
    "\n",
    "```\n",
    "1 qid:19945 1:0.74142 6:0.90265 7:0.8087 8:0.79522 9:0.80003 11:0.56756 12:0.064688\n",
    "2 qid:19945 8:0.898 9:0.80003 12:0.064688 17:0.035654 21:0.050978 27:0.16583 28:0.75984\n",
    "2 qid:19945 6:0.91949 8:0.86723 9:0.80003 12:0.064688 17:0.04154 20:0.96351 21:0.050978\n",
    "```\n",
    "\n",
    "Where column 1 is relevance label, column 2 is query id, and the rest are\n",
    "features as key-value pairs (values are defaulted to 0 for missing keys).  The\n",
    "keys are 1-indexed (1 to 700).\n",
    "\n",
    "We additionally provide a index file `feature_index.json` for indicating which\n",
    "features (keys) are available on Set 1, Set 2, or on both (Set 1 & 2)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "41291cd7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-10-11T03:34:12.077611Z",
     "iopub.status.busy": "2023-10-11T03:34:12.076798Z",
     "iopub.status.idle": "2023-10-11T03:37:07.480634Z",
     "shell.execute_reply": "2023-10-11T03:37:07.479628Z"
    },
    "papermill": {
     "duration": 175.423928,
     "end_time": "2023-10-11T03:37:07.493236",
     "exception": false,
     "start_time": "2023-10-11T03:34:12.069308",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of shared features: 415\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train dataset size (Set 1): 19944\n",
      "Train dataset size (Set 2): 1266\n",
      "Test dataset size (Set 1): 6983\n",
      "Test dataset size (Set 2): 3798\n"
     ]
    }
   ],
   "source": [
    "n_features = 700\n",
    "\n",
    "\n",
    "def load_yahoo(\n",
    "    path: str,\n",
    "    return_type: str = 'np'\n",
    ") -> List[Dict[str, Union[np.ndarray, torch.Tensor]]]:\n",
    "  # use return_type='pt' to avoid memory communication overhead.\n",
    "  # yahoo letor requires about 2.8 gb.\n",
    "\n",
    "  # `dataset` is a list of dict, and each dict contains 3 key-value pairs,\n",
    "  # \"query_id\": int, \"label\": 1d numpy array, \"feature\": 2d numpy array.\n",
    "  dataset = []\n",
    "\n",
    "  with open(path, \"r\") as f:\n",
    "    reader = csv.reader(f, delimiter=\" \")\n",
    "\n",
    "    query_id_to_idx = {}\n",
    "    for row in reader:\n",
    "      query_id = int(row[1].split(\":\")[1])\n",
    "      if query_id not in query_id_to_idx:\n",
    "        query_id_to_idx[query_id] = len(query_id_to_idx)\n",
    "        dataset.append({\"query_id\": query_id, \"label\": [], \"feature\": []})\n",
    "\n",
    "      # index in `dataset` for current query\n",
    "      idx = query_id_to_idx[query_id]\n",
    "\n",
    "      feature = [0 for _ in range(n_features)]\n",
    "      for x in row[2:]:\n",
    "        feature_idx, feature_value = x.split(\":\")\n",
    "        feature[int(feature_idx) - 1] = float(feature_value)\n",
    "      dataset[idx][\"feature\"].append(feature)\n",
    "      dataset[idx][\"label\"].append(float(row[0]))\n",
    "\n",
    "  if return_type == 'np':\n",
    "    return [{k: np.array(v) if k != 'query_id' else v\n",
    "             for k, v in x.items()}\n",
    "            for x in dataset]\n",
    "  elif return_type == 'pt':\n",
    "    return [{\n",
    "        \"query_id\": torch.tensor(x[\"query_id\"]).long().to(device),\n",
    "        \"label\": torch.tensor(x[\"label\"]).float().to(device),\n",
    "        \"feature\": torch.tensor(x[\"feature\"]).float().to(device)\n",
    "    } for x in dataset]\n",
    "\n",
    "\n",
    "# Get mask for shared features\n",
    "path = f\"{data_directory}/feature_index.json\"\n",
    "with open(path, \"r\") as f:\n",
    "  shared_features = json.load(f)[\"Set 1 & 2\"]\n",
    "shared_features_mask = torch.full((n_features,), False)\n",
    "for idx in shared_features:\n",
    "  shared_features_mask[int(idx) - 1] = True\n",
    "\n",
    "print(\"Number of shared features:\", shared_features_mask.sum().item())\n",
    "\n",
    "# Load datasets\n",
    "train_dataset_set1 = load_yahoo(f\"{data_directory}/set1.train.txt\",\n",
    "                                return_type='pt')\n",
    "train_dataset_set2 = load_yahoo(f\"{data_directory}/set2.train.txt\",\n",
    "                                return_type='pt')\n",
    "test_dataset_set1 = load_yahoo(f\"{data_directory}/set1.test.txt\",\n",
    "                               return_type='pt')\n",
    "test_dataset_set2 = load_yahoo(f\"{data_directory}/set2.test.txt\",\n",
    "                               return_type='pt')\n",
    "\n",
    "print(\"Train dataset size (Set 1):\", len(train_dataset_set1))\n",
    "print(\"Train dataset size (Set 2):\", len(train_dataset_set2))\n",
    "print(\"Test dataset size (Set 1):\", len(test_dataset_set1))\n",
    "print(\"Test dataset size (Set 2):\", len(test_dataset_set2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d1478af8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-10-11T03:37:07.509286Z",
     "iopub.status.busy": "2023-10-11T03:37:07.508527Z",
     "iopub.status.idle": "2023-10-11T03:37:07.527642Z",
     "shell.execute_reply": "2023-10-11T03:37:07.526694Z"
    },
    "papermill": {
     "duration": 0.028654,
     "end_time": "2023-10-11T03:37:07.529783",
     "exception": false,
     "start_time": "2023-10-11T03:37:07.501129",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Create dataloaders\n",
    "\n",
    "\n",
    "class DataCollatorWithPadding:\n",
    "  \"\"\"Collator for batching data.\n",
    "\n",
    "  Given a batch of examples loaded above, this collator will pad the features\n",
    "  and output a dict of tensors,\n",
    "  \"query_id\": shape (batch_size,), \"labels\": shape (batch_size, max_length),\n",
    "  \"features\": shape (batch_size, max_length, n_features).\n",
    "  \"\"\"\n",
    "\n",
    "  def __init__(self, padding_value: float = -100):\n",
    "    self.padding_value = padding_value\n",
    "\n",
    "  def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:\n",
    "    query_ids = [x[\"query_id\"] for x in examples]\n",
    "    labels = [x[\"label\"] for x in examples]\n",
    "    features = [x[\"feature\"] for x in examples]\n",
    "\n",
    "    max_length = max([len(x) for x in features])\n",
    "    if type(features[0]) == np.ndarray:\n",
    "      query_ids = torch.from_numpy(np.stack(query_ids)).long()\n",
    "      labels = torch.from_numpy(\n",
    "          np.stack([\n",
    "              np.pad(x, ((0, max_length - len(x))),\n",
    "                     constant_values=self.padding_value) for x in labels\n",
    "          ])).float()\n",
    "      features = torch.from_numpy(\n",
    "          np.stack([\n",
    "              np.pad(x, ((0, max_length - len(x)), (0, 0)),\n",
    "                     constant_values=self.padding_value) for x in features\n",
    "          ])).float()\n",
    "\n",
    "    elif type(features[0]) == torch.Tensor:\n",
    "      query_ids = torch.stack(query_ids)\n",
    "      labels = torch.stack([\n",
    "          torch.nn.functional.pad(x, (0, max_length - len(x)),\n",
    "                                  value=self.padding_value) for x in labels\n",
    "      ])\n",
    "      features = torch.stack([\n",
    "          torch.nn.functional.pad(x, (0, 0, 0, max_length - len(x)),\n",
    "                                  value=self.padding_value) for x in features\n",
    "      ])\n",
    "\n",
    "    return {\"query_id\": query_ids, \"labels\": labels, \"features\": features}\n",
    "\n",
    "\n",
    "def cycle(iterable):\n",
    "  \"\"\"Repeatly cycle through an iterable by caching StopIteration.\"\"\"\n",
    "  iterator = iter(iterable)\n",
    "  while True:\n",
    "    try:\n",
    "      yield next(iterator)\n",
    "    except StopIteration:\n",
    "      iterator = iter(iterable)\n",
    "      yield next(iterator)\n",
    "\n",
    "\n",
    "train_dataloader_set1 = torch.utils.data.DataLoader(\n",
    "    train_dataset_set1,\n",
    "    batch_size=batch_size,\n",
    "    shuffle=True,\n",
    "    collate_fn=DataCollatorWithPadding())\n",
    "train_dataloader_set2 = torch.utils.data.DataLoader(\n",
    "    train_dataset_set2,\n",
    "    batch_size=batch_size,\n",
    "    shuffle=True,\n",
    "    collate_fn=DataCollatorWithPadding())\n",
    "test_dataloader_set1 = torch.utils.data.DataLoader(\n",
    "    test_dataset_set1,\n",
    "    batch_size=batch_size,\n",
    "    shuffle=False,\n",
    "    collate_fn=DataCollatorWithPadding())\n",
    "test_dataloader_set2 = torch.utils.data.DataLoader(\n",
    "    test_dataset_set2,\n",
    "    batch_size=batch_size,\n",
    "    shuffle=False,\n",
    "    collate_fn=DataCollatorWithPadding())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7581cbac",
   "metadata": {
    "papermill": {
     "duration": 0.006557,
     "end_time": "2023-10-11T03:37:07.544377",
     "exception": false,
     "start_time": "2023-10-11T03:37:07.537820",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Model definitions and helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e697d430",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-10-11T03:37:07.557135Z",
     "iopub.status.busy": "2023-10-11T03:37:07.556242Z",
     "iopub.status.idle": "2023-10-11T03:37:07.569608Z",
     "shell.execute_reply": "2023-10-11T03:37:07.568264Z"
    },
    "papermill": {
     "duration": 0.021505,
     "end_time": "2023-10-11T03:37:07.571573",
     "exception": false,
     "start_time": "2023-10-11T03:37:07.550068",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Helpers in forward pass, including gradient reversal and softmax cross entropy\n",
    "# ranking loss\n",
    "\n",
    "\n",
    "class ScaleBackwardGrad(torch.autograd.Function):\n",
    "\n",
    "  @staticmethod\n",
    "  def forward(ctx, x, lambd):\n",
    "    ctx.lambd = lambd\n",
    "    return x.view_as(x)\n",
    "\n",
    "  @staticmethod\n",
    "  def backward(ctx, grad_output):\n",
    "    return grad_output * ctx.lambd, None\n",
    "\n",
    "\n",
    "class GradientReversalLayer(torch.nn.Module):\n",
    "\n",
    "  def __init__(self, lambd=1):\n",
    "    super(GradientReversalLayer, self).__init__()\n",
    "    self.lambd = lambd\n",
    "\n",
    "  def forward(self, x):\n",
    "    return ScaleBackwardGrad.apply(x, -self.lambd)\n",
    "\n",
    "  def extra_repr(self):\n",
    "    return 'lambda={}'.format(self.lambd)\n",
    "\n",
    "\n",
    "def softmax_cross_entropy_with_logits(logits,\n",
    "                                      labels,\n",
    "                                      ignore_index=-100) -> torch.Tensor:\n",
    "  # logits.shape = labels.size = (batch_size, list_size)\n",
    "  log_softmax = torch.nn.functional.log_softmax(logits, dim=1)\n",
    "  loss = -torch.sum(log_softmax * labels * (labels != ignore_index), dim=1)\n",
    "  return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "40e5e474",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-10-11T03:37:07.586143Z",
     "iopub.status.busy": "2023-10-11T03:37:07.585110Z",
     "iopub.status.idle": "2023-10-11T03:37:07.614273Z",
     "shell.execute_reply": "2023-10-11T03:37:07.613273Z"
    },
    "papermill": {
     "duration": 0.039518,
     "end_time": "2023-10-11T03:37:07.617306",
     "exception": false,
     "start_time": "2023-10-11T03:37:07.577788",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class MLPRanker(torch.nn.Module):\n",
    "  \"\"\"3-layer MLP scoring model.\"\"\"\n",
    "\n",
    "  def __init__(self, feature_dim=256):\n",
    "    super(MLPRanker, self).__init__()\n",
    "    n_shared = shared_features_mask.sum().item()\n",
    "    self.input_layer_shared = torch.nn.Sequential(\n",
    "        torch.nn.Linear(n_shared, 1024),\n",
    "        torch.nn.ReLU(),\n",
    "        torch.nn.Linear(1024, 256),\n",
    "        torch.nn.ReLU(),\n",
    "    )\n",
    "    self.input_layer_disjoint = torch.nn.Sequential(\n",
    "        torch.nn.Linear(n_features - n_shared, 1024),\n",
    "        torch.nn.ReLU(),\n",
    "        torch.nn.Linear(1024, 256),\n",
    "        torch.nn.ReLU(),\n",
    "    )\n",
    "    self.hidden_layer = torch.nn.Sequential(\n",
    "        torch.nn.Linear(512, feature_dim),\n",
    "        torch.nn.ReLU(),\n",
    "    )\n",
    "    self.output_layer = torch.nn.Linear(feature_dim, 1)\n",
    "    self.dropout = torch.nn.Dropout(0.1)\n",
    "\n",
    "  def forward(self, x, output_features=False):\n",
    "    x_shared = x[:, :, shared_features_mask.to(x.device)]\n",
    "    x_disjoint = x[:, :, ~shared_features_mask.to(x.device)]\n",
    "    hidden_states_shared = self.input_layer_shared(x_shared)\n",
    "    hidden_states_disjoint = self.input_layer_disjoint(x_disjoint)\n",
    "    hidden_states = torch.cat((hidden_states_shared, hidden_states_disjoint),\n",
    "                              dim=-1)\n",
    "    features = self.hidden_layer(hidden_states)\n",
    "    logits = self.output_layer(self.dropout(features))\n",
    "    if output_features:\n",
    "      return logits, features\n",
    "    else:\n",
    "      return logits\n",
    "\n",
    "\n",
    "class MLPDiscriminator(torch.nn.Module):\n",
    "  \"\"\"Discriminator with 3-layer MLP architecture.\"\"\"\n",
    "\n",
    "  def __init__(self, feature_dim=256):\n",
    "    super(MLPDiscriminator, self).__init__()\n",
    "    self.feature_dim = feature_dim\n",
    "    self.net_ad = torch.nn.Sequential(\n",
    "        torch.nn.Linear(feature_dim, 256),\n",
    "        torch.nn.ReLU(),\n",
    "        torch.nn.Linear(256, 256),\n",
    "        torch.nn.ReLU(),\n",
    "        torch.nn.Linear(256, 256),\n",
    "        torch.nn.ReLU(),\n",
    "        torch.nn.Linear(256, 1),\n",
    "    )\n",
    "\n",
    "  def forward(self, x):\n",
    "    logits = self.net_ad(x)\n",
    "    return logits\n",
    "\n",
    "\n",
    "class TransformerDiscriminator(torch.nn.Module):\n",
    "  \"\"\"Discriminator with three-T5 encoding blocks.\"\"\"\n",
    "\n",
    "  def __init__(self, feature_dim=256):\n",
    "    super(TransformerDiscriminator, self).__init__()\n",
    "    self.feature_dim = feature_dim\n",
    "    config = transformers.T5Config(\n",
    "        d_model=256,\n",
    "        d_kv=32,\n",
    "        d_ff=1024,\n",
    "        num_layers=3,\n",
    "        num_decoder_layers=0,\n",
    "        num_heads=4,\n",
    "        dropout_rate=0.0,\n",
    "        is_encoder_decoder=False,\n",
    "    )\n",
    "    if feature_dim != 256:\n",
    "      self.projection = torch.nn.Linear(256, feature_dim, bias=False)\n",
    "    self.blocks = torch.nn.ModuleList([\n",
    "        transformers.models.t5.modeling_t5.T5Block(\n",
    "            config, has_relative_attention_bias=False)\n",
    "        for _ in range(config.num_layers)\n",
    "    ])\n",
    "    self.output_layer = torch.nn.Linear(256, 1)\n",
    "\n",
    "  def forward(self, x, mask=None):\n",
    "    # mask is a 2d tensor of bool, indicating items that were padded\n",
    "    if mask is None:\n",
    "      mask = torch.ones_like(x.shape[:-1], dtype=torch.bool).to(x.device)\n",
    "    if self.feature_dim != 256:\n",
    "      x = self.projection(x)\n",
    "    for block in self.blocks:\n",
    "      layer_outputs = block(x, attention_mask=mask[:, None, None, :])\n",
    "      x = layer_outputs[0]\n",
    "    x = (x * mask[:, :, None]).sum(dim=1) / mask.sum(dim=1)[:,\n",
    "                                                            None]  # mean-pool\n",
    "    logits = self.output_layer(x)\n",
    "    return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4fd37401",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-10-11T03:37:07.638901Z",
     "iopub.status.busy": "2023-10-11T03:37:07.638315Z",
     "iopub.status.idle": "2023-10-11T03:37:07.662603Z",
     "shell.execute_reply": "2023-10-11T03:37:07.661527Z"
    },
    "papermill": {
     "duration": 0.037648,
     "end_time": "2023-10-11T03:37:07.665469",
     "exception": false,
     "start_time": "2023-10-11T03:37:07.627821",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def evaluate(model, dataloader):\n",
    "  # results is a list of dicts, each dict contains 3 key-value pairs,\n",
    "  # \"query_id\": int, \"label\": 1d numpy array, \"prediction\": 1d numpy array.\n",
    "  results = []\n",
    "  with torch.no_grad():\n",
    "    for batch in dataloader:\n",
    "      scores = model(batch[\"features\"].to(device))\n",
    "      for i in range(len(batch)):\n",
    "        mask = batch[\"labels\"][i, :] != -100\n",
    "        results.append({\n",
    "            \"query_id\":\n",
    "                int(batch[\"query_id\"][i].item()),\n",
    "            \"label\":\n",
    "                batch[\"labels\"][i, mask].detach().cpu().numpy().reshape(-1),\n",
    "            \"prediction\":\n",
    "                scores[i, mask].detach().cpu().numpy().reshape(-1),\n",
    "        })\n",
    "  return results\n",
    "\n",
    "\n",
    "def add_results(results_1, results_2):\n",
    "  if results_1 is None or results_2 is None:\n",
    "    return results_1 or results_2\n",
    "  else:\n",
    "    results = []\n",
    "    # assuming the results are of the same structure\n",
    "    for result_1, result_2 in zip(results_1, results_2):\n",
    "      results.append({\n",
    "          \"query_id\": result_1[\"query_id\"],\n",
    "          \"label\": result_1[\"label\"],\n",
    "          \"prediction\": result_1[\"prediction\"] + result_2[\"prediction\"],\n",
    "      })\n",
    "    return results\n",
    "\n",
    "\n",
    "def results_to_metrics(results):\n",
    "  # note that the numbers and sig tests on our paper are evaluated and computed\n",
    "  # using google's internal tools, and may differ from sklearn's ndcg_score\n",
    "  # results\n",
    "\n",
    "  def ndcg_score_(labels, predictions, k):\n",
    "    # wrapper around sklearn.metrics.ndcg_score to handle lists of\n",
    "    # different lengths\n",
    "    s = []\n",
    "    for l, p in zip(labels, predictions):\n",
    "      if len(l) > 1:\n",
    "        s.append(ndcg_score([l], [p], k=k))\n",
    "    return np.mean(s)\n",
    "\n",
    "  labels = [x[\"label\"] for x in results]\n",
    "  predictions = [x[\"prediction\"] for x in results]\n",
    "  return {\n",
    "      \"ndcg@5\": ndcg_score_(labels, predictions, k=5),\n",
    "      \"ndcg@10\": ndcg_score_(labels, predictions, k=10),\n",
    "      \"ndcg@20\": ndcg_score_(labels, predictions, k=20),\n",
    "  }\n",
    "\n",
    "\n",
    "def results_to_trec(results, output_path, model_name='null'):\n",
    "  with open(f\"{output_path}.trec\", 'w') as f:\n",
    "    for result in results:\n",
    "      prediction = result[\"prediction\"]\n",
    "      doc_id = np.arange(len(prediction))  # artificially create doc_id\n",
    "      result[\"doc_id\"] = doc_id\n",
    "      sorted_idx = prediction.argsort()[::-1]\n",
    "      for i in range(len(result[\"label\"])):\n",
    "        f.write(\n",
    "            f\"{result['query_id']} Q0 {result['query_id']}.doc{doc_id[sorted_idx[i]]} {i + 1} {prediction[sorted_idx[i]]} {model_name}\\n\"\n",
    "        )\n",
    "\n",
    "\n",
    "def dataset_to_qrels(dataset, output_path):\n",
    "  with open(f\"{output_path}\", 'w') as f:\n",
    "    f.write('query-id\\tcorpus-id\\tscore\\n')\n",
    "    for x in dataset:\n",
    "      for i in range(len(x[\"label\"])):\n",
    "        f.write(\n",
    "            f\"{x['query_id']}\\t{x['query_id']}.doc{i}\\t{int(x['label'][i])}\\n\"\n",
    "        )  # artificially create doc_id"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc93b389",
   "metadata": {
    "papermill": {
     "duration": 0.008528,
     "end_time": "2023-10-11T03:37:07.683853",
     "exception": false,
     "start_time": "2023-10-11T03:37:07.675325",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Training and evaluation with different methods\n",
    "\n",
    "For each method, we train `len(seeds)` number of models with different random\n",
    "seeds.  In the end, aggregate the scores computed by all models and compute the\n",
    "metrics w.r.t. aggregated scores.  We will also output `.trec` files containing\n",
    "the aggregated scores."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4bea25f1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-10-11T03:37:07.696855Z",
     "iopub.status.busy": "2023-10-11T03:37:07.696273Z",
     "iopub.status.idle": "2023-10-11T03:37:14.877342Z",
     "shell.execute_reply": "2023-10-11T03:37:14.876456Z"
    },
    "papermill": {
     "duration": 7.191122,
     "end_time": "2023-10-11T03:37:14.880461",
     "exception": false,
     "start_time": "2023-10-11T03:37:07.689339",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "if not os.path.exists(results_directory):\n",
    "  os.makedirs(results_directory)\n",
    "\n",
    "dataset_to_qrels(test_dataset_set1,\n",
    "                 f\"{results_directory}/yahoo_letor_set1.qrels\")\n",
    "dataset_to_qrels(test_dataset_set2,\n",
    "                 f\"{results_directory}/yahoo_letor_set2.qrels\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "650192c5",
   "metadata": {
    "papermill": {
     "duration": 0.010999,
     "end_time": "2023-10-11T03:37:14.902098",
     "exception": false,
     "start_time": "2023-10-11T03:37:14.891099",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "### Supervised training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6c4d6704",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-10-11T03:37:14.921382Z",
     "iopub.status.busy": "2023-10-11T03:37:14.920505Z",
     "iopub.status.idle": "2023-10-11T03:59:56.830909Z",
     "shell.execute_reply": "2023-10-11T03:59:56.829355Z"
    },
    "papermill": {
     "duration": 1361.932181,
     "end_time": "2023-10-11T03:59:56.844348",
     "exception": false,
     "start_time": "2023-10-11T03:37:14.912167",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lr=4e-05\n",
      "  Set 1: {'ndcg@5': 0.7428128573573406, 'ndcg@10': 0.7784310397975203, 'ndcg@20': 0.8166909746608406}\n",
      "  Set 2: {'ndcg@5': 0.777508134352758, 'ndcg@10': 0.8029083848600129, 'ndcg@20': 0.8499299923878426}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "lrs = [4e-5]\n",
    "\n",
    "for lr in lrs:\n",
    "  results_set1 = None\n",
    "  results_set2 = None\n",
    "\n",
    "  for seed in seeds:\n",
    "    transformers.set_seed(seed)\n",
    "\n",
    "    model = MLPRanker().to(device)\n",
    "    model.train()\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
    "    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,\n",
    "                                                       gamma=lr_decay_gamma)\n",
    "\n",
    "    iterator_set1 = cycle(train_dataloader_set1)\n",
    "    iterator_set2 = cycle(train_dataloader_set2)\n",
    "\n",
    "    for step in range(num_steps):\n",
    "      loss = 0\n",
    "      for iterator in [iterator_set1, iterator_set2]:\n",
    "        batch = next(iterator)\n",
    "        logits = model(batch[\"features\"].to(device))\n",
    "        loss += softmax_cross_entropy_with_logits(\n",
    "            logits.view(len(batch[\"features\"]), -1),\n",
    "            batch[\"labels\"].to(device)).mean()\n",
    "      loss.backward()\n",
    "      optimizer.step()\n",
    "      optimizer.zero_grad()\n",
    "      if (step + 1) % steps_per_decay == 0:\n",
    "        scheduler.step()\n",
    "\n",
    "    model.eval()\n",
    "    results_set1 = add_results(results_set1,\n",
    "                               evaluate(model, test_dataloader_set1))\n",
    "    results_set2 = add_results(results_set2,\n",
    "                               evaluate(model, test_dataloader_set2))\n",
    "\n",
    "  results_to_trec(\n",
    "      results_set1,\n",
    "      f\"{results_directory}/yahoo_letor_set1_supervised_lr={lr}.trec\",\n",
    "      \"supervised\",\n",
    "  )\n",
    "  results_to_trec(\n",
    "      results_set2,\n",
    "      f\"{results_directory}/yahoo_letor_set2_supervised_lr={lr}.trec\",\n",
    "      \"supervised\",\n",
    "  )\n",
    "\n",
    "  metrics_set1 = results_to_metrics(results_set1)\n",
    "  metrics_set2 = results_to_metrics(results_set2)\n",
    "  print(f\"lr={lr}\")\n",
    "  print(f\"  Set 1: {metrics_set1}\")\n",
    "  print(f\"  Set 2: {metrics_set2}\")\n",
    "  print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6e80891",
   "metadata": {
    "papermill": {
     "duration": 0.005649,
     "end_time": "2023-10-11T03:59:56.858550",
     "exception": false,
     "start_time": "2023-10-11T03:59:56.852901",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "### Zero-shot learning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d8c33bbc",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-10-11T03:59:56.870089Z",
     "iopub.status.busy": "2023-10-11T03:59:56.869665Z",
     "iopub.status.idle": "2023-10-11T04:11:52.595813Z",
     "shell.execute_reply": "2023-10-11T04:11:52.593644Z"
    },
    "papermill": {
     "duration": 715.74866,
     "end_time": "2023-10-11T04:11:52.612454",
     "exception": false,
     "start_time": "2023-10-11T03:59:56.863794",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lr=0.0008\n",
      "  Set 1: {'ndcg@5': 0.7535653084601909, 'ndcg@10': 0.7852975416304792, 'ndcg@20': 0.8231135592552501}\n",
      "  Set 2: {'ndcg@5': 0.747945382469857, 'ndcg@10': 0.7793506375122675, 'ndcg@20': 0.8326996005716688}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "lrs = [8e-4]\n",
    "\n",
    "for lr in lrs:\n",
    "  results_set1 = None\n",
    "  results_set2 = None\n",
    "\n",
    "  for seed in seeds:\n",
    "    transformers.set_seed(seed)\n",
    "\n",
    "    model = MLPRanker().to(device)\n",
    "    model.train()\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
    "    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,\n",
    "                                                       gamma=lr_decay_gamma)\n",
    "\n",
    "    iterator_set1 = cycle(train_dataloader_set1)\n",
    "\n",
    "    for step in range(num_steps):\n",
    "      batch = next(iterator_set1)\n",
    "      logits = model(batch[\"features\"].to(device))\n",
    "      loss = softmax_cross_entropy_with_logits(\n",
    "          logits.view(len(batch[\"features\"]), -1),\n",
    "          batch[\"labels\"].to(device)).mean()\n",
    "      loss.backward()\n",
    "      optimizer.step()\n",
    "      optimizer.zero_grad()\n",
    "      if (step + 1) % steps_per_decay == 0:\n",
    "        scheduler.step()\n",
    "\n",
    "    model.eval()\n",
    "    results_set1 = add_results(results_set1,\n",
    "                               evaluate(model, test_dataloader_set1))\n",
    "    results_set2 = add_results(results_set2,\n",
    "                               evaluate(model, test_dataloader_set2))\n",
    "\n",
    "  results_to_trec(\n",
    "      results_set1,\n",
    "      f\"{results_directory}/yahoo_letor_set1_zeroshot_lr={lr}.trec\",\n",
    "      \"zeroshot\",\n",
    "  )\n",
    "  results_to_trec(\n",
    "      results_set2,\n",
    "      f\"{results_directory}/yahoo_letor_set2_zeroshot_lr={lr}.trec\",\n",
    "      \"zeroshot\",\n",
    "  )\n",
    "\n",
    "  metrics_set1 = results_to_metrics(results_set1)\n",
    "  metrics_set2 = results_to_metrics(results_set2)\n",
    "  print(f\"lr={lr}\")\n",
    "  print(f\"  Set 1: {metrics_set1}\")\n",
    "  print(f\"  Set 2: {metrics_set2}\")\n",
    "  print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad127ab0",
   "metadata": {
    "papermill": {
     "duration": 0.005264,
     "end_time": "2023-10-11T04:11:52.625480",
     "exception": false,
     "start_time": "2023-10-11T04:11:52.620216",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "### Item-level alignment (ItemDA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3c52a60e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-10-11T04:11:52.638028Z",
     "iopub.status.busy": "2023-10-11T04:11:52.637457Z",
     "iopub.status.idle": "2023-10-11T04:48:29.010465Z",
     "shell.execute_reply": "2023-10-11T04:48:29.008295Z"
    },
    "papermill": {
     "duration": 2196.395981,
     "end_time": "2023-10-11T04:48:29.026323",
     "exception": false,
     "start_time": "2023-10-11T04:11:52.630342",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lr=0.0008, ad_lr=2, lambd=0.4\n",
      "  Set 1: {'ndcg@5': 0.7513071247131285, 'ndcg@10': 0.7840645657742897, 'ndcg@20': 0.8221046657034589}\n",
      "  Set 2: {'ndcg@5': 0.7577907814224591, 'ndcg@10': 0.7838200822785192, 'ndcg@20': 0.8369013069311035}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "lrs = [8e-4]\n",
    "ad_lrs = [2]  # as multiple of lr\n",
    "lambds = [0.4]\n",
    "\n",
    "for lr in lrs:\n",
    "  for ad_lr in ad_lrs:\n",
    "    for lambd in lambds:\n",
    "      results_set1 = None\n",
    "      results_set2 = None\n",
    "\n",
    "      for seed in seeds:\n",
    "        transformers.set_seed(seed)\n",
    "\n",
    "        model = MLPRanker().to(device)\n",
    "        model.train()\n",
    "        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
    "        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,\n",
    "                                                           gamma=lr_decay_gamma)\n",
    "\n",
    "        discriminator = MLPDiscriminator().to(device)\n",
    "        optimizer_ad = torch.optim.AdamW(discriminator.parameters(),\n",
    "                                         lr=lr * ad_lr)\n",
    "        grl = GradientReversalLayer(lambd).to(device)\n",
    "\n",
    "        iterator_set1 = cycle(train_dataloader_set1)\n",
    "        iterator_set2 = cycle(train_dataloader_set2)\n",
    "\n",
    "        for step in range(num_steps):\n",
    "          loss = 0\n",
    "\n",
    "          for a, iterator in enumerate([iterator_set1, iterator_set2]):\n",
    "            batch = next(iterator)\n",
    "            mask = batch[\"labels\"].to(device) != -100\n",
    "\n",
    "            logits, features = model(batch[\"features\"].to(device),\n",
    "                                     output_features=True)\n",
    "            if a == 0:\n",
    "              # only perform supervised learning on set 1\n",
    "              loss += softmax_cross_entropy_with_logits(\n",
    "                  logits.view(len(batch[\"features\"]), -1),\n",
    "                  batch[\"labels\"].to(device)).mean()\n",
    "\n",
    "            features_ = grl(features).view(\n",
    "                len(batch[\"features\"]) * len(batch[\"features\"][0]),\n",
    "                -1)[mask.view(-1)]\n",
    "            logits_ad = discriminator(features_)\n",
    "            loss += torch.nn.BCELoss()(torch.sigmoid(logits_ad),\n",
    "                                       torch.ones_like(logits_ad) * a)\n",
    "\n",
    "          loss.backward()\n",
    "          optimizer.step()\n",
    "          optimizer.zero_grad()\n",
    "          optimizer_ad.step()\n",
    "          optimizer_ad.zero_grad()\n",
    "          if (step + 1) % steps_per_decay == 0:\n",
    "            scheduler.step()\n",
    "\n",
    "        model.eval()\n",
    "        results_set1 = add_results(results_set1,\n",
    "                                   evaluate(model, test_dataloader_set1))\n",
    "        results_set2 = add_results(results_set2,\n",
    "                                   evaluate(model, test_dataloader_set2))\n",
    "\n",
    "      results_to_trec(\n",
    "          results_set1,\n",
    "          f\"{results_directory}/yahoo_letor_set1_itemda_lr={lr}_ad_lr={ad_lr}_lambd={lambd}.trec\",\n",
    "          \"itemda\",\n",
    "      )\n",
    "      results_to_trec(\n",
    "          results_set2,\n",
    "          f\"{results_directory}/yahoo_letor_set2_itemda_lr={lr}_ad_lr={ad_lr}_lambd={lambd}.trec\",\n",
    "          \"itemda\",\n",
    "      )\n",
    "\n",
    "      metrics_set1 = results_to_metrics(results_set1)\n",
    "      metrics_set2 = results_to_metrics(results_set2)\n",
    "      print(f\"lr={lr}, ad_lr={ad_lr}, lambd={lambd}\")\n",
    "      print(f\"  Set 1: {metrics_set1}\")\n",
    "      print(f\"  Set 2: {metrics_set2}\")\n",
    "      print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4da9b30",
   "metadata": {
    "papermill": {
     "duration": 0.005519,
     "end_time": "2023-10-11T04:48:29.039492",
     "exception": false,
     "start_time": "2023-10-11T04:48:29.033973",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "### List-level alignment (ListDA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "087904b5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-10-11T04:48:29.050936Z",
     "iopub.status.busy": "2023-10-11T04:48:29.050372Z",
     "iopub.status.idle": "2023-10-11T06:15:02.789310Z",
     "shell.execute_reply": "2023-10-11T06:15:02.787348Z"
    },
    "papermill": {
     "duration": 5193.761286,
     "end_time": "2023-10-11T06:15:02.805282",
     "exception": false,
     "start_time": "2023-10-11T04:48:29.043996",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lr=0.0008, ad_lr=2, lambd=0.8\n",
      "  Set 1: {'ndcg@5': 0.7484887634992445, 'ndcg@10': 0.7823414315441499, 'ndcg@20': 0.8218222924302829}\n",
      "  Set 2: {'ndcg@5': 0.7613440145946231, 'ndcg@10': 0.7860804736579625, 'ndcg@20': 0.8373959081928443}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "lrs = [8e-4]\n",
    "ad_lrs = [2]  # as multiple of lr\n",
    "lambds = [0.8]\n",
    "\n",
    "for lr in lrs:\n",
    "  for ad_lr in ad_lrs:\n",
    "    for lambd in lambds:\n",
    "      results_set1 = None\n",
    "      results_set2 = None\n",
    "\n",
    "      for seed in seeds:\n",
    "        transformers.set_seed(seed)\n",
    "\n",
    "        model = MLPRanker().to(device)\n",
    "        model.train()\n",
    "        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
    "        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,\n",
    "                                                           gamma=lr_decay_gamma)\n",
    "\n",
    "        discriminator = TransformerDiscriminator().to(device)\n",
    "        optimizer_ad = torch.optim.AdamW(discriminator.parameters(),\n",
    "                                         lr=lr * ad_lr)\n",
    "        grl = GradientReversalLayer(lambd).to(device)\n",
    "\n",
    "        iterator_set1 = cycle(train_dataloader_set1)\n",
    "        iterator_set2 = cycle(train_dataloader_set2)\n",
    "\n",
    "        for step in range(num_steps):\n",
    "          loss = 0\n",
    "\n",
    "          for a, iterator in enumerate([iterator_set1, iterator_set2]):\n",
    "            batch = next(iterator)\n",
    "            mask = batch[\"labels\"].to(device) != -100\n",
    "\n",
    "            logits, features = model(batch[\"features\"].to(device),\n",
    "                                     output_features=True)\n",
    "            if a == 0:\n",
    "              # only perform supervised learning on set 1\n",
    "              loss += softmax_cross_entropy_with_logits(\n",
    "                  logits.view(len(batch[\"features\"]), -1),\n",
    "                  batch[\"labels\"].to(device)).mean()\n",
    "\n",
    "            features_ = grl(features)\n",
    "            logits_ad = discriminator(features_, mask=mask)\n",
    "            loss += torch.nn.BCELoss()(torch.sigmoid(logits_ad),\n",
    "                                       torch.ones_like(logits_ad) * a)\n",
    "\n",
    "          loss.backward()\n",
    "          optimizer.step()\n",
    "          optimizer.zero_grad()\n",
    "          optimizer_ad.step()\n",
    "          optimizer_ad.zero_grad()\n",
    "          if (step + 1) % steps_per_decay == 0:\n",
    "            scheduler.step()\n",
    "\n",
    "        model.eval()\n",
    "        results_set1 = add_results(results_set1,\n",
    "                                   evaluate(model, test_dataloader_set1))\n",
    "        results_set2 = add_results(results_set2,\n",
    "                                   evaluate(model, test_dataloader_set2))\n",
    "\n",
    "      results_to_trec(\n",
    "          results_set1,\n",
    "          f\"{results_directory}/yahoo_letor_set1_listda_lr={lr}_ad_lr={ad_lr}_lambd={lambd}.trec\",\n",
    "          \"listda\",\n",
    "      )\n",
    "      results_to_trec(\n",
    "          results_set2,\n",
    "          f\"{results_directory}/yahoo_letor_set2_listda_lr={lr}_ad_lr={ad_lr}_lambd={lambd}.trec\",\n",
    "          \"listda\",\n",
    "      )\n",
    "\n",
    "      metrics_set1 = results_to_metrics(results_set1)\n",
    "      metrics_set2 = results_to_metrics(results_set2)\n",
    "      print(f\"lr={lr}, ad_lr={ad_lr}, lambd={lambd}\")\n",
    "      print(f\"  Set 1: {metrics_set1}\")\n",
    "      print(f\"  Set 2: {metrics_set2}\")\n",
    "      print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "779b9110",
   "metadata": {
    "papermill": {
     "duration": 0.005676,
     "end_time": "2023-10-11T06:15:02.818542",
     "exception": false,
     "start_time": "2023-10-11T06:15:02.812866",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rankda",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 9656.94337,
   "end_time": "2023-10-11T06:15:05.337723",
   "environment_variables": {},
   "exception": null,
   "input_path": "yahoo_letor_ copy.ipynb",
   "output_path": "yahoo_letor.ipynb",
   "parameters": {},
   "start_time": "2023-10-11T03:34:08.394353",
   "version": "2.3.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
