{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iX41yOwnHL4m"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import sklearn.metrics"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Functions for fitting Gaussian and compute Mahalanobis distance\n",
        "def compute_mean_and_cov(embeds, labels):\n",
        "  \"\"\"Computes class-specific means and shared covariance matrix of given embedding.\n",
        "\n",
        "  The computation follows Eq (1) in [1].\n",
        "\n",
        "  Args:\n",
        "    embeds: An np.array of size [n_train_sample, n_dim], where n_train_sample is\n",
        "      the sample size of training set, n_dim is the dimension of the embedding.\n",
        "    labels: An np.array of size [n_train_sample, ]\n",
        "\n",
        "  Returns:\n",
        "    mean_list: A list of len n_class, and the i-th element is an np.array of\n",
        "    size [n_dim, ] corresponding to the mean of the fitted Guassian distribution\n",
        "    for the i-th class.\n",
        "    cov: The shared covariance mmatrix of the size [n_dim, n_dim].\n",
        "  \"\"\"\n",
        "  n_dim = embeds.shape[1]\n",
        "  class_ids = np.unique(labels)\n",
        "  mean_list = []\n",
        "  cov = np.zeros((n_dim, n_dim))\n",
        "\n",
        "  for class_id in class_ids:\n",
        "    data = embeds[labels == class_id]\n",
        "    data_mean = np.mean(data, axis=0)\n",
        "    cov += np.dot((data - data_mean).T, (data - data_mean))\n",
        "    mean_list.append(data_mean)\n",
        "  cov = cov / len(labels)\n",
        "  return mean_list, cov\n",
        "\n",
        "\n",
        "def compute_mahalanobis_distance(embeds, mean_list, cov, epsilon=1e-20):\n",
        "  \"\"\"Computes Mahalanobis distance between the input to the fitted Guassians.\n",
        "\n",
        "  The computation follows Eq.(2) in [1].\n",
        "\n",
        "  Args:\n",
        "    embeds: An np.array of size [n_test_sample, n_dim], where n_test_sample is\n",
        "      the sample size of the test set, n_dim is the size of the embeddings.\n",
        "    mean_list: A list of len n_class, and the i-th element is an np.array of\n",
        "      size [n_dim, ] corresponding to the mean of the fitted Guassian\n",
        "      distribution for the i-th class.\n",
        "    cov: The shared covariance mmatrix of the size [n_dim, n_dim].\n",
        "    epsilon: The small value added to the diagonal of the covariance matrix to\n",
        "      avoid singularity.\n",
        "\n",
        "  Returns:\n",
        "    out: An np.array of size [n_test_sample, n_class] where the [i, j] element\n",
        "    corresponds to the Mahalanobis distance between i-th sample to the j-th\n",
        "    class Guassian.\n",
        "  \"\"\"\n",
        "  n_sample = embeds.shape[0]\n",
        "  n_class = len(mean_list)\n",
        "\n",
        "  v = cov + np.eye(cov.shape[0], dtype=int) * epsilon  # avoid singularity\n",
        "  vi = np.linalg.inv(v)\n",
        "  means = np.array(mean_list)\n",
        "\n",
        "  out = np.zeros((n_sample, n_class))\n",
        "  for i in range(n_sample):\n",
        "    x = embeds[i]\n",
        "    out[i, :] = np.diag(np.dot(np.dot((x - means), vi), (x - means).T))\n",
        "  return out\n",
        "\n",
        "\n",
        "def compute_ood_metrics(targets,\n",
        "                        predictions,\n",
        "                        tpr_thres=0.95,\n",
        "                        targets_threshold=None):\n",
        "  \"\"\"Computes Area Under the ROC and PR curves and FPRN.\n",
        "\n",
        "  ROC - Receiver Operating Characteristic\n",
        "  PR  - Precision and Recall\n",
        "  FPRN - False positive rate at which true positive rate is N.\n",
        "\n",
        "  Args:\n",
        "    targets: np.ndarray of targets, either 0 or 1, or continuous values.\n",
        "    predictions: np.ndarray of predictions, any value.\n",
        "    tpr_thres: float, threshold for true positive rate.\n",
        "    targets_threshold: float, if target values are continuous values, this\n",
        "      threshold binarizes them.\n",
        "\n",
        "  Returns:\n",
        "    A dictionary with AUC-ROC, AUC-PR, and FPRN scores.\n",
        "  \"\"\"\n",
        "\n",
        "  if targets_threshold is not None:\n",
        "    targets = np.array(targets)\n",
        "    targets = np.where(targets < targets_threshold,\n",
        "                       np.zeros_like(targets, dtype=np.int32),\n",
        "                       np.ones_like(targets, dtype=np.int32))\n",
        "\n",
        "  fpr, tpr, _ = sklearn.metrics.roc_curve(targets, predictions)\n",
        "  fprn = fpr[np.argmax(tpr >= tpr_thres)]\n",
        "\n",
        "  return {\n",
        "      'auroc': sklearn.metrics.roc_auc_score(targets, predictions),\n",
        "      'auprc': sklearn.metrics.average_precision_score(targets, predictions),\n",
        "      'fprn': fprn,\n",
        "  }\n"
      ],
      "metadata": {
        "id": "mb7j8xMm-XbT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Steps for computing Relative Mahalanobis distance (RMD) OOD score"
      ],
      "metadata": {
        "id": "ey44b79vA1et"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# (1) Prepare feature embeddings, embs_train_ind (NxD), for in-domain training data. \n",
        "# (2) Prepare the same number of feature embeddings, embs_train_ood (NxD), for\n",
        "# general domain data (e.g. C4 for summarization, or ParaGrawl for translation). \n",
        "\n",
        "# (3) Fit Gaussian distributions for in-domain and general domain respectively\n",
        "mean_list, cov = compute_mean_and_cov(embs_train_ind, np.zeros(len(embs_ind)))\n",
        "mean_list0, cov0 = compute_mean_and_cov(embs_train_ood, np.zeros(len(embs_ood)))\n",
        "\n",
        "# (4) Prepare feature embeddings, embs_ind and embs_ood for the test in-domain \n",
        "# and test OOD data\n",
        "\n",
        "# (5) Compute RMD OOD score for the test examples dist - dist_0\n",
        "mdist_ind = compute_mahalanobis_distance(embs_ind, mean_list, cov).reshape(-1)\n",
        "mdist_ood = compute_mahalanobis_distance(embs_ood, mean_list, cov).reshape(-1)\n",
        "\n",
        "mdist0_ind = compute_mahalanobis_distance(embs_ind, mean_list0, cov0).reshape(-1)\n",
        "mdist0_ood = compute_mahalanobis_distance(embs_ood, mean_list0, cov0).reshape(-1)\n",
        "\n",
        "scores_ind, scores_ood = [mdist_ind - mdist0_ind, mdist_ood - mdist0_ood]\n",
        "\n",
        "# (6) Compute AUROC for OOD detection\n",
        "labels_ind, labels_ood = [np.zeros_like(scores_ind), np.ones_like(scores_ood)]\n",
        "auc_rmd = compute_ood_metrics(np.concatenate((labels_ind, labels_ood)), np.concatenate((scores_ind, scores_ood)))\n",
        "print('OOD metrics based on Relative Mahalanobis distance', auc_rmd)"
      ],
      "metadata": {
        "id": "IV41iUDdDR4N"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}