{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 29,
      "metadata": {},
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import pickle\n",
        "from bounds import *"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### compute bounds for real-world datasets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 35,
      "metadata": {},
      "outputs": [],
      "source": [
        "# load pre-computed results if there is any\n",
        "\n",
        "with open(f'/usr0/home/yuncheng/MultiBench/synthetic/bounds/all_results.pickle', 'rb') as f:\n",
        "    results = pickle.load(f)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {},
      "outputs": [],
      "source": [
        "# compute bounds separately for each of the 3 modality pairs for sentiment, humor, and sarcasm datasets (with 3 modalities)\n",
        "\n",
        "for dataset in ['humor', 'sarcasm', 'mosei', 'mosi']:\n",
        "    with open(f'/usr0/home/yuncheng/MultiBench/synthetic/bounds/{dataset}_cluster.pickle', 'rb') as f:\n",
        "        data = pickle.load(f)\n",
        "    if f'{dataset}_vision_text' not in results:\n",
        "        x1 = np.array(data['test']['vision'])\n",
        "        x2 = np.array(data['test']['text'])\n",
        "        y = np.array(data['test']['labels'])\n",
        "        _, counts = np.unique(y, return_counts=True)\n",
        "        pk = counts / np.sum(counts)\n",
        "        ent = entropy(pk, base=2)\n",
        "        P, maps = convert_data_to_distribution(x1, x2, y)\n",
        "        all_quantities, all_bounds = get_bounds(P)\n",
        "        results[f'{dataset}_vision_text'] = dict()\n",
        "        results[f'{dataset}_vision_text']['all_quantities'] = all_quantities\n",
        "        results[f'{dataset}_vision_text']['all_bounds'] = all_bounds\n",
        "        results[f'{dataset}_vision_text']['all_quantities']['entropy'] = ent\n",
        "        results[f'{dataset}_vision_text']['all_quantities']['n_classes'] = len(counts)\n",
        "        print(dataset, 'vision+text', all_bounds)\n",
        "        with open(f'/usr0/home/yuncheng/MultiBench/synthetic/bounds/all_results.pickle', 'wb') as f:\n",
        "            pickle.dump(results, f)\n",
        "\n",
        "    if f'{dataset}_vision_audio' not in results:\n",
        "        x1 = np.array(data['test']['vision'])\n",
        "        x2 = np.array(data['test']['audio'])\n",
        "        y = np.array(data['test']['labels'])\n",
        "        _, counts = np.unique(y, return_counts=True)\n",
        "        pk = counts / np.sum(counts)\n",
        "        ent = entropy(pk, base=2)\n",
        "        P, maps = convert_data_to_distribution(x1, x2, y)\n",
        "        all_quantities, all_bounds = get_bounds(P)\n",
        "        results[f'{dataset}_vision_audio'] = dict()\n",
        "        results[f'{dataset}_vision_audio']['all_quantities'] = all_quantities\n",
        "        results[f'{dataset}_vision_audio']['all_bounds'] = all_bounds\n",
        "        print(dataset, 'vision+audio', all_bounds)\n",
        "        results[f'{dataset}_vision_audio']['all_quantities']['entropy'] = ent\n",
        "        results[f'{dataset}_vision_audio']['all_quantities']['n_classes'] = len(counts)\n",
        "        with open(f'/usr0/home/yuncheng/MultiBench/synthetic/bounds/all_results.pickle', 'wb') as f:\n",
        "            pickle.dump(results, f)\n",
        "\n",
        "    if f'{dataset}_audio_text' not in results:\n",
        "        x1 = np.array(data['test']['audio'])\n",
        "        x2 = np.array(data['test']['text'])\n",
        "        y = np.array(data['test']['labels'])\n",
        "        _, counts = np.unique(y, return_counts=True)\n",
        "        pk = counts / np.sum(counts)\n",
        "        ent = entropy(pk, base=2)\n",
        "        P, maps = convert_data_to_distribution(x1, x2, y)\n",
        "        all_quantities, all_bounds = get_bounds(P)\n",
        "        results[f'{dataset}_audio_text'] = dict()\n",
        "        results[f'{dataset}_audio_text']['all_quantities'] = all_quantities\n",
        "        results[f'{dataset}_audio_text']['all_bounds'] = all_bounds\n",
        "        print(dataset, 'audio+text', all_bounds)\n",
        "        results[f'{dataset}_audio_text']['all_quantities']['entropy'] = ent\n",
        "        results[f'{dataset}_audio_text']['all_quantities']['n_classes'] = len(counts)\n",
        "        with open(f'/usr0/home/yuncheng/MultiBench/synthetic/bounds/all_results.pickle', 'wb') as f:\n",
        "            pickle.dump(results, f)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {},
      "outputs": [],
      "source": [
        "# avmnist\n",
        "\n",
        "with open(f'/usr0/home/yuncheng/MultiBench/synthetic/bounds/avmnist_cluster.pickle', 'rb') as f:\n",
        "    data = pickle.load(f)\n",
        "if 'avmnist' not in results:\n",
        "    x1 = np.array(data['test']['image'])\n",
        "    x2 = np.array(data['test']['audio'])\n",
        "    y = np.array(data['test']['labels'])\n",
        "    _, counts = np.unique(y, return_counts=True)\n",
        "    pk = counts / np.sum(counts)\n",
        "    ent = entropy(pk, base=2)\n",
        "    P, maps = convert_data_to_distribution(x1, x2, y)\n",
        "    all_quantities, all_bounds = get_bounds(P)\n",
        "    results['avmnist'] = dict()\n",
        "    results['avmnist']['all_quantities'] = all_quantities\n",
        "    results['avmnist']['all_bounds'] = all_bounds\n",
        "    results['avmnist']['all_quantities']['entropy'] = ent\n",
        "    results['avmnist']['all_quantities']['n_classes'] = len(counts)\n",
        "    print('avmnist', all_bounds)\n",
        "    with open(f'/usr0/home/yuncheng/MultiBench/synthetic/bounds/all_results.pickle', 'wb') as f:\n",
        "        pickle.dump(results, f)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {},
      "outputs": [],
      "source": [
        "# mimic\n",
        "\n",
        "with open(f'/usr0/home/yuncheng/MultiBench/synthetic/bounds/mimic_cluster.pickle', 'rb') as f:\n",
        "    data = pickle.load(f)\n",
        "if 'mimic' not in results:\n",
        "    x1 = np.array(data['test']['modal1'])\n",
        "    x2 = np.array(data['test']['modal2'])\n",
        "    y = np.array(data['test']['labels'])\n",
        "    _, counts = np.unique(y, return_counts=True)\n",
        "    pk = counts / np.sum(counts)\n",
        "    ent = entropy(pk, base=2)\n",
        "    P, maps = convert_data_to_distribution(x1, x2, y)\n",
        "    all_quantities, all_bounds = get_bounds(P)\n",
        "    results['mimic'] = dict()\n",
        "    results['mimic']['all_quantities'] = all_quantities\n",
        "    results['mimic']['all_quantities']['entropy'] = ent\n",
        "    results['mimic']['all_quantities']['n_classes'] = len(counts)\n",
        "    results['mimic']['all_bounds'] = all_bounds\n",
        "    print('mimic', all_bounds)\n",
        "    with open(f'/usr0/home/yuncheng/MultiBench/synthetic/bounds/all_results.pickle', 'wb') as f:\n",
        "        pickle.dump(results, f)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 36,
      "metadata": {},
      "outputs": [],
      "source": [
        "# enrico\n",
        "\n",
        "with open(f'/usr0/home/yuncheng/MultiBench/synthetic/bounds/enrico_cluster.pickle', 'rb') as f:\n",
        "    data = pickle.load(f)\n",
        "if 'enrico' not in results:\n",
        "    x1 = np.array(data['test']['image'])\n",
        "    x2 = np.array(data['test']['wireframe'])\n",
        "    y = np.array(data['test']['labels'])\n",
        "    _, counts = np.unique(y, return_counts=True)\n",
        "    pk = counts / np.sum(counts)\n",
        "    ent = entropy(pk, base=2)\n",
        "    P, maps = convert_data_to_distribution(x1, x2, y)\n",
        "    all_quantities, all_bounds = get_bounds(P)\n",
        "    results['enrico'] = dict()\n",
        "    results['enrico']['all_quantities'] = all_quantities\n",
        "    results['enrico']['all_bounds'] = all_bounds\n",
        "    results['enrico']['all_quantities']['entropy'] = ent\n",
        "    results['enrico']['all_quantities']['n_classes'] = len(counts)\n",
        "    print('enrico', all_bounds)\n",
        "    with open(f'/usr0/home/yuncheng/MultiBench/synthetic/bounds/all_results.pickle', 'wb') as f:\n",
        "        pickle.dump(results, f)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 37,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "sarcasm_vision_text\n",
            "0.16099189806501513\n",
            "0.012885638908304407\n",
            "0.018920194901199654\n",
            "0.49403082275072646\n",
            "{'lower_R': 0.035792193369361236, 'lower_U1': 0.029758818482420252, 'lower_U2': 0.018572587075325386, 'lower_diff': 0.0727645267110582, 'upper': 0.7919976015823501}\n",
            "\n",
            "sarcasm_vision_audio\n",
            "0.16227887205457028\n",
            "0.011598910990317037\n",
            "0.02860549232864884\n",
            "0.31061942091086114\n",
            "{'lower_R': 0.005208499062894667, 'lower_U1': -0.011797147241049966, 'lower_U2': 0.038803897539506554, 'lower_diff': 0.05787696275263708, 'upper': 0.7823120580833305}\n",
            "\n",
            "sarcasm_audio_text\n",
            "0.1771434108991574\n",
            "0.01374067109713172\n",
            "0.0027702216436247583\n",
            "0.5083906674354036\n",
            "{'lower_R': 0.05735795841094382, 'lower_U1': 0.06833116294873684, 'lower_U2': 0.04814128370518645, 'lower_diff': 0.11104172930663579, 'upper': 0.7911410298169537}\n",
            "\n",
            "humor_vision_text\n",
            "0.012326482789573602\n",
            "0.029332138602113696\n",
            "6.345698359307539e-16\n",
            "0.20928762498147135\n",
            "{'lower_R': -0.008161148287795461, 'lower_U1': 0.021170990311625535, 'lower_U2': -0.024181781142573074, 'lower_diff': 0.0039275041024201506, 'upper': 0.9582150673812831}\n",
            "\n",
            "humor_vision_audio\n",
            "0.013272006680228076\n",
            "0.028386614714879004\n",
            "1.6273565684158568e-05\n",
            "0.23575065819497978\n",
            "{'lower_R': 0.0033044003910599473, 'lower_U1': 0.03167474153414381, 'lower_U2': -0.02340619448950741, 'lower_diff': 0.004863928803522145, 'upper': 0.958198793812179}\n",
            "\n",
            "humor_audio_text\n",
            "0.011294215253418831\n",
            "0.0019689035664655474\n",
            "0.0009994363364219713\n",
            "0.07927403387832731\n",
            "{'lower_R': 0.000341541016227076, 'lower_U1': 0.0013361696661972015, 'lower_U2': -0.0006685959461673451, 'lower_diff': 0.009727855244540931, 'upper': 0.9856111336166569}\n",
            "\n",
            "mosei_vision_text\n",
            "0.021572929961820853\n",
            "0.001580949878484888\n",
            "0.011658950012159903\n",
            "0.04264747633405388\n",
            "{'lower_R': 0.006134479238445012, 'lower_U1': -0.003943520900746251, 'lower_U2': 0.01322783814350937, 'lower_diff': 0.010184020313321721, 'upper': 0.9649301358116967}\n",
            "\n",
            "mosei_vision_audio\n",
            "0.017509457044517476\n",
            "0.005644422765419957\n",
            "5.320363085419916e-06\n",
            "0.02898867687581463\n",
            "{'lower_R': 0.0035382669954397983, 'lower_U1': 0.009177369422626707, 'lower_U2': -0.004192648005586466, 'lower_diff': 0.006362754104281078, 'upper': 0.9765837654911412}\n",
            "\n",
            "mosei_audio_text\n",
            "0.017514777355177317\n",
            "1.065577902260447e-15\n",
            "0.01571710261918092\n",
            "0.03310884037662838\n",
            "{'lower_R': 0.004128502226037112, 'lower_U1': -0.011588600398220412, 'lower_U2': 0.014646618817057356, 'lower_diff': 0.0031533327990398985, 'upper': 0.966511085689802}\n",
            "\n",
            "mosi_vision_text\n",
            "0.0310364103609508\n",
            "0.01714326340106777\n",
            "4.661760319724213e-14\n",
            "0.3105052762006219\n",
            "{'lower_R': 0.014071259726027519, 'lower_U1': 0.031214523126501092, 'lower_U2': -0.0006880728412246469, 'lower_diff': 0.028818229875314207, 'upper': 0.9249448497448565}\n",
            "\n",
            "mosi_vision_audio\n",
            "0.03381135624368165\n",
            "0.014368317524450192\n",
            "0.00548822131315872\n",
            "0.2823096416830247\n",
            "{'lower_R': 0.011985721675679128, 'lower_U1': 0.020865817880309296, 'lower_U2': 0.00040765985872707154, 'lower_diff': 0.031014600088469062, 'upper': 0.9194566284256328}\n",
            "\n",
            "mosi_audio_text\n",
            "0.02992075915856\n",
            "0.009378818378770365\n",
            "0.0011156512017003063\n",
            "0.1405526923239786\n",
            "{'lower_R': 0.001581893332054607, 'lower_U1': 0.009845060516843208, 'lower_U2': 0.0021095696637299325, 'lower_diff': 0.02082997773708413, 'upper': 0.9327092947678883}\n",
            "\n",
            "avmnist\n",
            "0.17412296422503123\n",
            "1.3930971662904117\n",
            "0.04405933091257233\n",
            "0.25976352205752024\n",
            "{'lower_R': 0.08718663493117015, 'lower_U1': 1.4362559372887724, 'lower_U2': -1.2550029745016955, 'lower_diff': -1.0624312951047772, 'upper': 1.7081430646928113}\n",
            "\n",
            "mimic\n",
            "0.004657249272865171\n",
            "0.2503856569594271\n",
            "2.881232154245327e-08\n",
            "0.0160503223457999\n",
            "{'lower_R': -0.0012610296850349423, 'lower_U1': 0.2491250679572009, 'lower_U2': -0.2495194487461873, 'lower_diff': -0.12185296156045777, 'upper': 0.4069562216068458}\n",
            "\n",
            "enrico\n",
            "0.4860432273346665\n",
            "0.5636430815658506\n",
            "0.7438917746457177\n",
            "1.0186823758541625\n",
            "{'lower_R': 0.013074015939091688, 'lower_U1': -0.16706907108462787, 'lower_U2': 0.20688402668561667, 'lower_diff': -0.5486113403654733, 'upper': 2.0922930021619592}\n",
            "\n"
          ]
        }
      ],
      "source": [
        "with open('/usr0/home/yuncheng/MultiBench/synthetic/bounds/all_results.pickle', 'rb') as f:\n",
        "    data = pickle.load(f)\n",
        "for dataset in data:\n",
        "    print(dataset)\n",
        "    for q in ['R', 'U2', 'U1', 'S']:\n",
        "        print(data[dataset]['all_quantities'][q])\n",
        "    print(data[dataset]['all_bounds'])\n",
        "    print()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "multibench",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.6"
    },
    "vscode": {
      "interpreter": {
        "hash": "8158f520b0615a91d72976457965394544e0f25ca15232774db0f5a21042574b"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
