{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Abb68XTAnSwm"
      },
      "source": [
        "We implement MiDAS here, using pretrained MiDAS encoder and expert models. We use an ML experiment framework to run experiments.\n",
        "\n",
        "Data ia assumed to be stored in Google Drive. References to datasets are provided in paper."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t8t8OcDQDypK"
      },
      "source": [
        "# Setup Notebook"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i0-UNku6DPqU"
      },
      "source": [
        "## Connect to Drive"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "w_P6hPRDqyi4",
        "outputId": "555215f2-125c-44e6-bfc2-c2cec258ee50"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Thu May 19 00:22:22 2022       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   40C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "|  No running processes found                                                 |\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        }
      ],
      "source": [
        "!nvidia-smi"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aZtdWNypnP72"
      },
      "outputs": [],
      "source": [
        "import os.path as osp\n",
        "import os, glob, shutil"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "xiPHbFAZjWlv",
        "outputId": "35f52ecf-b004-4adf-ae64-9670f44b541c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ],
      "source": [
        "# Set Gogle Drive Connection\n",
        "#if not osp.exists(\"./drive\"):\n",
        "from google.colab import drive\n",
        "\n",
        "#https://stackoverflow.com/questions/69822304/google-colab-google-drive-can%c2%b4t-be-mounted-anymore-browser-popup-google-dri\n",
        "drive.mount('/content/drive', force_remount=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OJ8JWzSoG5P1"
      },
      "source": [
        "## Data Downloads"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "jaWg-q8XkFM_",
        "outputId": "793b73e7-3bb8-40e6-ad3b-81a782a333f1"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Covid-FN does not exist. Copying...\n"
          ]
        }
      ],
      "source": [
        "if not osp.exists(\"Covid-FN\"):\n",
        "  print(\"Covid-FN does not exist. Copying...\")\n",
        "  !cp -r ./drive/MyDrive/Datasets/Covid-FN .\n",
        "else:\n",
        "  print(\"Covid-FN already exists\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "W4bTaG6FkJiE",
        "outputId": "f695eb71-0c93-4cf9-a483-03baac71dc7c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Unpacked Covid-FN/kagglefn_short.zip to Data/kagglefn_short\n",
            "Unpacked Covid-FN/recov.zip to Data/recov\n",
            "Unpacked Covid-FN/cov_rumor.zip to Data/cov_rumor\n",
            "Unpacked Covid-FN/cov19_fn_title.zip to Data/cov19_fn_title\n",
            "Unpacked Covid-FN/cov19_fn_text.zip to Data/cov19_fn_text\n",
            "Unpacked Covid-FN/coaid_tweets.zip to Data/coaid_tweets\n",
            "Unpacked Covid-FN/covid_fn.zip to Data/covid_fn\n",
            "Unpacked Covid-FN/coaid_news.zip to Data/coaid_news\n",
            "Unpacked Covid-FN/kagglefn_long.zip to Data/kagglefn_long\n",
            "Unpacked Covid-FN/cmu_miscov19.zip to Data/cmu_miscov19\n",
            "Unpacked Covid-FN/covid_cq.zip to Data/covid_cq\n"
          ]
        }
      ],
      "source": [
        "sources = glob.glob(\"Covid-FN/*.zip\")\n",
        "dest = \"Data\"\n",
        "os.makedirs(dest, exist_ok=True)\n",
        "for source in sources:\n",
        "  base = osp.basename(source)\n",
        "  dname = osp.splitext(base)[0]\n",
        "  ddest = osp.join(dest,dname)\n",
        "  if not osp.exists(dname):\n",
        "    shutil.unpack_archive(source, ddest)\n",
        "    print(\"Unpacked %s to %s\"%(source, ddest))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "27LK1AmgfuCk"
      },
      "outputs": [],
      "source": [
        "!cp ./drive/MyDrive/Datasets/PreTrained/Albertv2/30k* ."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qEpcMzQ1oikH"
      },
      "outputs": [],
      "source": [
        "!cp ./drive/MyDrive/Datasets/PreTrained/Albertv2/pytorch_model.bin ."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TqJoyEW1f_a1"
      },
      "source": [
        "## Git Clone"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oisq4kE8IMHi"
      },
      "source": [
        "### From Source"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "goBMKUmagBIx",
        "outputId": "494d13f5-c6c9-4867-a29e-2659c0d36667"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Cloning into 'GLAMOR'...\n",
            "remote: Enumerating objects: 4296, done.\u001b[K\n",
            "remote: Counting objects: 100% (711/711), done.\u001b[K\n",
            "remote: Compressing objects: 100% (375/375), done.\u001b[K\n",
            "remote: Total 4296 (delta 407), reused 534 (delta 275), pack-reused 3585\u001b[K\n",
            "Receiving objects: 100% (4296/4296), 1.64 MiB | 16.14 MiB/s, done.\n",
            "Resolving deltas: 100% (2841/2841), done.\n"
          ]
        }
      ],
      "source": [
        "! git clone -b colabel-multicb https://github.com/asuprem/GLAMOR"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "T80AC-kx4v4Y",
        "outputId": "ca107650-a156-499d-cd39-8996d0ca8f3c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Obtaining file:///content/GLAMOR\n",
            "Requirement already satisfied: scikit-learn>=1.0.2 in /usr/local/lib/python3.7/dist-packages (from ednaml==0.1.4) (1.0.2)\n",
            "Requirement already satisfied: torch>=1.10.* in /usr/local/lib/python3.7/dist-packages (from ednaml==0.1.4) (1.11.0+cu113)\n",
            "Collecting torchinfo>=1.6.5\n",
            "  Downloading torchinfo-1.6.6-py3-none-any.whl (21 kB)\n",
            "Requirement already satisfied: torchvision>=0.11.* in /usr/local/lib/python3.7/dist-packages (from ednaml==0.1.4) (0.12.0+cu113)\n",
            "Requirement already satisfied: Pillow>=7.1.2 in /usr/local/lib/python3.7/dist-packages (from ednaml==0.1.4) (7.1.2)\n",
            "Requirement already satisfied: tqdm>=4.63.* in /usr/local/lib/python3.7/dist-packages (from ednaml==0.1.4) (4.64.0)\n",
            "Collecting sentencepiece>=0.1.96\n",
            "  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)\n",
            "\u001b[K     |████████████████████████████████| 1.2 MB 7.3 MB/s \n",
            "\u001b[?25hRequirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=1.0.2->ednaml==0.1.4) (1.4.1)\n",
            "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=1.0.2->ednaml==0.1.4) (1.1.0)\n",
            "Requirement already satisfied: numpy>=1.14.6 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=1.0.2->ednaml==0.1.4) (1.21.6)\n",
            "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=1.0.2->ednaml==0.1.4) (3.1.0)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.10.*->ednaml==0.1.4) (4.2.0)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from torchvision>=0.11.*->ednaml==0.1.4) (2.23.0)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision>=0.11.*->ednaml==0.1.4) (1.24.3)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision>=0.11.*->ednaml==0.1.4) (2021.10.8)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision>=0.11.*->ednaml==0.1.4) (2.10)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision>=0.11.*->ednaml==0.1.4) (3.0.4)\n",
            "Installing collected packages: torchinfo, sentencepiece, ednaml\n",
            "  Running setup.py develop for ednaml\n",
            "Successfully installed ednaml-0.1.4 sentencepiece-0.1.96 torchinfo-1.6.6\n"
          ]
        }
      ],
      "source": [
        "!pip install -e GLAMOR/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tRDFo-J3q-RX"
      },
      "outputs": [],
      "source": [
        "#! rm -rf -- GLAMOR"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d5j3WfN0fpIT"
      },
      "source": [
        "###  From PyPi"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f7dkOhZi08dU"
      },
      "outputs": [],
      "source": [
        "#! python -V"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FwqgjiZ331ik"
      },
      "outputs": [],
      "source": [
        "#! pip3 install --pre ednaml==0.1.4"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t6HDMaGzs2rd"
      },
      "source": [
        "# MiDAS"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9MafR4mBKavE"
      },
      "source": [
        "## Data Crawlers and Generators"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xA3U0U4SyUKV"
      },
      "outputs": [],
      "source": [
        "import ednaml"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Nap_NoJmKZtB"
      },
      "outputs": [],
      "source": [
        "import csv, glob, os\n",
        "from ednaml.crawlers import Crawler\n",
        "class MiDASCrawler(Crawler):\n",
        "  def __init__(self, logger = None, data_folder=\"Data\", include=[\"cmu_miscov19\", \"covid19_fn_title\", \"kagglefn_short\"]):\n",
        "    \"\"\"Crawls the Data folder with all datasets already extracted to their individual folders.\n",
        "    Assumes specific file construction: files are separated into splits \n",
        "    (train, test, and val), with label subsets fake and true, with naming convention:\n",
        "\n",
        "    <datasetname>-<labelsubset>-<split>.csv\n",
        "\n",
        "    \"\"\"\n",
        "    logger.info(\"Crawling %s for %s\"%(data_folder, str(include)))\n",
        "    \n",
        "    self.classes = {}\n",
        "    self.classes[\"fnews\"] = 2\n",
        "    self.metadata = {}\n",
        "\n",
        "    self.data_folder = data_folder\n",
        "    datasets_folders = glob.glob(os.path.join(data_folder, \"*\"))\n",
        "    datasets_folders = [item for item in datasets_folders if os.path.basename(item) in include]\n",
        "    self.classes[\"dataset\"] = len(datasets_folders)\n",
        "\n",
        "\n",
        "    self.metadata[\"train\"] = {}\n",
        "    self.metadata[\"test\"] = {}\n",
        "    self.metadata[\"val\"] = {}\n",
        "    self.metadata[\"train\"][\"crawl\"] = []\n",
        "    self.metadata[\"test\"][\"crawl\"] = []\n",
        "    self.metadata[\"val\"][\"crawl\"] = []\n",
        "    for idx, folder in enumerate(datasets_folders):\n",
        "      ftrain = [os.path.join(folder, \"-\".join([os.path.basename(folder), subset, \"train.csv\"])) for subset in [\"fake\", \"true\"]]\n",
        "      ftest = [os.path.join(folder, \"-\".join([os.path.basename(folder), subset, \"test.csv\"])) for subset in [\"fake\", \"true\"]]\n",
        "      fval = [os.path.join(folder, \"-\".join([os.path.basename(folder), subset, \"val.csv\"])) for subset in [\"fake\", \"true\"]]\n",
        "\n",
        "      # train:\n",
        "      self.metadata[\"train\"][\"crawl\"] += self.getTextAndLabels(ftrain, idx)\n",
        "      self.metadata[\"test\"][\"crawl\"] += self.getTextAndLabels(ftest, idx)\n",
        "      self.metadata[\"val\"][\"crawl\"] += self.getTextAndLabels(fval, idx)\n",
        "\n",
        "    self.metadata[\"train\"][\"classes\"] = self.classes\n",
        "    self.metadata[\"test\"][\"classes\"] = self.classes\n",
        "    self.metadata[\"val\"][\"classes\"] = self.classes\n",
        "\n",
        "  def getTextAndLabels(self, listOfFiles, datasetidx):\n",
        "    crawl = []\n",
        "    for file in listOfFiles:\n",
        "      # For each file, open as csv. Extract the relevant columns (\"text\", \"label\")\n",
        "      with open(file, \"r\") as ofile:\n",
        "        csvread = csv.reader(ofile)\n",
        "\n",
        "        header = next(csvread)\n",
        "        textidx = [idx for idx,item in enumerate(header) if item == \"text\"][0]\n",
        "        labelidx = [idx for idx,item in enumerate(header) if item == \"label\"][0]\n",
        "\n",
        "        for row in csvread:\n",
        "          crawl.append([row[textidx], int(row[labelidx]), datasetidx])\n",
        "    return crawl"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l9AQM02vzAOk"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from torch.utils.data import TensorDataset\n",
        "class MiDASDataset(torch.utils.data.Dataset):\n",
        "  def __init__(self, dataset, mode, transform=None, **kwargs):\n",
        "    self.dataset = dataset  # list of tuples (text, label, datasetlabel)\n",
        "    self.cache = kwargs.get(\"cache\", False)\n",
        "    self.memcache = kwargs.get(\"memcache\", False)\n",
        "    self.tokenizer = kwargs.get(\"tokenizer\")\n",
        "    self.maxlen = kwargs.get(\"maxlen\")\n",
        "    self.mlm = kwargs.get(\"mlm_probability\", 0.2)\n",
        "    self.masking = kwargs.get(\"masking\", mode==\"train\")\n",
        "\n",
        "    self.getcount = 0\n",
        "    self.refresh_flag = len(dataset)\n",
        "\n",
        "    if self.cache and self.memcache:\n",
        "      raise ValueError(\"Use only one cache type\")\n",
        "\n",
        "    self.getter = self.uncachedget\n",
        "    if self.cache:\n",
        "      raise NotImplementedError()\n",
        "      self.getter = self.diskget\n",
        "    if self.memcache:\n",
        "      self.getter = self.memget\n",
        "\n",
        "    # if cache, then we will cache transformations to disk, then load them\n",
        "    # if memcache, then we will cache to memory and load them\n",
        "\n",
        "    if self.cache or self.memcache:\n",
        "      # The actual cache-ing\n",
        "      self.input_length_cache = []\n",
        "      self.convert_to_features(self.dataset, self.tokenizer, maxlen=self.maxlen)\n",
        "      self.memcached_dataset = self.refresh_mask_ids()\n",
        "      \n",
        "\n",
        "  def __len__(self):\n",
        "    return len(self.dataset)\n",
        "\n",
        "  def __getitem__(self, idx):\n",
        "    return self.getter(idx)\n",
        "\n",
        "  def uncachedget(self, idx):\n",
        "    return \n",
        "\n",
        "    tokens = self.tokenizer.tokenize(self.dataset[idx][0])\n",
        "    if len(tokens) > maxlen - 2:\n",
        "      tokens = tokens[0:(maxlen - 2)]\n",
        "\n",
        "    finaltokens = [\"[CLS]\"]\n",
        "    token_type_ids = [0]\n",
        "    for token in tokens:\n",
        "      finaltokens.append(token)\n",
        "      token_type_ids.append(0)\n",
        "    finaltokens.append(\"[SEP]\")\n",
        "    token_type_ids.append(0)\n",
        "\n",
        "    input_ids = self.tokenizer.convert_tokens_to_ids(finaltokens)\n",
        "    attention_mask = [1]*len(input_ids)\n",
        "    input_len = len(input_ids)\n",
        "\n",
        "    while len(input_ids) < max_seq_length:\n",
        "      input_ids.append(0)\n",
        "      attention_mask.append(0)\n",
        "      token_type_ids.append(0) \n",
        "\n",
        "    assert len(input_ids) == max_seq_length\n",
        "    assert len(attention_mask) == max_seq_length\n",
        "    assert len(token_type_ids) == max_seq_length\n",
        "    \n",
        "    return (input_ids, attention_mask, token_type_ids, input_len, self.dataset[idx][1], self.dataset[idx][2])\n",
        "    \n",
        "\n",
        "  def memget(self, idx):\n",
        "    self.getcount+=1\n",
        "    if self.getcount > self.refresh_flag and self.masking:\n",
        "      self.memcached_dataset = self.refresh_mask_ids()\n",
        "      self.getcount = 0\n",
        "    return self.memcached_dataset[idx]\n",
        "\n",
        "  def refresh_mask_ids(self):\n",
        "    print(\"Refreshing mask ids\")\n",
        "    if self.masking:\n",
        "      self.mask_ids = self.build_mask_ids(self.input_length_cache)\n",
        "      \n",
        "      all_attention_mask = self.all_attention_mask.clone()\n",
        "      all_masklm = self.all_masklm.clone()\n",
        "\n",
        "      for idx in range(self.refresh_flag):\n",
        "        all_attention_mask[idx][self.mask_ids[idx]] = 0 # Set the masking words to 0, so we do not attend to it during prediction\n",
        "        all_masklm[idx][self.mask_ids[idx]] = self.all_input_ids[idx][self.mask_ids[idx]] # Set the masking labels for these to the actual word index from all_input_ids\n",
        "\n",
        "      return TensorDataset(self.all_input_ids, all_attention_mask, self.all_token_type_ids, all_masklm, self.all_lens, self.all_labels, self.all_datalabels)\n",
        "    else:\n",
        "      return TensorDataset(self.all_input_ids, self.all_attention_mask, self.all_token_type_ids, self.all_masklm, self.all_lens, self.all_labels, self.all_datalabels)\n",
        "\n",
        "  def build_mask_ids(self, input_length_cache):\n",
        "    # for each element, we get a set of indices that are randomly selected...\n",
        "    # Also, -2 and +1 take care of [cls] and [sep] not being masked\n",
        "    return [(torch.randperm(inplength-2)+1)[:int(inplength*self.mlm)]  for inplength in input_length_cache]\n",
        "\n",
        "\n",
        "  def convert_to_features(self, dataset, tokenizer, maxlen):\n",
        "    features = []\n",
        "    self.input_length_cache = []\n",
        "    for idx, sample in enumerate(dataset):\n",
        "      tokens = self.tokenizer.tokenize(sample[0])\n",
        "      if len(tokens) > maxlen - 2:\n",
        "        tokens = tokens[0:(maxlen - 2)]\n",
        "\n",
        "      finaltokens = [\"[CLS]\"]\n",
        "      token_type_ids = [0]\n",
        "      for token in tokens:\n",
        "        finaltokens.append(token)\n",
        "        token_type_ids.append(0)\n",
        "      finaltokens.append(\"[SEP]\")\n",
        "      token_type_ids.append(0)\n",
        "\n",
        "      input_ids = self.tokenizer.convert_tokens_to_ids(finaltokens)\n",
        "      attention_mask = [1]*len(input_ids)\n",
        "      input_len = len(input_ids)\n",
        "      self.input_length_cache.append(len(input_ids))\n",
        "      while len(input_ids) < maxlen:\n",
        "        input_ids.append(0)\n",
        "        attention_mask.append(0)\n",
        "        token_type_ids.append(0) \n",
        "\n",
        "      assert len(input_ids) == maxlen\n",
        "      assert len(attention_mask) == maxlen\n",
        "      assert len(token_type_ids) == maxlen\n",
        "      \n",
        "      features.append(\n",
        "          (input_ids, attention_mask, token_type_ids, input_len, sample[1], sample[2])\n",
        "      )\n",
        "\n",
        "    self.all_input_ids = torch.tensor([f[0] for f in features], dtype=torch.long)\n",
        "    self.all_attention_mask = torch.tensor([f[1] for f in features], dtype=torch.long)\n",
        "    self.all_token_type_ids = torch.tensor([f[2] for f in features], dtype=torch.long)\n",
        "    self.all_lens = torch.tensor([f[3] for f in features], dtype=torch.long)\n",
        "    self.all_labels = torch.tensor([f[4] for f in features], dtype=torch.long)\n",
        "    self.all_datalabels = torch.tensor([f[5] for f in features], dtype=torch.long)\n",
        "    self.all_masklm = -1*torch.ones(self.all_input_ids.shape, dtype=torch.long)\n",
        "\n",
        "\n",
        "\n",
        "from ednaml.utils.LabelMetadata import LabelMetadata\n",
        "from ednaml.generators import TextGenerator\n",
        "class MiDASGenerator(TextGenerator):\n",
        "  # input includes tokenizer for build...\n",
        "\n",
        "  # Set it up such that, given a crawler, create a dataset from it.\n",
        "  # Then create a cached batch\n",
        "  # From cached batch, yield batches until it is empty\n",
        "\n",
        "  def build_transforms(self, transform, mode, **kwargs):  #<-- generator kwargs:\n",
        "    from ednaml.utils import locate_class\n",
        "    tokenizer = kwargs.get(\"tokenizer\", \"AlbertFullTokenizer\")\n",
        "    self.tokenizer = locate_class(package=\"ednaml\", subpackage=\"utils\", classpackage=tokenizer, classfile=\"tokenizers\")\n",
        "    self.tokenizer = self.tokenizer(**kwargs) # vocab_file, do_lower_case, spm_model_file\n",
        "\n",
        "  \n",
        "  def buildDataset(self, crawler, mode, transform, **kwargs): #<-- dataset args:\n",
        "    return MiDASDataset(crawler.metadata[mode][\"crawl\"], mode, tokenizer = self.tokenizer, **kwargs) # needs maxlen, memcache, mlm_probability\n",
        "\n",
        "  def buildDataLoader(self, dataset, mode, batch_size, **kwargs):\n",
        "    return torch.utils.data.DataLoader(dataset, batch_size=batch_size*self.gpus,\n",
        "                                        shuffle=(mode==\"train\"), num_workers = self.workers, \n",
        "                                       collate_fn=self.collate_fn)\n",
        "\n",
        "  def getNumEntities(self, crawler, mode, **kwargs):  #<-- dataset args\n",
        "    label_dict = {\n",
        "        item: {\"classes\": crawler.metadata[mode][\"classes\"][item]}\n",
        "        for item in kwargs.get(\"classificationclass\", [\"color\"])\n",
        "    }\n",
        "    return LabelMetadata(label_dict=label_dict)\n",
        "\n",
        "  def collate_fn(self, batch):\n",
        "    all_input_ids, all_attention_mask, all_token_type_ids, all_masklm, all_lens, all_labels, all_datalabels  = map(torch.stack, zip(*batch))\n",
        "    max_len = max(all_lens).item()\n",
        "    all_input_ids = all_input_ids[:, :max_len]\n",
        "    all_attention_mask = all_attention_mask[:, :max_len]\n",
        "    all_token_type_ids = all_token_type_ids[:, :max_len]\n",
        "    all_masklm = all_masklm[:, :max_len]\n",
        "\n",
        "    return all_input_ids, all_attention_mask, all_token_type_ids, all_masklm, all_labels, all_datalabels"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4leBz_pt8lrb"
      },
      "source": [
        "## MiDAS-Encoder Model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LFZX4EtWK1SN"
      },
      "source": [
        "### Model Definition"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Shyz_obDDvmH"
      },
      "outputs": [],
      "source": [
        "from ednaml.models.Albert import AlbertPreTrainedModel, AlbertModel, AlbertOnlyMLMHead\n",
        "from ednaml.models.Albert import AlbertConfig\n",
        "from torch.nn import CrossEntropyLoss\n",
        "from ednaml.models.ModelAbstract import ModelAbstract\n",
        "from torch import nn\n",
        "\n",
        "from ednaml.utils.layers import GradientReversalLayer\n",
        "from ednaml.utils.albert import AlbertEmbeddingAverage, AlbertPooledOutput, AlbertRawCLSOutput\n",
        "\n",
        "class MiDASAlbert(ModelAbstract):\n",
        "  def __init__(self, base, weights, metadata, normalization, parameter_groups, **kwargs):\n",
        "    super().__init__(base=base,\n",
        "          weights=weights,\n",
        "          metadata=metadata, \n",
        "          normalization=normalization,\n",
        "          parameter_groups=parameter_groups,\n",
        "          **kwargs)\n",
        "    \n",
        "  def model_attributes_setup(self, **kwargs):\n",
        "    self.config = AlbertConfig(**kwargs)\n",
        "    self.configargs = self.config.getVars()\n",
        "    self.num_decoders = kwargs.get(\"num_decoders\")\n",
        "    self.domains = kwargs.get(\"domains\")\n",
        "    self.g_alpha = kwargs.get(\"g_alpha\", 1.0)\n",
        "    self.pool_method = kwargs.get(\"pooling\", \"pooled\")\n",
        "    self.num_decoder_range = [item for item in range(self.num_decoders)]\n",
        "    self.labeldiscriminator = kwargs.get(\"l_discriminator\", False)\n",
        "\n",
        "  \n",
        "  def model_setup(self, **kwargs):\n",
        "    self.encoder, errors = AlbertModel.from_pretrained(\"pytorch_model.bin\", config=self.config, output_loading_info=True)\n",
        "    print(\"Errors \\n\\t\", errors)\n",
        "    self.decoders = nn.ModuleList([AlbertOnlyMLMHead(self.config)]*self.num_decoders)\n",
        "    self.decoders.apply(self._init_weights)\n",
        "    self.gradient_reversal = GradientReversalLayer(alpha=self.g_alpha)\n",
        "    self.discriminator = nn.Linear(self.config.hidden_size, self.domains, bias=False)\n",
        "    self.discriminator.apply(self.weights_init_softmax)\n",
        "    if self.labeldiscriminator:\n",
        "      self.l_d = nn.Linear(self.config.hidden_size, self.config.num_labels, bias=False)\n",
        "      self.l_d.apply(self.weights_init_softmax)\n",
        "    else:\n",
        "      self.l_d = None\n",
        "    self.tie_weights()\n",
        "\n",
        "    if self.pool_method == \"pooled\":\n",
        "      self.pooler_layer = AlbertPooledOutput()\n",
        "    elif self.pool_method == \"raw\":\n",
        "      self.pooler_layer = AlbertRawCLSOutput()\n",
        "    elif self.pool_method == \"average\":\n",
        "      self.pooler_layer = AlbertEmbeddingAverage()\n",
        "    else:\n",
        "      raise NotImplementedError()\n",
        "\n",
        "\n",
        "  def tie_weights(self):\n",
        "    for cls in self.decoders:\n",
        "        self._tie_or_clone_weights(cls,\n",
        "                                  self.encoder.embeddings.word_embeddings)\n",
        "        \n",
        "  def _tie_or_clone_weights(self, first_module, second_module):\n",
        "        \"\"\" Tie or clone module weights depending of weither we are using TorchScript or not\n",
        "        \"\"\"\n",
        "\n",
        "        if self.config.torchscript:\n",
        "            first_module.weight = nn.Parameter(second_module.weight.clone())\n",
        "        else:\n",
        "            first_module.weight = second_module.weight\n",
        "\n",
        "\n",
        "        if hasattr(first_module, 'bias') and first_module.bias is not None:\n",
        "            first_module.bias.data = torch.nn.functional.pad(\n",
        "                first_module.bias.data,\n",
        "                (0, first_module.weight.shape[0] - first_module.bias.shape[0]),\n",
        "                'constant',\n",
        "                0\n",
        "            )\n",
        "\n",
        "  def _init_weights(self, module):\n",
        "        \"\"\" Initialize the weights \"\"\"\n",
        "        if isinstance(module, (nn.Linear, nn.Embedding)):\n",
        "            # Slightly different from the TF version which uses truncated_normal for initialization\n",
        "            # cf https://github.com/pytorch/pytorch/pull/5617\n",
        "            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n",
        "        elif isinstance(module, torch.nn.LayerNorm):\n",
        "            module.bias.data.zero_()\n",
        "            module.weight.data.fill_(1.0)\n",
        "        if isinstance(module, nn.Linear) and module.bias is not None:\n",
        "            module.bias.data.zero_()\n",
        "\n",
        "  def forward_impl(self, x, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, perturbation=None):\n",
        "    outputs = self.encoder(x,\n",
        "                        attention_mask=attention_mask,\n",
        "                        token_type_ids=token_type_ids,\n",
        "                        position_ids=position_ids,\n",
        "                        head_mask=head_mask)  # sequence_output, pooled_output, (hidden_states), (attentions)\n",
        "    \n",
        "    sequence_output = outputs[0]\n",
        "    if self.inferencing:\n",
        "      # we will need to add sequence_output to perturbation...\n",
        "      sequence_output = perturbation + sequence_output\n",
        "      prediction_scores = [self.decoders[idx](sequence_output) for idx in self.num_decoder_range]\n",
        "      return prediction_scores, self.pooler_layer(outputs), None\n",
        "    pooled_out = self.pooler_layer(outputs)\n",
        "    prediction_scores = [self.decoders[idx](sequence_output) for idx in self.num_decoder_range]\n",
        "    discrimination = self.discriminator(self.gradient_reversal(self.pooler_layer(outputs))) # use pooled outputs for discrimination... (NOTE -- can also use the CLS token...but we won't do that)\n",
        "    l_d = None\n",
        "    if self.labeldiscriminator:\n",
        "      l_d = self.l_d(pooled_out)\n",
        "    #discrimination = self.discriminator(self.gradient_reversal(outputs[0][:,0])) # use pooled outputs for discrimination... (NOTE -- can also use the CLS token...but we won't do that)\n",
        "    #outputs = (prediction_scores,) + outputs[2:]  # add discriminator here as well...\n",
        "    return prediction_scores, discrimination, l_d\n",
        "    \n",
        "\n",
        "  def partial_load(self, weights_path):\n",
        "    super().partial_load(self, weights_path)  # For this, we need to look at the from_pretrained function to accurately load the saved weights from .bin..."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "57z4bFE9LGr6"
      },
      "source": [
        "## MiDAS-Expert Model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VEx0iOEWLLpv"
      },
      "source": [
        "### Model Definition"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UG2CyIaLLE0H"
      },
      "outputs": [],
      "source": [
        "from ednaml.models.Albert import AlbertPreTrainedModel, AlbertModel, AlbertOnlyMLMHead\n",
        "from ednaml.models.Albert import AlbertConfig\n",
        "from torch.nn import CrossEntropyLoss\n",
        "from ednaml.models.ModelAbstract import ModelAbstract\n",
        "from torch import nn\n",
        "\n",
        "from ednaml.utils.albert import AlbertEmbeddingAverage, AlbertPooledOutput, AlbertRawCLSOutput\n",
        "\n",
        "class MiDASExpert(ModelAbstract):\n",
        "  def __init__(self, base, weights, metadata, normalization, parameter_groups, **kwargs):\n",
        "    super().__init__(base=base,\n",
        "          weights=weights,\n",
        "          metadata=metadata, \n",
        "          normalization=normalization,\n",
        "          parameter_groups=parameter_groups,\n",
        "          **kwargs)\n",
        "    \n",
        "  def model_attributes_setup(self, **kwargs):\n",
        "    self.config = AlbertConfig(**kwargs)\n",
        "    self.configargs = self.config.getVars()\n",
        "    self.num_labels = kwargs.get(\"num_classes\")\n",
        "    self.pool_method = kwargs.get(\"pooling\", \"pooled\")\n",
        "  \n",
        "  def model_setup(self, **kwargs):\n",
        "    \n",
        "    self.dropout = nn.Dropout(0.1 if self.config.hidden_dropout_prob == 0 else self.config.hidden_dropout_prob)\n",
        "    self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels)\n",
        "\n",
        "    if self.pool_method == \"pooled\":\n",
        "      self.pooler_layer = AlbertPooledOutput()\n",
        "    elif self.pool_method == \"raw\":\n",
        "      self.pooler_layer = AlbertRawCLSOutput()\n",
        "    elif self.pool_method == \"average\":\n",
        "      self.pooler_layer = AlbertEmbeddingAverage()\n",
        "    else:\n",
        "      raise NotImplementedError()\n",
        "\n",
        "    self.init_weights()\n",
        "    \n",
        "    self.encoder, errors = AlbertModel.from_pretrained(\"pytorch_model.bin\", config=self.config, output_loading_info=True)\n",
        "    print(\"Errors \\n\\t\", errors)\n",
        "    \n",
        "        \n",
        "  def _tie_or_clone_weights(self, first_module, second_module):\n",
        "        \"\"\" Tie or clone module weights depending of weither we are using TorchScript or not\n",
        "        \"\"\"\n",
        "\n",
        "        if self.config.torchscript:\n",
        "            first_module.weight = nn.Parameter(second_module.weight.clone())\n",
        "        else:\n",
        "            first_module.weight = second_module.weight\n",
        "\n",
        "\n",
        "        if hasattr(first_module, 'bias') and first_module.bias is not None:\n",
        "            first_module.bias.data = torch.nn.functional.pad(\n",
        "                first_module.bias.data,\n",
        "                (0, first_module.weight.shape[0] - first_module.bias.shape[0]),\n",
        "                'constant',\n",
        "                0\n",
        "            )\n",
        "\n",
        "  def forward_impl(self, x, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):\n",
        "    outputs = self.encoder(x,\n",
        "                        attention_mask=attention_mask,\n",
        "                        token_type_ids=token_type_ids,\n",
        "                        position_ids=position_ids,\n",
        "                        head_mask=head_mask)  # sequence_output, pooled_output, (hidden_states), (attentions)\n",
        "    \n",
        "    pooled_output = self.pooler_layer(outputs)  # TODO -- have an option to either use the pooled output, the original output, or average the embeddings together, i.e. a layer that is either a lambda layer, or does some averaging...\n",
        "    pooled_output = self.dropout(pooled_output+0.1)\n",
        "    logits = self.classifier(pooled_output)\n",
        "    return logits, pooled_output, outputs[2:] # list of k scores; hidden states, attentions...\n",
        "\n",
        "  def init_weights(self):\n",
        "        \"\"\" Initialize and prunes weights if needed. \"\"\"\n",
        "        # Initialize weights\n",
        "        self.apply(self._init_weights)\n",
        "\n",
        "  def _init_weights(self, module):\n",
        "        \"\"\" Initialize the weights \"\"\"\n",
        "        if isinstance(module, (nn.Linear, nn.Embedding)):\n",
        "            # Slightly different from the TF version which uses truncated_normal for initialization\n",
        "            # cf https://github.com/pytorch/pytorch/pull/5617\n",
        "            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n",
        "        elif isinstance(module, torch.nn.LayerNorm):\n",
        "            module.bias.data.zero_()\n",
        "            module.weight.data.fill_(1.0)\n",
        "        if isinstance(module, nn.Linear) and module.bias is not None:\n",
        "            module.bias.data.zero_()\n",
        "\n",
        "\n",
        "  def partial_load(self, weights_path):\n",
        "    super().partial_load(self, weights_path)  # For this, we need to look at the from_pretrained function to accurately load the saved weights from .bin..."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MdwnXp07yaop"
      },
      "source": [
        "# Calculating L for each expert (functions)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "McAD9IEUeO-r"
      },
      "outputs": [],
      "source": [
        "%load_ext autoreload\n",
        "%autoreload 2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3_imUFi5Asco",
        "outputId": "b15b8632-4cee-4eee-ed55-e510255ccbc9"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'1.11.0+cu113'"
            ]
          },
          "execution_count": 7,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import torch, tqdm, ednaml\n",
        "from ednaml.core import EdnaML\n",
        "torch.__version__"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kiDZrl6YENhJ"
      },
      "source": [
        "## Calculate encoder features, expert features, expert logits, and return the true labels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mJpvvj2VAf3G"
      },
      "outputs": [],
      "source": [
        "def calculate_features(encoder, expert_config, expert_index):\n",
        "  from ednaml.utils import build_model_and_load_weights\n",
        "  print(\"Loading expert from %s\"%expert_config)\n",
        "  expert = build_model_and_load_weights(expert_config, model_class=MiDASExpert, epoch=None, custom_metadata=None, add_streamhandler = False)\n",
        "  print(\"Loading data from %s\"%expert_config)\n",
        "  data_eml = EdnaML(expert_config, add_filehandler=False, add_streamhandler = False)\n",
        "  data_eml.addCrawlerClass(MiDASCrawler)\n",
        "  data_eml.addGeneratorClass(MiDASGenerator)\n",
        "  data_eml.cfg.EXECUTION.DATAREADER.DATASET_ARGS[\"masking\"] = False\n",
        "  data_eml.buildDataloaders()\n",
        "  encoder.num_decoder_range = [expert_index]\n",
        "  encoder.cuda()\n",
        "  expert.cuda()\n",
        "  encoder.inference()\n",
        "  expert.inference()\n",
        "  print(\"Getting features for %s\"%os.path.basename(expert_config))\n",
        "  encoder_features, expert_features, expert_logits, true_labels = [],[],[],[]\n",
        "  with torch.inference_mode():\n",
        "    for batch in tqdm.tqdm(data_eml.train_generator.dataloader):\n",
        "      batch = tuple(item.cuda() for item in batch)\n",
        "      all_input_ids, all_attention_mask, all_token_type_ids, all_masklm, all_labels, all_datalabels = batch\n",
        "      perturbation = torch.zeros((*all_input_ids.shape, 768)).cuda()\n",
        "      preds, features, _ = encoder(all_input_ids, token_type_ids = all_token_type_ids, attention_mask=all_attention_mask, perturbation=perturbation)\n",
        "      decoded = preds[0].max(2)[1]\n",
        "      logits, pooled, _ = expert(decoded, token_type_ids = all_token_type_ids, attention_mask=all_attention_mask)  # make these torch.onez/zeros(), of the correct size!!!!!\n",
        "\n",
        "      features = features.detach().cpu()\n",
        "      logits = logits.detach().cpu()\n",
        "      pooled = pooled.detach().cpu()\n",
        "      all_labels = all_labels.detach().cpu()\n",
        "\n",
        "      encoder_features.append(features)\n",
        "      expert_features.append(pooled)\n",
        "      expert_logits.append(logits)\n",
        "      true_labels.append(all_labels)\n",
        "      batch = tuple(item.cpu() for item in batch)\n",
        "  encoder_features, expert_features, expert_logits, true_labels = (\n",
        "        torch.cat(encoder_features, dim=0),\n",
        "        torch.cat(expert_features, dim=0),\n",
        "        torch.cat(expert_logits, dim=0),\n",
        "        torch.cat(true_labels, dim=0)\n",
        "    )\n",
        "  batch = tuple(item.cpu() for item in batch)\n",
        "  encoder.cpu()\n",
        "  expert.cpu()\n",
        "  torch.cuda.empty_cache()\n",
        "  return encoder_features, expert_features, expert_logits, true_labels\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TRmGKBgQEaGu"
      },
      "source": [
        "## Distance similarity matrix -- convert to cosine similarity eventually?"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "21II2CKxrGiB"
      },
      "outputs": [],
      "source": [
        "def similarity_matrix(mat):\n",
        "    # get the product x * y\n",
        "    # here, y = x.t()\n",
        "    r = torch.mm(mat, mat.t())\n",
        "    # get the diagonal elements\n",
        "    diag = r.diag().unsqueeze(0)\n",
        "    diag = diag.expand_as(r)\n",
        "    # compute the distance matrix\n",
        "    D = diag + diag.t() - 2*r\n",
        "    return D.sqrt()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mS5c1HKcEfRd"
      },
      "source": [
        "## Compute the L-value for an expert"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BzC9P7W_vfcj"
      },
      "outputs": [],
      "source": [
        "def get_L_value(expert_features, encoder_features, expert_labels, label_cluster, alpha=15):\n",
        "  expert_centroid = torch.mean(expert_features[expert_labels==label_cluster],0)\n",
        "  # nearest neighbors to centroid\n",
        "  simmat = similarity_matrix(torch.cat((expert_centroid.unsqueeze(1).T,expert_features[expert_labels==label_cluster])))\n",
        "  simmat_sorted = simmat[0].sort()[1]\n",
        "  # index 0 is the centroid. index 1 is the actual centroid. index 16 is the 15th nearest point\n",
        "  simmat_range_idx = simmat_sorted[2:alpha+2]-1\n",
        "  simmat_centroid_idx = simmat_sorted[1]-1\n",
        "  # now we can measure the L values\n",
        "  L_x = encoder_features[expert_labels == label_cluster]\n",
        "  L_fx = expert_features[expert_labels == label_cluster]\n",
        "  L_x1 = L_x[simmat_centroid_idx]\n",
        "  L_fx1 = L_fx[simmat_centroid_idx]\n",
        "  L_x2 = L_x[simmat_range_idx]\n",
        "  L_fx2 = L_fx[simmat_range_idx]\n",
        "  xdist = torch.sqrt(torch.sum(torch.square(L_x2-L_x1),1))\n",
        "  fxdist = torch.sqrt(torch.sum(torch.square(L_fx2-L_fx1),1))\n",
        "  local_l = max(fxdist / xdist)\n",
        "  return local_l, fxdist/xdist, expert_centroid"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x16t47PQte3h"
      },
      "outputs": [],
      "source": [
        "def computeLfromArgs(x1,x2,fx1,fx2):\n",
        "  xdist = torch.sqrt(torch.sum(torch.square(x2-x1),1))\n",
        "  fxdist = torch.sqrt(torch.sum(torch.square(fx2-fx1),1))\n",
        "  local_l = max(fxdist / xdist)\n",
        "  return local_l, fxdist, xdist"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ks9YxEJMEy5U"
      },
      "source": [
        "## Calculate local L for all experts"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hYtgIlHqq3An"
      },
      "outputs": [],
      "source": [
        "def calculate_all_L(encoder_config, expert_configs):\n",
        "  from ednaml.utils import build_model_and_load_weights\n",
        "  print(\"Loading encoder from %s\"%encoder_config)\n",
        "  encoder = build_model_and_load_weights(encoder_config, model_class=MiDASAlbert, epoch=None, custom_metadata=None, add_streamhandler = False)\n",
        "\n",
        "  ef, el, tlab, elab = [None]*len(expert_configs), [None]*len(expert_configs), [None]*len(expert_configs), [None]*len(expert_configs)\n",
        "  for idx, expert_config in enumerate(expert_configs):\n",
        "    print(\"Starting on %s\"%os.path.basename(expert_config))\n",
        "    encoder_features, expert_features, expert_logits, true_labels = calculate_features(\n",
        "                                                                        encoder = encoder,\n",
        "                                                                        expert_config = expert_config,\n",
        "                                                                        expert_index = idx\n",
        "                                                                        )\n",
        "    print(\"Calculating local L values for %s\"%os.path.basename(expert_config))\n",
        "    expert_labels = torch.argmax(expert_logits, dim=1)\n",
        "    ef.append(encoder_features)\n",
        "    el.append(expert_logits)\n",
        "    tlab.append(true_labels)\n",
        "    elab.append(expert_labels)\n",
        "  return ef, el, tlab, elab\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mVV7-RbLE3gS"
      },
      "source": [
        "## Compute the overall L value for encoder from all local Ls"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-P9th8q5fsCW"
      },
      "outputs": [],
      "source": [
        "def computeL(ef, el, tlab, elab, alpha=10):\n",
        "  pos, neg = get_L_value(el, ef, elab, label_cluster=1, alpha=alpha), get_L_value(el, ef, elab, label_cluster=0, alpha=alpha)\n",
        "  (pos_L,pos_allL,pos_centroid) = pos\n",
        "  (neg_L,neg_allL,neg_centroid) = neg\n",
        "  max_local_L = max(pos_L, neg_L)\n",
        "  return max_local_L"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IF7QtZQJE6y3"
      },
      "source": [
        "## Perform an $m$-sweep to check values of Ls for different $m$s"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oKcRRQT5ftSB"
      },
      "outputs": [],
      "source": [
        "def alphasweep(ef, el, tlab, elab, alpha=[10]):\n",
        "  lvals = [None]*len(alpha)\n",
        "  for idx, a in enumerate(alpha):\n",
        "    exp_l = [None]*len(ef)\n",
        "    for expert_idx in range(len(ef)):\n",
        "      exp_l[expert_idx] = computeL(ef[expert_idx], el[expert_idx], tlab[expert_idx], elab[expert_idx], a)\n",
        "    lvals[idx] = max(exp_l)\n",
        "    estr = \"\\t\".join([\"D{1}: {0:.3f}\".format(item.item(), idx) for idx, item in enumerate(exp_l)])\n",
        "    print(\"Alpha: %i\\tL:%f\\t %s\"%(a, lvals[idx], estr))\n",
        "  return lvals"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7-fXPKYI7Grg"
      },
      "source": [
        "# Actual L Calculation using functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0QOmlOIrLApB",
        "outputId": "05ceda5d-4e14-4983-9a79-56a781ed7704"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Loading encoder from ./GLAMOR/profiles/MiDAS/encoder-experiments/midas-encoder-easy-5.yml\n",
            "loading weights file pytorch_model.bin\n",
            "Errors \n",
            "\t {'missing_keys': [], 'unexpected_keys': [], 'error_msgs': []}\n",
            "Starting on midas-coaid_news-1.yml\n",
            "Loading expert from ./GLAMOR/profiles/MiDAS/expert-experiments/midas-coaid_news-1.yml\n",
            "loading weights file pytorch_model.bin\n",
            "Errors \n",
            "\t {'missing_keys': [], 'unexpected_keys': [], 'error_msgs': []}\n",
            "Loading data from ./GLAMOR/profiles/MiDAS/expert-experiments/midas-coaid_news-1.yml\n",
            "Refreshing mask ids\n",
            "Refreshing mask ids\n",
            "Getting features for midas-coaid_news-1.yml\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 240/240 [00:19<00:00, 12.48it/s]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Calculating local L values for midas-coaid_news-1.yml\n",
            "Starting on midas-kagglefnlong-1.yml\n",
            "Loading expert from ./GLAMOR/profiles/MiDAS/expert-experiments/midas-kagglefnlong-1.yml\n",
            "loading weights file pytorch_model.bin\n",
            "Errors \n",
            "\t {'missing_keys': [], 'unexpected_keys': [], 'error_msgs': []}\n",
            "Loading data from ./GLAMOR/profiles/MiDAS/expert-experiments/midas-kagglefnlong-1.yml\n",
            "Refreshing mask ids\n",
            "Refreshing mask ids\n",
            "Getting features for midas-kagglefnlong-1.yml\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 1965/1965 [08:26<00:00,  3.88it/s]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Calculating local L values for midas-kagglefnlong-1.yml\n",
            "Starting on midas-cov19fntitle-1.yml\n",
            "Loading expert from ./GLAMOR/profiles/MiDAS/expert-experiments/midas-cov19fntitle-1.yml\n",
            "loading weights file pytorch_model.bin\n",
            "Errors \n",
            "\t {'missing_keys': [], 'unexpected_keys': [], 'error_msgs': []}\n",
            "Loading data from ./GLAMOR/profiles/MiDAS/expert-experiments/midas-cov19fntitle-1.yml\n",
            "Refreshing mask ids\n",
            "Refreshing mask ids\n",
            "Getting features for midas-cov19fntitle-1.yml\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 123/123 [02:37<00:00,  1.28s/it]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Calculating local L values for midas-cov19fntitle-1.yml\n",
            "Starting on midas-covrumor-1.yml\n",
            "Loading expert from ./GLAMOR/profiles/MiDAS/expert-experiments/midas-covrumor-1.yml\n",
            "loading weights file pytorch_model.bin\n",
            "Errors \n",
            "\t {'missing_keys': [], 'unexpected_keys': [], 'error_msgs': []}\n",
            "Loading data from ./GLAMOR/profiles/MiDAS/expert-experiments/midas-covrumor-1.yml\n",
            "Refreshing mask ids\n",
            "Refreshing mask ids\n",
            "Getting features for midas-covrumor-1.yml\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 244/244 [00:32<00:00,  7.56it/s]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Calculating local L values for midas-covrumor-1.yml\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\n"
          ]
        }
      ],
      "source": [
        "encoder_features, encoder_logits, true_labels, encoder_labels  = calculate_all_L(\n",
        "    \"midas-encoder.yml\",\n",
        "    [\n",
        "      \"midas-coaid_news-1.yml\",\n",
        "      \"midas-kagglefnlong-1.yml\",\n",
        "      \"midas-cov19fntitle-1.yml\",\n",
        "      \"midas-covrumor-1.yml\"\n",
        "    ]\n",
        ")\n",
        "#L_value = max(localL)[0].item()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ByURZx1RFPG4"
      },
      "source": [
        "## Perform m-sweep"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_9S0bXCa_l_x",
        "outputId": "60281ef2-8ca3-4d85-891c-56b3bdc912c7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Alpha: 1\tL:0.077129\t D0: 0.077\tD1: 0.007\tD2: 0.003\tD3: 0.017\n",
            "Alpha: 10\tL:0.077129\t D0: 0.077\tD1: 0.007\tD2: 0.023\tD3: 0.034\n",
            "Alpha: 20\tL:0.083998\t D0: 0.077\tD1: 0.013\tD2: 0.026\tD3: 0.084\n",
            "Alpha: 30\tL:0.083998\t D0: 0.077\tD1: 0.021\tD2: 0.029\tD3: 0.084\n",
            "Alpha: 40\tL:0.083998\t D0: 0.077\tD1: 0.037\tD2: 0.036\tD3: 0.084\n",
            "Alpha: 50\tL:0.110128\t D0: 0.085\tD1: 0.037\tD2: 0.039\tD3: 0.110\n",
            "Alpha: 60\tL:0.110128\t D0: 0.085\tD1: 0.037\tD2: 0.045\tD3: 0.110\n",
            "Alpha: 70\tL:0.114787\t D0: 0.085\tD1: 0.037\tD2: 0.081\tD3: 0.115\n",
            "Alpha: 80\tL:0.205267\t D0: 0.097\tD1: 0.037\tD2: 0.081\tD3: 0.205\n",
            "Alpha: 90\tL:0.205267\t D0: 0.101\tD1: 0.037\tD2: 0.081\tD3: 0.205\n",
            "Alpha: 100\tL:0.205267\t D0: 0.112\tD1: 0.037\tD2: 0.081\tD3: 0.205\n",
            "Alpha: 110\tL:0.205267\t D0: 0.112\tD1: 0.037\tD2: 0.081\tD3: 0.205\n",
            "Alpha: 120\tL:0.205267\t D0: 0.112\tD1: 0.037\tD2: 0.084\tD3: 0.205\n",
            "Alpha: 130\tL:0.205267\t D0: 0.112\tD1: 0.037\tD2: 0.084\tD3: 0.205\n",
            "Alpha: 140\tL:0.310975\t D0: 0.123\tD1: 0.037\tD2: 0.084\tD3: 0.311\n",
            "Alpha: 150\tL:0.310975\t D0: 0.123\tD1: 0.037\tD2: 0.086\tD3: 0.311\n",
            "Alpha: 160\tL:0.310975\t D0: 0.123\tD1: 0.040\tD2: 0.148\tD3: 0.311\n",
            "Alpha: 170\tL:0.310975\t D0: 0.150\tD1: 0.040\tD2: 0.148\tD3: 0.311\n",
            "Alpha: 180\tL:0.310975\t D0: 0.150\tD1: 0.066\tD2: 0.148\tD3: 0.311\n",
            "Alpha: 190\tL:0.310975\t D0: 0.150\tD1: 0.066\tD2: 0.149\tD3: 0.311\n",
            "Alpha: 200\tL:0.310975\t D0: 0.180\tD1: 0.066\tD2: 0.149\tD3: 0.311\n",
            "Alpha: 210\tL:0.310975\t D0: 0.180\tD1: 0.078\tD2: 0.220\tD3: 0.311\n",
            "Alpha: 220\tL:0.310975\t D0: 0.180\tD1: 0.078\tD2: 0.220\tD3: 0.311\n",
            "Alpha: 230\tL:0.310975\t D0: 0.180\tD1: 0.078\tD2: 0.220\tD3: 0.311\n",
            "Alpha: 240\tL:0.310975\t D0: 0.180\tD1: 0.078\tD2: 0.220\tD3: 0.311\n"
          ]
        }
      ],
      "source": [
        "asweep = [item*10 for item in range(25)]\n",
        "asweep[0] = 1\n",
        "lvals = alphasweep(encoder_features[4:], encoder_logits[4:], true_labels[4:], encoder_labels[4:], asweep)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bzp1OA4Xkunv",
        "outputId": "71ee0dd3-d393-49f6-a449-a85deab52571"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[tensor(0.0771),\n",
              " tensor(0.0771),\n",
              " tensor(0.0840),\n",
              " tensor(0.0840),\n",
              " tensor(0.0840),\n",
              " tensor(0.1101),\n",
              " tensor(0.1101),\n",
              " tensor(0.1148),\n",
              " tensor(0.2053),\n",
              " tensor(0.2053),\n",
              " tensor(0.2053),\n",
              " tensor(0.2053),\n",
              " tensor(0.2053),\n",
              " tensor(0.2053),\n",
              " tensor(0.3110),\n",
              " tensor(0.3110),\n",
              " tensor(0.3110),\n",
              " tensor(0.3110),\n",
              " tensor(0.3110),\n",
              " tensor(0.3110),\n",
              " tensor(0.3110),\n",
              " tensor(0.3110),\n",
              " tensor(0.3110),\n",
              " tensor(0.3110),\n",
              " tensor(0.3110)]"
            ]
          },
          "execution_count": 17,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "lvals"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IUJXzzWlQ7lj"
      },
      "source": [
        "# MiDAS Evaluation with a specific $\\epsilon$ value"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0om9ASVErV9M"
      },
      "outputs": [],
      "source": [
        "L_value = 0.037445\n",
        "#L_value = 0.110128"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "KSVI9d5ZSrF1",
        "outputId": "f32aec49-ff7e-4e5a-f21b-721282e78b83"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epsilon:   26.70583522499666\n"
          ]
        }
      ],
      "source": [
        "epsilon = 1/L_value\n",
        "print(\"Epsilon:  \", epsilon)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PfSFYXil8CyP"
      },
      "outputs": [],
      "source": [
        "econfig = {}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qzevQqzOFwj2"
      },
      "source": [
        "## Perturbation helper function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Kudy-wJqTC6k"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import numpy as np"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bxAB9wOsR48F"
      },
      "outputs": [],
      "source": [
        "def generate_perturbation(epsilon, n_dims, n_samples):\n",
        "  Y = np.random.multivariate_normal(mean=[0], cov=np.eye(1,1), size=(n_dims, n_samples))\n",
        "  Y = np.squeeze(Y, -1)\n",
        "  Y /= np.sqrt(np.sum(Y * Y, axis=0))\n",
        "  U = np.random.uniform(low=0, high=1, size=(n_samples)) ** (1/n_dims)\n",
        "  Y *= U * epsilon # in my case radius is one\n",
        "  return torch.Tensor(Y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tl64gQgCF1YE"
      },
      "source": [
        "## Set up encoder model, expert model, and target data "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZVyAxIzLGr0h"
      },
      "outputs": [],
      "source": [
        "encoder_config = \"./GLAMOR/profiles/MiDAS/encoder-experiments/midas-encoder-easy-1.yml\"\n",
        "target_config = \"./GLAMOR/profiles/MiDAS/expert-experiments/midas-cov19fntext-1.yml\"\n",
        "expert_config = \"./GLAMOR/profiles/MiDAS/expert-experiments/midas-cov19fntitle-1.yml\"\n",
        "#epsilon = 0.0556"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2GmNRVPr_vMc",
        "outputId": "a87c22ef-b697-4f88-c845-c12278dac04f"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "00:49:32 Loaded ednaml_model_builder from ednaml.models to build model\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Loading encoder from ./GLAMOR/profiles/MiDAS/encoder-experiments/midas-encoder-easy-1.yml\n",
            "loading weights file pytorch_model.bin\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "00:49:32 Finished instantiating model with MiDASAlbert architecture\n",
            "00:49:32 Previous stop detected. Will attempt to resume from epoch 6\n",
            "00:49:32 Loading model from drive backup.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Errors \n",
            "\t {'missing_keys': [], 'unexpected_keys': [], 'error_msgs': []}\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "00:49:33 Finished loading model state_dict from ./drive/MyDrive/Projects/MiDAS/Models/midas_encoder-v1-albert-easy_domain/midas_encoder-v1_epoch6.pth\n",
            "00:49:33 Reading data with DataReader DataReader\n",
            "00:49:33 Crawling Data for ['cov19_fn_text']\n",
            "00:49:33 Skipped generating training data, because EdnaML is in test mode.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Loading data from ./GLAMOR/profiles/MiDAS/expert-experiments/midas-cov19fntext-1.yml\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "00:49:34 Generated test data/query generator\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Refreshing mask ids\n"
          ]
        }
      ],
      "source": [
        "#def calculate_perturbed_L(encoder, expert_config, target_config, epsilon):\n",
        "import torch, tqdm, ednaml\n",
        "from ednaml.core import EdnaML\n",
        "torch.__version__\n",
        "\n",
        "from ednaml.utils import build_model_and_load_weights\n",
        "\n",
        "print(\"Loading encoder from %s\"%encoder_config)\n",
        "encoder = build_model_and_load_weights(encoder_config, model_class=MiDASAlbert, epoch=None, custom_metadata=None, add_streamhandler = True)\n",
        "print(\"Loading data from %s\"%target_config)\n",
        "data_eml = EdnaML(target_config, mode = \"test\", add_filehandler=False, add_streamhandler = True)\n",
        "data_eml.addCrawlerClass(MiDASCrawler)\n",
        "data_eml.addGeneratorClass(MiDASGenerator)\n",
        "data_eml.cfg.EXECUTION.DATAREADER.DATASET_ARGS[\"masking\"] = False\n",
        "data_eml.cfg.TEST_TRANSFORMATION.BATCH_SIZE = 1\n",
        "data_eml.buildDataloaders()\n",
        "\n",
        "encoder.num_decoder_range = [1]\n",
        "encoder.cuda()\n",
        "encoder.inference()\n",
        "#print(\"Getting features for %s\"%os.path.basename(expert_config))\n",
        "encoder_features, expert_features, expert_logits, true_labels = [],[],[],[]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9eN2J5B98XZc",
        "outputId": "b1615adc-7edd-40a2-bbba-e733c5860817"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Loading expert from ./GLAMOR/profiles/MiDAS/expert-experiments/midas-cov19fntitle-1.yml\n",
            "loading weights file pytorch_model.bin\n",
            "Errors \n",
            "\t {'missing_keys': [], 'unexpected_keys': [], 'error_msgs': []}\n"
          ]
        }
      ],
      "source": [
        "print(\"Loading expert from %s\"%expert_config)\n",
        "expert = build_model_and_load_weights(expert_config, model_class=MiDASExpert, epoch=None, custom_metadata=None, add_streamhandler = False)\n",
        "expert.cuda()\n",
        "expert.inference()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7z87DLS6GNxU"
      },
      "source": [
        "## Pass target-config data through MiDAS, then expert"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "n_yPjlyD_w4y",
        "outputId": "e46ec6f2-9a9a-47fd-f26c-2be03450b35c"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 564/564 [03:57<00:00,  2.37it/s]\n"
          ]
        }
      ],
      "source": [
        "#encoder_features, expert_features, expert_\n",
        "expert_logits, true_labels, l_vals , elabels = [],[],[], []\n",
        "with torch.inference_mode():\n",
        "  for batch in tqdm.tqdm(data_eml.test_generator.dataloader):\n",
        "    #batch = next(iter(data_eml.test_generator.dataloader))\n",
        "    batch = tuple(item.cuda() for item in batch)\n",
        "    all_input_ids, all_attention_mask, all_token_type_ids, all_masklm, all_labels, all_datalabels = batch\n",
        "    # batch of 1\n",
        "    # generate perturbation points\n",
        "    slen = all_input_ids.shape[1]\n",
        "    perturbation = generate_perturbation(epsilon=epsilon, n_dims = 768, n_samples=16)\n",
        "    perturbation = perturbation.T[:,None,:].repeat(1,slen,1)\n",
        "    smax = torch.where(all_attention_mask[0]==1)[0][-1] - 1\n",
        "    perturbation[0,:,:]*=0  # set the first one to be original...\n",
        "    #perturbation[:,slen-1,:]*=0\n",
        "    perturbation = perturbation.cuda()\n",
        "    preds, features, _ = encoder(all_input_ids, token_type_ids = all_token_type_ids, attention_mask=all_attention_mask, perturbation = perturbation)\n",
        "    #preds, features, _ = encoder(all_input_ids, token_type_ids = all_token_type_ids, attention_mask=all_attention_mask)\n",
        "\n",
        "    decoded = preds[0].max(2)[1]\n",
        "    logits, pooled, _ = expert(decoded, token_type_ids = all_token_type_ids, attention_mask=all_attention_mask)  # make these torch.onez/zeros(), of the correct size!!!!!\n",
        "    perturbation = perturbation.detach().cpu()\n",
        "    features = features.detach().cpu()\n",
        "    logits = logits.detach().cpu()\n",
        "    pooled = pooled.detach().cpu()\n",
        "    all_labels = all_labels.detach().cpu()\n",
        "\n",
        "    lval, _, _ = computeLfromArgs(torch.zeros(15,768) + features,perturbation[1:,0,:]+features,torch.zeros(15,768)+pooled[0].unsqueeze(-1).T,pooled[1:])\n",
        "\n",
        "    #encoder_features.append(features)\n",
        "    #expert_features.append(pooled)\n",
        "    expert_logits.append(logits[0].unsqueeze(-1))\n",
        "    elabels.append(torch.argmax(logits[0]))\n",
        "    true_labels.append(all_labels)\n",
        "    l_vals.append(lval.item())\n",
        "    batch = tuple(item.cpu() for item in batch)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1VmkLPxcHD0c"
      },
      "source": [
        "## Perform accuracy measurements"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8kOA3ZwuzeC_"
      },
      "outputs": [],
      "source": [
        "l_vals = torch.Tensor(l_vals)\n",
        "accuracy = torch.mean((torch.Tensor(elabels) == torch.cat(true_labels)).float())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "SbN68pd8zsYV",
        "outputId": "d9a58d81-0088-46e5-bb3b-c9169e140b72"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Accuracy: 0.424\n",
            "L\t Acc \t Cov\n",
            "0 \t 1.000\t 0.002\n",
            "1 \t 0.349\t 0.422\n",
            "1 \t 0.415\t 0.965\n",
            "1 \t 0.424\t 1.000\n",
            "2 \t 0.424\t 1.000\n",
            "2 \t 0.424\t 1.000\n",
            "2 \t 0.424\t 1.000\n",
            "3 \t 0.424\t 1.000\n",
            "3 \t 0.424\t 1.000\n",
            "3 \t 0.424\t 1.000\n"
          ]
        }
      ],
      "source": [
        "print(\"Accuracy: %.3f\"%accuracy)\n",
        "print(\"L\\t Acc \\t Cov\")\n",
        "for i in np.linspace(min(l_vals), max(l_vals)+2, 10):\n",
        "  lcount = ((torch.Tensor(elabels)[l_vals <= i] == torch.cat(true_labels)[l_vals <= i]).float())\n",
        "  lacc = torch.mean(lcount)\n",
        "  print(\"%i \\t %0.3f\\t %.3f\"%(i,lacc,lcount.shape[0] / len(elabels)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p2drybCxRLEg"
      },
      "outputs": [],
      "source": [
        "econfig[expert_config] = {}\n",
        "econfig[expert_config][\"l_vals\"] = l_vals.numpy()\n",
        "econfig[expert_config][\"elabels\"] = torch.Tensor(elabels).numpy()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Zt3y9iFvIPDJ"
      },
      "outputs": [],
      "source": [
        "lvalstacked = np.stack([econfig[ecfg][\"l_vals\"] for ecfg in econfig])\n",
        "elabstacked = np.stack([econfig[ecfg][\"elabels\"] for ecfg in econfig])\n",
        "lvalminidx = np.argmin(lvalstacked,axis=0)\n",
        "lvalmin = lvalstacked[lvalminidx, range(len(lvalminidx))]\n",
        "final_labels = elabstacked[lvalminidx, range(len(lvalminidx))]\n",
        "final_true_labels =  torch.cat(true_labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2kHxTM4aItZW",
        "outputId": "531e1858-580c-4a39-a0df-b67e70adb074"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Final acc: 0.424\n"
          ]
        }
      ],
      "source": [
        "final_acc = torch.mean((torch.Tensor(final_labels)==final_true_labels).float())\n",
        "print(\"Final acc: %.3f\"%final_acc.item())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lMPYCcNuJgSG",
        "outputId": "d2dd24b5-60dd-4177-c62c-a81268ea1d01"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Accuracy: 0.424\n",
            "L\t Acc \t Cov\n",
            "0.86 \t 1.000\t 0.002\n",
            "0.96 \t 0.500\t 0.021\n",
            "1.07 \t 0.473\t 0.098\n",
            "1.17 \t 0.370\t 0.374\n",
            "1.28 \t 0.342\t 0.715\n",
            "1.38 \t 0.378\t 0.881\n",
            "1.49 \t 0.411\t 0.957\n",
            "1.59 \t 0.419\t 0.986\n",
            "1.70 \t 0.422\t 0.995\n",
            "1.80 \t 0.424\t 1.000\n"
          ]
        }
      ],
      "source": [
        "print(\"Accuracy: %.3f\"%final_acc)\n",
        "print(\"L\\t Acc \\t Cov\")\n",
        "#for i in range(max(int(min(lvalmin)),1),int(max(lvalmin)+2)):\n",
        "for i in np.linspace(min(lvalmin), max(lvalmin), 10):\n",
        "  lcount = ((torch.Tensor(final_labels)[lvalmin <= i] == final_true_labels[lvalmin <= i]).float())\n",
        "  lacc = torch.mean(lcount)\n",
        "  print(\"%3.2f \\t %0.3f\\t %.3f\"%(i,lacc,lcount.shape[0] / len(elabels)))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "t8t8OcDQDypK",
        "i0-UNku6DPqU",
        "OJ8JWzSoG5P1",
        "TqJoyEW1f_a1",
        "d5j3WfN0fpIT",
        "t6HDMaGzs2rd",
        "9MafR4mBKavE",
        "LFZX4EtWK1SN",
        "VEx0iOEWLLpv",
        "MdwnXp07yaop",
        "IF7QtZQJE6y3"
      ],
      "machine_shape": "hm",
      "name": "EdnaML - MiDAS",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
