{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# The Principle of Selecting Polynomial Coefficients\n"
      ],
      "metadata": {
        "id": "ZYvxDePHkNB6"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import math\n",
        "import colabtools\n",
        "import json\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import functools\n",
        "from absl import logging\n",
        "from colabtools import proto, stubby\n",
        "from tqdm import tqdm\n",
        "import tensorflow as tf\n",
        "from scipy import optimize\n",
        "from multiprocessing import Pool\n",
        "\n",
        "np.seterr(divide='ignore')"
      ],
      "metadata": {
        "id": "qlQd_YGwIgkd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def split_concatenated_input(concatenated_input, seperator_list):\n",
        "  \"\"\"Splits a concatenated input string into segments.\n",
        "\n",
        "  The input concatenated_input is of format \"seperator_list[0] segment_0\n",
        "  seperator_list[1] segment_1 ...... seperator_list[i] segment_i ...\".\n",
        "\n",
        "  Args:\n",
        "    concatenated_input: A string of multiple concatenated segments.\n",
        "    seperator_list: A list of seperators used to split the concatenated_input.\n",
        "\n",
        "  Returns:\n",
        "    A list of segments in the concatenated_input.\n",
        "  \"\"\"\n",
        "\n",
        "  concatenated_input = concatenated_input.strip()\n",
        "  concatenated_input = concatenated_input[len(seperator_list[0]):]\n",
        "\n",
        "  segments = []\n",
        "  for i in range(1, len(seperator_list)):\n",
        "    current_sep = seperator_list[i]\n",
        "    snippets = concatenated_input.split(current_sep)\n",
        "\n",
        "    # Under rare cases, the text contains multiple seperator strings and we\n",
        "    # always assume the first snippet corresponds to the target segment.\n",
        "    segments.append(snippets[0].strip())\n",
        "    concatenated_input = current_sep.join(snippets[1:])\n",
        "\n",
        "  segments.append(concatenated_input)\n",
        "  return segments\n",
        "\n",
        "\n",
        "def load_weak_supervision_data(file_path):\n",
        "  \"\"\"For distillation pipeline, weak supervision data contains:\n",
        "    1. \"index\": the index of this example.\n",
        "    2. \"model_input\": a concatenated string of label, title, and passage.\n",
        "\n",
        "  Returns:\n",
        "    A list of [index, label, title, passage, class_label, class_probs].\n",
        "  \"\"\"\n",
        "  tf_dataset = tf.data.Dataset.list_files(file_path)\n",
        "  tf_dataset = tf_dataset.flat_map(tf.data.TFRecordDataset)\n",
        "\n",
        "  data = []\n",
        "  seperator_list = [\"label entailment: label: \", \"title: \", \"passage: \"]\n",
        "  for raw_record in tqdm(tf_dataset):\n",
        "    example = tf.train.Example()\n",
        "    example.ParseFromString(raw_record.numpy())\n",
        "    index = int(example.features.feature[\"index\"].bytes_list.value[0].decode())\n",
        "    label, title, passage = split_concatenated_input(\n",
        "        example.features.feature[\"model_input\"].bytes_list.value[0].decode(),\n",
        "        seperator_list)\n",
        "\n",
        "    first_model_output = \\\n",
        "        example.features.feature[\"model_output\"].bytes_list.value[0].decode()\n",
        "    first_token = first_model_output.split(\" \")[0].lower()\n",
        "    first_model_output_score = example.features.feature[\n",
        "        \"model_output_score\"].float_list.value[0]\n",
        "    first_model_output_prob = np.exp(first_model_output_score)\n",
        "    if first_token == \"yes\":\n",
        "      class_label = 1\n",
        "      class_probs = [1.0 - first_model_output_prob, first_model_output_prob]\n",
        "    else:\n",
        "      class_label = 0\n",
        "      class_probs = [first_model_output_prob, 1.0 - first_model_output_prob]\n",
        "    data.append({\n",
        "        \"index\": index,\n",
        "        \"label\": label,\n",
        "        \"title\": title,\n",
        "        \"passage\": passage,\n",
        "        \"class\": class_label,\n",
        "        \"class_probs\": class_probs,\n",
        "    })\n",
        "\n",
        "  return pd.DataFrame(data)\n",
        "\n",
        "\n",
        "def load_labeled_data(file_path):\n",
        "  \"\"\"For distillation pipeline, labeled data contains:\n",
        "    1. \"index\": the index of this example.\n",
        "    2. \"model_input\": a concatenated string of label, title, and passage.\n",
        "\n",
        "  Returns:\n",
        "    A list of [index, label, title, passage, class_label, class_probs].\n",
        "  \"\"\"\n",
        "  tf_dataset = tf.data.Dataset.list_files(file_path)\n",
        "  tf_dataset = tf_dataset.flat_map(tf.data.TFRecordDataset)\n",
        "\n",
        "  data = []\n",
        "  seperator_list = [\"label entailment: label: \", \"title: \", \"passage: \"]\n",
        "  for raw_record in tqdm(tf_dataset):\n",
        "    example = tf.train.Example()\n",
        "    example.ParseFromString(raw_record.numpy())\n",
        "    index = int(example.features.feature[\"index\"].bytes_list.value[0].decode())\n",
        "    label, title, passage = split_concatenated_input(\n",
        "        example.features.feature[\"model_input\"].bytes_list.value[0].decode(),\n",
        "        seperator_list)\n",
        "    if example.features.feature[\"class\"].bytes_list.value[0].decode().lower(\n",
        "    ) == \"no\":\n",
        "      class_label = 0\n",
        "    elif example.features.feature[\"class\"].bytes_list.value[0].decode().lower(\n",
        "    ) == \"yes\":\n",
        "      class_label = 1\n",
        "    else:\n",
        "      raise ValueError(\"wrong label.\")\n",
        "    data.append({\n",
        "        \"index\": index,\n",
        "        \"label\": label,\n",
        "        \"title\": title,\n",
        "        \"passage\": passage,\n",
        "        \"class\": class_label,\n",
        "    })\n",
        "\n",
        "  return pd.DataFrame(data)\n",
        "\n",
        "\n",
        "def sigmoid(x):\n",
        "  return np.where(x >= 0, 1 / (1 + np.exp(-x)), np.exp(x) / (1 + np.exp(x)))\n",
        "\n",
        "\n",
        "def kl_loss(t, s):\n",
        "  \"\"\"t: teacher logits s: student logits\"\"\"\n",
        "  assert abs(sum(t) - 1.0) < 0.001, \"t is not a prob. distribution\"\n",
        "  assert abs(sum(s) - 1.0) < 0.001, \"s is not a prob. distribution\"\n",
        "  return t[1] * np.log(t[1]) + t[0] * np.log(t[0]) - t[1] * np.log(\n",
        "      s[1]) - t[0] * np.log(s[0])\n",
        "\n",
        "\n",
        "def poly_kl_loss(t, s, N=1, eps_list=[1.0], gamma_list=[1.0]):\n",
        "  \"\"\"t: teacher logits s: student logits N: expansion order eps_list: A list of hparms epsilon for each expansion order gamma_list: A list of hparms gamma for each expansion order\"\"\"\n",
        "  assert len(eps_list) == N, \"length of eps_list should be equal to N\"\n",
        "  assert len(gamma_list) == N, \"length of gamma_list should be equal to N\"\n",
        "  loss = kl_loss(t, s)\n",
        "  for i in range(N):\n",
        "    loss += t[1] * eps_list[i] * (s[0]**(i + 1))\n",
        "    loss += t[0] * gamma_list[i] * (s[1]**(i + 1))\n",
        "  return loss\n",
        "\n",
        "\n",
        "def l_pkd(pos_logit, N=1, eps_list=[1.0], gamma_list=[1.0]):\n",
        "  p_pos = sigmoid(pos_logit)\n",
        "  p = np.array([1. - p_pos, p_pos]).flatten()\n",
        "  assert len(eps_list) == N, \"length of eps_list should be equal to N\"\n",
        "  assert len(gamma_list) == N, \"length of gamma_list should be equal to N\"\n",
        "  logit = np.array([0, pos_logit])\n",
        "  partition = 1 + p[1] / (p[0] + 1e-9)\n",
        "  perturb = np.array([0., 0.])\n",
        "  for i in range(N):\n",
        "    perturb += np.array(\n",
        "        [gamma_list[i] * p[1]**(i + 1), eps_list[i] * p[0]**(i + 1)])\n",
        "  return -logit + np.log(partition) + perturb\n",
        "\n",
        "\n",
        "def l_pkd_grad(pos_logit, N=1, eps_list=[1.0], gamma_list=[1.0]):\n",
        "  p_pos = sigmoid(pos_logit)\n",
        "  p = np.array([1. - p_pos, p_pos]).flatten()\n",
        "  assert len(eps_list) == N, \"length of eps_list should be equal to N\"\n",
        "  assert len(gamma_list) == N, \"length of gamma_list should be equal to N\"\n",
        "  g_logit = np.array([0, 1])\n",
        "  partition = 1 + p[1] / (p[0] + 1e-9)\n",
        "  g_partition = np.array([np.exp(pos_logit), np.exp(pos_logit)])\n",
        "  g_log_partition = g_partition / partition\n",
        "  g_perturb = np.array([0., 0.])\n",
        "  for i in range(N):\n",
        "    g_p1 = p[0] * p[1]\n",
        "    g_gamma = gamma_list[i] * (i + 1) * g_p1**i\n",
        "    g_p0 = -p[0] * p[1]\n",
        "    g_eps = eps_list[i] * (i + 1) * g_p0**i\n",
        "    g_perturb += np.array([g_gamma, g_eps]).flatten()\n",
        "  return -g_logit + g_log_partition.flatten() + g_perturb"
      ],
      "metadata": {
        "id": "OzlnEoxGzVTd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Approach 1: Equivalenet teacher"
      ],
      "metadata": {
        "id": "FNcBe8NOzZ5U"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Consider binary classification with $\\mathbf{p}^t(x)\\in\\mathbb{R}^2 =\n",
        "\\left(\\frac{1}{1+\\exp{f^t(x)}}, \\frac{\\exp{f^t(x)}}{1+\\exp{f^t(x)}}\\right)$. We\n",
        "want to learn a transformed $\\tilde{f}^t(x)$ such that $$\n",
        "\\mathbf{p}^{t_{\\text{PKD}}}(x)^Tl((0, \\tilde{f}^t(x))) +\n",
        "\\mathbf{p}^{t_{\\text{PKD}}}(x)^T\\log(\\mathbf{p}^{t_{\\text{PKD}}}(x)) =\n",
        "\\mathbf{p}^t(x)^Tl_{\\text{PKD-K}}((0, \\tilde{f}^t(x))) +\n",
        "\\mathbf{p}^t(x)^T\\log(\\mathbf{p}^t(x)), $$ where\n",
        "$\\mathbf{p}^{t_{\\text{PKD}}}(x) = \\left(\\frac{1}{1+\\exp{\\tilde{f}^t(x)}},\n",
        "\\frac{\\exp{\\tilde{f}^t(x)}}{1+\\exp{\\tilde{f}^t(x)}}\\right)$.\n",
        "\n",
        "We can solve the equation using numerical approach.\n",
        "\n",
        "Then we minimize the following loss:\n",
        "$$\n",
        "    l(\\{\\mathbf{\\epsilon}_{k},k\\in[K]\\}) =\n",
        "    \\left(\\frac{1}{N}\\sum_{n\\in[N]}\\left[\\|\\mathbf{p}^{t_{\\text{PKD}}}(x_n)-\\mathbf{y}_n\\|_2\\right]\\right)^2 + \\frac{1}{N}\\sum_{n\\in[N]}\\left[\\left(\\mathbf{p}^{t_{\\text{PKD}}}(x_n)^T\\log\\mathbf{p}^{t_{\\text{PKD}}}(x_n)\\right)^2\\right]\n",
        "$$"
      ],
      "metadata": {
        "id": "Dz7cFZK5zXBd"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def equivalent_teacher_equation(pos_logit,\n",
        "                                p_t,\n",
        "                                N=1,\n",
        "                                eps_list=[1.0],\n",
        "                                gamma_list=[1.0]):\n",
        "  p_pkd_pos = sigmoid(pos_logit)\n",
        "  p_pkd = np.array([1. - p_pkd_pos, p_pkd_pos]).flatten()\n",
        "  lhs = sum(p_pkd * l_pkd(pos_logit, N=0, eps_list=[], gamma_list=[]) +\n",
        "            p_pkd * np.log(np.clip(p_pkd, a_min=1e-9, a_max=1 - 1e-9)))\n",
        "  rhs = sum(p_t *\n",
        "            l_pkd(pos_logit, N=N, eps_list=eps_list, gamma_list=gamma_list) +\n",
        "            p_t * np.log(p_t))\n",
        "  return lhs - rhs\n",
        "\n",
        "\n",
        "def equivalent_teacher_equation_grad(pos_logit,\n",
        "                                     p_t,\n",
        "                                     N=1,\n",
        "                                     eps_list=[1.0],\n",
        "                                     gamma_list=[1.0]):\n",
        "  p_pkd_pos = sigmoid(pos_logit)\n",
        "  p_pkd = np.array([1. - p_pkd_pos, p_pkd_pos]).flatten()\n",
        "  p_pkd_grad = (-p_pkd[0] * p_pkd[1], p_pkd[0] * p_pkd[1])\n",
        "  lhs = sum(p_pkd_grad * l_pkd(pos_logit, N=0, eps_list=[], gamma_list=[]) +\n",
        "            p_pkd * l_pkd_grad(pos_logit, N=0, eps_list=[], gamma_list=[]) +\n",
        "            p_pkd_grad * np.log(np.clip(p_pkd, a_min=1e-9, a_max=1 - 1e-9)) +\n",
        "            p_pkd_grad)\n",
        "  rhs = sum(\n",
        "      p_t *\n",
        "      l_pkd_grad(pos_logit, N=N, eps_list=eps_list, gamma_list=gamma_list))\n",
        "  return lhs - rhs\n",
        "\n",
        "\n",
        "def p_pkd(p, N=1, eps_list=[1.0], gamma_list=[1.0]):\n",
        "  tilde_f = optimize.fsolve(\n",
        "      functools.partial(\n",
        "          equivalent_teacher_equation,\n",
        "          p_t=p,\n",
        "          N=N,\n",
        "          eps_list=eps_list,\n",
        "          gamma_list=gamma_list),\n",
        "      fprime=functools.partial(\n",
        "          equivalent_teacher_equation_grad,\n",
        "          p_t=p,\n",
        "          N=N,\n",
        "          eps_list=eps_list,\n",
        "          gamma_list=gamma_list),\n",
        "      x0=np.log((p[1] + 1e-9) / (p[0] + 1e-9)))\n",
        "  p_pkd_pos = sigmoid(tilde_f)\n",
        "  return np.array([1 - p_pkd_pos, p_pkd_pos]).flatten()\n",
        "\n",
        "\n",
        "def get_hparam_score(data, N=1, eps_list=[1.0], gamma_list=[1.0]):\n",
        "  pts = []\n",
        "  transformed_pts = []\n",
        "  yts = []\n",
        "  for _, x in data.iterrows():\n",
        "    pt = x['class_probs']\n",
        "    pts.append(pt)\n",
        "    transformed_p = p_pkd(pt, N=N, eps_list=eps_list, gamma_list=gamma_list)\n",
        "    transformed_pts.append(transformed_p)\n",
        "    label = np.zeros(2)\n",
        "    label[x['class']] = 1.\n",
        "    yts.append(label)\n",
        "\n",
        "  distance = 0.\n",
        "  entropy2 = 0.\n",
        "  for p, y in zip(transformed_pts, yts):\n",
        "    p = np.clip(p, 1e-9, 1. - 1e-9)\n",
        "    distance += np.sqrt(sum((p - y)**2))\n",
        "    entropy2 += (sum(p * np.log(p)))**2.\n",
        "  term1 = (distance / len(yts))**2.\n",
        "  term2 = entropy2 / len(yts)\n",
        "  print(f'term1: {term1}, term2: {term2}')\n",
        "  return term1 + term2, term1, term2\n",
        "\n",
        "\n",
        "def get_hparam_and_score(inputs):\n",
        "  eps_list, gamma_list, data = inputs\n",
        "  score, score1, score2 = get_hparam_score(\n",
        "      data, N=len(eps_list), eps_list=eps_list, gamma_list=gamma_list)\n",
        "  return eps_list, gamma_list, score"
      ],
      "metadata": {
        "id": "a7iyMvqbzXZ_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "N_TRIALS = 100  #@param{type: \"number\"}\n",
        "TEACHER_IDS = \"1,2,3\"  #@param{type: \"string\"}\n",
        "TASK = \"mnli\" #@param{type: \"string\"} [\"mnli\", \"sst2\", \"boolq\"]\n",
        "\n",
        "TEACHER_IDS = [int(x.strip()) for x in TEACHER_IDS.split(\",\")]\n",
        "\n",
        "for teacher_id in TEACHER_IDS:\n",
        "  teacher_path = f'/teacher_{teacher_id}.tfrecord'\n",
        "  label_path = f'/{TASK}_label.tfrecord'\n",
        "  pred_data = load_weak_supervision_data(teacher_path)[[\"index\", \"class_probs\"]]\n",
        "  labeled_data = load_labeled_data(label_path)[[\"index\", \"class\"]]\n",
        "  data = pred_data.merge(\n",
        "      labeled_data,\n",
        "      on=['index'],\n",
        "      suffixes=['_pred', ''])\n",
        "  for pkd_order in range(1, 7):\n",
        "    eps_list = np.random.uniform(low=-1., high=10., size=(N_TRIALS, pkd_order))\n",
        "    gamma_list = np.random.uniform(low=-1., high=10., size=(N_TRIALS, pkd_order))\n",
        "    data_list = [data.head(1000)] * N_TRIALS\n",
        "    with Pool(N_TRIALS//10) as p:\n",
        "      results = (p.map(get_hparam_and_score, list(zip(eps_list, gamma_list, data_list))))\n",
        "    best_ids = np.argmin([x[2] for x in results])\n",
        "    print(f'\\n======\\nteacher_id: {teacher_id} pkd_order: {pkd_order}')\n",
        "    print(f'best hparam and score: {results[best_ids][0]}, {results[best_ids][1]}, {results[best_ids][2]}\\n\\n')"
      ],
      "metadata": {
        "id": "XktpKTQizeQH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Approach 2: directly optimize the ideal student"
      ],
      "metadata": {
        "id": "mX0MbiIuzgNf"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Consider binary classification with $\\mathbf{p}^t(x)\\in\\mathbb{R}^2 =\n",
        "\\left(\\frac{1}{1+\\exp{f^t(x)}}, \\frac{\\exp{f^t(x)}}{1+\\exp{f^t(x)}}\\right)$. We\n",
        "can directly solve the student best logit ${f}^s(x)$ as the minimizer of $$\n",
        "\\mathbf{p}^t(x)^Tl_{\\text{PKD-K}}((0, f(x))) +\n",
        "\\mathbf{p}^t(x)^T\\log(\\mathbf{p}^t(x)), $$"
      ],
      "metadata": {
        "id": "S81ae5TOzkIK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def target_function(pos_logit, p_t, N, eps_list, gamma_list):\n",
        "  p_pkd_pos = sigmoid(pos_logit)\n",
        "  p_pkd = np.array([1. - p_pkd_pos, p_pkd_pos]).flatten()\n",
        "  return sum(p_t *\n",
        "             l_pkd(pos_logit, N=N, eps_list=eps_list, gamma_list=gamma_list))\n",
        "\n",
        "\n",
        "def target_function_grad(pos_logit, p_t, N, eps_list, gamma_list):\n",
        "  res = sum(\n",
        "      p_t *\n",
        "      l_pkd_grad(pos_logit, N=N, eps_list=eps_list, gamma_list=gamma_list))\n",
        "  return res\n",
        "\n",
        "\n",
        "def ideal_student(p, N=1, eps_list=[1.0], gamma_list=[1.0]):\n",
        "  tilde_f = optimize.fsolve(\n",
        "      functools.partial(\n",
        "          target_function_grad,\n",
        "          p_t=p,\n",
        "          N=N,\n",
        "          eps_list=eps_list,\n",
        "          gamma_list=gamma_list),\n",
        "      x0=np.log((p[1] + 1e-9) / (p[0] + 1e-9)))\n",
        "  p_pkd_pos = sigmoid(tilde_f)\n",
        "  return np.array([1 - p_pkd_pos, p_pkd_pos]).flatten()\n",
        "\n",
        "\n",
        "def get_hparam_score2(data, N=1, eps_list=[1.0], gamma_list=[1.0]):\n",
        "  pts = []\n",
        "  transformed_pts = []\n",
        "  yts = []\n",
        "  for _, x in data.iterrows():\n",
        "    pt = x['class_probs']\n",
        "    pts.append(pt)\n",
        "    transformed_p = ideal_student(\n",
        "        pt, N=N, eps_list=eps_list, gamma_list=gamma_list)\n",
        "    transformed_pts.append(transformed_p)\n",
        "    label = np.zeros(2)\n",
        "    label[x['class']] = 1.\n",
        "    yts.append(label)\n",
        "\n",
        "  distance = 0.\n",
        "  for p, y in zip(transformed_pts, yts):\n",
        "    distance += np.sqrt(sum((np.clip(p, 0., 1.) - y)**2))\n",
        "  return distance / len(yts)\n"
      ],
      "metadata": {
        "id": "8UYJamFdzeK-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# For multi-class"
      ],
      "metadata": {
        "id": "ezqdX0KNzz1n"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Equivalent teacher"
      ],
      "metadata": {
        "id": "IpvfQFBez5se"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Consider multiclass classification with $\\mathbf{p}^t(x)\\in\\mathbb{R}^C =\n",
        "\\text{softmax}{f^t(x)}$. We\n",
        "want to learn a transformed $\\tilde{f}^t(x)$ such that $$\n",
        "\\mathbf{p}^{t_{\\text{PKD}}}(x)^Tl(\\tilde{f}^t(x)) +\n",
        "\\mathbf{p}^{t_{\\text{PKD}}}(x)^T\\log(\\mathbf{p}^{t_{\\text{PKD}}}(x)) =\n",
        "\\mathbf{p}^t(x)^Tl_{\\text{PKD-K}}(\\tilde{f}^t(x)) +\n",
        "\\mathbf{p}^t(x)^T\\log(\\mathbf{p}^t(x)), $$ where\n",
        "$\\mathbf{p}^{t_{\\text{PKD}}}(x) = \\text{softmax}\\tilde{f}^t(x)$.\n",
        "\n",
        "We can solve the equation using numerical approach.\n",
        "\n",
        "Then we minimize the following loss:\n",
        "$$\n",
        "    l(\\{\\mathbf{\\epsilon}_{k},k\\in[K]\\}) =\n",
        "    \\left(\\frac{1}{N}\\sum_{n\\in[N]}\\left[\\|\\mathbf{p}^{t_{\\text{PKD}}}(x_n)-\\mathbf{y}_n\\|_2\\right]\\right)^2 + \\frac{1}{N}\\sum_{n\\in[N]}\\left[\\left(\\mathbf{p}^{t_{\\text{PKD}}}(x_n)^T\\log\\mathbf{p}^{t_{\\text{PKD}}}(x_n)\\right)^2\\right]\n",
        "$$"
      ],
      "metadata": {
        "id": "pygp4pSEz9Iw"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def equivalent_teacher_equation(logits,\n",
        "                                p_t,\n",
        "                                N=1, num_classes=3, eta=np.ones((3, 1))):\n",
        "  logits = np.concatenate([[0], logits])\n",
        "  p_pkd = special.softmax(logits)\n",
        "  lhs = sum(p_pkd * l_pkd(logits, N=0, eta=[]) +\n",
        "            p_pkd * np.log(np.clip(p_pkd, a_min=1e-9, a_max=1 - 1e-9)))\n",
        "  rhs = sum(p_t *\n",
        "            l_pkd(logits, N=N, num_classes=num_classes, eta=eta) +\n",
        "            p_t * np.log(p_t))\n",
        "  return [lhs - rhs]*(num_classes-1)\n",
        "\n",
        "def p_pkd(pt, N=1, num_classes=3, eta=np.ones((3, 1))):\n",
        "  tilde_f = optimize.fsolve(\n",
        "      functools.partial(\n",
        "          equivalent_teacher_equation,\n",
        "          p_t=pt,\n",
        "          N=N,\n",
        "          num_classes=num_classes,\n",
        "          eta=eta),\n",
        "      x0=np.log(pt/pt[0])[1:])\n",
        "  tilde_f = np.concatenate([[0], tilde_f])\n",
        "  return special.softmax(tilde_f)\n",
        "\n",
        "def get_hparam_score(data, N=1, num_classes=3, eta=np.ones((3, 1))):\n",
        "  pts = []\n",
        "  transformed_pts = []\n",
        "  yts = []\n",
        "  for _, x in data.iterrows():\n",
        "    pt = x['class_probs']\n",
        "    pts.append(pt)\n",
        "    transformed_p = p_pkd(pt, N=N, num_classes=num_classes,\n",
        "          eta=eta)\n",
        "    transformed_pts.append(transformed_p)\n",
        "    label = np.zeros(num_classes)\n",
        "    label[x['class']] = 1.\n",
        "    yts.append(label)\n",
        "\n",
        "  distance = 0.\n",
        "  entropy2 = 0.\n",
        "  for p, y in zip(transformed_pts, yts):\n",
        "    p = np.clip(p, 1e-9, 1. - 1e-9)\n",
        "    distance += np.sqrt(sum((p - y)**2))\n",
        "    entropy2 += (sum(p * np.log(p)))**2.\n",
        "  term1 = (distance / len(yts))**2.\n",
        "  term2 = entropy2 / len(yts)\n",
        "  print(f'term1: {term1}, term2: {term2}')\n",
        "  return term1 + term2, term1, term2\n",
        "\n",
        "\n",
        "def get_hparam_and_score(inputs):\n",
        "  eta, data = inputs\n",
        "  score, score1, score2 = get_hparam_score(\n",
        "      data, N=eta.shape[1], num_classes=eta.shape[0], eta=eta)\n",
        "  return eta, score"
      ],
      "metadata": {
        "id": "upJbOsDhz7aY"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}