{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "eb3248ac-8a2e-41db-9e83-d8eafc4648bc",
   "metadata": {},
   "outputs": [
    {
     "ename": "ImportError",
     "evalue": "cannot import name 'estimate_default_bound_classification' from 'modules.bound_utils' (/Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/modules/bound_utils.py)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mImportError\u001b[0m                               Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 13\u001b[0m\n\u001b[1;32m     10\u001b[0m matplotlib, plt \u001b[38;5;241m=\u001b[39m import_matplotlib()\n\u001b[1;32m     11\u001b[0m get_ipython()\u001b[38;5;241m.\u001b[39mrun_line_magic(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmatplotlib\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124minline\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m---> 13\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmodules\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbound_utils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m estimate_fcmi_bound_classification, estimate_default_bound_classification, estimate_proposed_bound_classification, estimate_sgld_bound, estimate_kl_bound_classification, estimate_lg_bound_classification, estimate_interp_bound_classification\n\u001b[1;32m     14\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mscripts\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfcmi_train_classifier\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m mnist_ld_schedule, \\\n\u001b[1;32m     15\u001b[0m     cifar_resnet50_ld_schedule  \u001b[38;5;66;03m# for pickle to be able to load LD methods\u001b[39;00m\n\u001b[1;32m     16\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmethods\u001b[39;00m\n",
      "\u001b[0;31mImportError\u001b[0m: cannot import name 'estimate_default_bound_classification' from 'modules.bound_utils' (/Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/modules/bound_utils.py)"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import pickle\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from nnlib.nnlib import utils\n",
    "from nnlib.nnlib.matplotlib_utils import import_matplotlib\n",
    "\n",
    "matplotlib, plt = import_matplotlib()\n",
    "%matplotlib inline\n",
    "\n",
    "from modules.bound_utils import estimate_fcmi_bound_classification, estimate_default_bound_classification, estimate_proposed_bound_classification, estimate_sgld_bound, estimate_kl_bound_classification, estimate_lg_bound_classification, estimate_interp_bound_classification\n",
    "from scripts.fcmi_train_classifier import mnist_ld_schedule, \\\n",
    "    cifar_resnet50_ld_schedule  # for pickle to be able to load LD methods\n",
    "import methods\n",
    "from torchmetrics.classification import MulticlassCalibrationError as ECE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b1ff804-bc57-401b-84fd-851116f851d8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "n = 75\n",
    "lr = 0.001\n",
    "seed = 0\n",
    "n_S_seeds = 30\n",
    "epoch = 200 ## full: 200 ; 200 epochsに対する結果\n",
    "n_batch = 256\n",
    "\n",
    "preds = []\n",
    "labels = []\n",
    "masks = []\n",
    "\n",
    "for S_seed in range(n_S_seeds):\n",
    "    dir_name = f'n={n},lr={lr},seed={seed},S_seed={S_seed}'\n",
    "    dir_path = os.path.join(os.getcwd() + \"/results/\", \"fcmi-mnist-4vs9-CNN\", dir_name)\n",
    "    if not os.path.exists(dir_path):\n",
    "        print(f\"Did not find results for {dir_name}\")\n",
    "        continue\n",
    "    \n",
    "    with open(os.path.join(dir_path, 'saved_data.pkl'), 'rb') as f:\n",
    "        saved_data = pickle.load(f)\n",
    "    \n",
    "    model = utils.load(path=os.path.join(dir_path, 'checkpoints', f'epoch{epoch - 1}.mdl'), methods=methods, device=\"cpu\")\n",
    "    if 'all_examples_wo_data_aug' in saved_data:\n",
    "        all_examples = saved_data['all_examples_wo_data_aug']\n",
    "    else:\n",
    "        all_examples = saved_data['all_examples']\n",
    "    \n",
    "    cur_preds = utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=n_batch)['pred'] ## (2*n, num_classes)\n",
    "    cur_labels = utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=n_batch)['label_0'] ## (2*n, num_classes)\n",
    "    cur_mask = saved_data['mask'] #(n_labels)\n",
    "    \n",
    "    preds.append(cur_preds)\n",
    "    labels.append(cur_labels)\n",
    "    masks.append(cur_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "583448fb-4af4-4f2d-a5fa-3e70d8fa3e49",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "15"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "int(4000 ** (1/3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "82e0681d-b118-4e8a-9813-76651eee30f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "#(torch.argmax(preds[0], dim=1) != labels[0]).sum() / len(labels[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5cd9a9ba-fb90-4238-8553-3e39965bf7e4",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'n' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/var/folders/_s/jh8l69lx0wz3rhnd7rx1w5gr0000gn/T/ipykernel_27917/1577509907.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mnum_classes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m## 以下を，各n-samplesに対して繰り返す\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m     \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0mms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mp\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmasks\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;31m## あるサンプルの mask (mask: select the train/val split (S)の際． train dataかval dataか)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'n' is not defined"
     ]
    }
   ],
   "source": [
    "num_classes = 2\n",
    "\n",
    "for idx in range(n): ## 以下を，各n-samplesに対して繰り返す\n",
    "    loss = []\n",
    "    ms = [p[idx] for p in masks] ## あるサンプルの mask (mask: select the train/val split (S)の際． train dataかval dataか)\n",
    "    ps = [p[2*idx:2*idx+2] for p in preds] ## 異なる初期値から学習させたモデルのそれぞれに対し，　n-th sampleのsupersampleに対する予測確率 (e.g.; ３０個のモデルの，Z_{0,0},Z_{0,1}に対する予測確率)\n",
    "    ls = [l[2*idx:2*idx+2] for l in labels]\n",
    "    for i in range(len(ps)):\n",
    "        #ps[i] = torch.argmax(ps[i], dim=1) ## 各モデルごとに，　supersampleに対する予測確率の高いラベル番号に変換 (supersample間の予測ラベルを算出;サンプルの変更に頑健なら確率の変化は小さいのでラベルは一致(＝汎化))\n",
    "        ps[i] = torch.max(torch.softmax(ps[i],1), dim=1).values\n",
    "        loss.append((ps[i] - ls[i])**2)\n",
    "        #loss += (ps[i] != ls[i]).sum() / 30\n",
    "        #ps[i] = num_classes * ps[i][0] + ps[i][1] ## 損失関数の離散確率変数（(0,0),(0,1), (1,0), (1,1)）\n",
    "        #ps[i] = ps[i].item()\n",
    "    #loss = torch.exp(- torch.tensor(loss)) / torch.sum(torch.exp(- torch.tensor(loss)))\n",
    "    #loss = torch.tensor(loss)\n",
    "    loss = torch.concat(loss).reshape(-1,2)\n",
    "    #loss = loss.sum(1)\n",
    "    res2 = torch.tensor(ms).bool()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "09cca5ba-c2eb-46f9-9b72-9ef85b4dca10",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8114296582961815"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mutual_info_classif(loss, ms, discrete_features=[False, False]).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 272,
   "id": "a3faf42b-1dd1-4aa6-b1d5-64745ebe2ffd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neighbors import KernelDensity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 281,
   "id": "a18f7c3b-03f1-4a57-b479-e7a27df9d803",
   "metadata": {},
   "outputs": [],
   "source": [
    "#KernelDensity(loss)\n",
    "#X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])\n",
    "#kde = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(X)\n",
    "#kde.score_samples(X)\n",
    "kde = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(loss)\n",
    "prob = np.exp(kde.score_samples(loss)) ## since kde.score_samples() returns log-likelihood"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 280,
   "id": "493f29dd-3aa7-49f1-87c5-8dc3cdfd8abd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "119.27651352241102"
      ]
     },
     "execution_count": 280,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 282,
   "id": "0af3cc0b-da2a-457c-9804-e7808911748e",
   "metadata": {},
   "outputs": [],
   "source": [
    "pa = loss[res2==True]\n",
    "pb = loss[res2==False]\n",
    "kde_pa = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(pa)\n",
    "pa = np.exp(kde_pa.score_samples(pa))\n",
    "#print(kde_pa.score_samples(pa), len(kde_pa.score_samples(pa)))\n",
    "\n",
    "kde_pb = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(pb)\n",
    "#print(kde_pb.score_samples(pb), len(kde_pa.score_samples(pb)))\n",
    "pb = np.exp(kde_pa.score_samples(pb))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 287,
   "id": "b04174de-89ea-4fde-bede-f441e008fe74",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-404.5135622447754"
      ]
     },
     "execution_count": 287,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#prob[res2==True] * \n",
    "prob.sum() * (np.log(prob.sum()) - np.log(pa.sum()*pb.sum()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 183,
   "id": "ee348e50-05f6-49b6-8aa7-b25b56d20580",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "#torch.max(torch.softmax(preds[0],1),1).values \n",
    "#(torch.max(torch.softmax(ps[0],1),1).values - ls[0]) ** 2\n",
    "res = (torch.exp(-torch.concat(loss)) / torch.exp(-torch.concat(loss)).sum()).reshape(-1,2)\n",
    "#res = pd.DataFrame(res)\n",
    "#res2 = pd.DataFrame(ms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 191,
   "id": "c30b5e40-61d4-4092-9fcc-fad1f0b0f59c",
   "metadata": {},
   "outputs": [],
   "source": [
    "res2 = torch.tensor(ms).bool()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 228,
   "id": "e40e890b-fe18-4b43-a4ec-4d7bffadc8e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#res[res2==True]\n",
    "loss = torch.exp(-loss) / torch.sum(torch.exp(-loss))\n",
    "pa = loss[res2==True]\n",
    "pb = loss[res2==False]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 270,
   "id": "a0554d02-f743-44ef-91d0-a99676d72c31",
   "metadata": {},
   "outputs": [],
   "source": [
    "#prob.sum() * np.log(prob.sum()/(pa.sum()*pb.sum()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "477d6174-6937-4502-ab44-39a068b8ea5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_classes = 2\n",
    "\n",
    "for idx in range(n): ## 以下を，各n-samplesに対して繰り返す\n",
    "    ms = [p[idx] for p in masks] ## あるサンプルの mask (mask: select the train/val split (S)の際． train dataかval dataか)\n",
    "    ps = [p[2*idx:2*idx+2] for p in preds] ## 異なる初期値から学習させたモデルのそれぞれに対し，　n-th sampleのsupersampleに対する予測確率 (e.g.; ３０個のモデルの，Z_{0,0},Z_{0,1}に対する予測確率)\n",
    "    for i in range(len(ps)):\n",
    "        ps[i] = torch.argmax(ps[i], dim=1) ## 各モデルごとに，　supersampleに対する予測確率の高いラベル番号に変換 (supersample間の予測ラベルを算出;サンプルの変更に頑健なら確率の変化は小さいのでラベルは一致(＝汎化))\n",
    "        ps[i] = num_classes * ps[i][0] + ps[i][1] ## ???"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "21f86db9-fb6c-435c-89f2-2b824675438a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.08350420864501443\n"
     ]
    }
   ],
   "source": [
    "id = 50\n",
    "\n",
    "ms = [p[id] for p in masks] ## あるサンプルの mask (mask: select the train/val split (S)の際． train dataかval dataか)\n",
    "ps = [p[2*id:2*id+2] for p in preds] ## 異なる初期値から学習させたモデルのそれぞれに対し，　n-th sampleのsupersampleに対する予測確率 (e.g.; ３０個のモデルの，Z_{0,0},Z_{0,1}に対する予測確率)\n",
    "ls = [l[2*id:2*id+2] for l in labels]\n",
    "#loss = 0.\n",
    "for i in range(len(ps)):\n",
    "    ps[i] = torch.argmax(ps[i], dim=1) ## 各モデルごとに，　supersampleに対する予測確率の高い方のラベルに変換 (supersample間の予測ラベルを算出;サンプルの変更に頑健なら確率の変化は小さいのでラベルは一致(＝汎化))\n",
    "    #loss += (ps[i] != ls[i]).sum()\n",
    "    ls[i] = (ps[i] != ls[i]).sum() / 30\n",
    "    ls[i] = ls[i].item()\n",
    "    ps[i] = num_classes * ps[i][0] + ps[i][1] ## ???　--> supersampleに対する予測パターンのうち，どこに値するかを計算(２値の場合，予測は(0,0), (0,1), (1,0), (1,1)の4通り; 予測が(0,1)なら，2*0+1=1番目のパターン)\n",
    "    ps[i] = ps[i].item()\n",
    "#print(mi)\n",
    "prob = np.zeros((2, 4))\n",
    "ind = 0\n",
    "for a, b in zip(ms, ps):\n",
    "    ## 30個のモデルのうち， (mask,予測パターン)に関する同時確率．例えば，30個のモデルのうち，(mask=0,予測パターン=2)のモデルが4個あった場合の確率は， 4/30．\n",
    "    prob[a,b] += 1 / len(ms)\n",
    "    #ind += 1\n",
    "    #1.0/len(ms) ## {train or val}かどうかで(0,1)で条件付確率 (e.g., val時（ms=1）, supersampleに対する予測は（1,0） --> prob[0,3]; ２値の場合，予測は(0,0), (0,1), (1,0), (1,1)の4通り)\n",
    "#prob /= prob.sum()\n",
    "pa = np.sum(prob, axis=1) ## 各予測確率パターンに対する周辺化\n",
    "pb = np.sum(prob, axis=0) ## maskに対する周辺化 (TODO: Check this corresponds to the 0-1 loss --> maybe, this is not 0-1 loss; overall, this estimation is based on Cor. 2 of H.)\n",
    "mi = 0\n",
    "for a in range(2):\n",
    "    for b in range(4):\n",
    "        if prob[a,b] < 1e-9:\n",
    "            continue\n",
    "        mi += prob[a,b] * np.log(prob[a,b]/(pa[a]*pb[b]))\n",
    "print(mi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "de9171ee-efd6-4c17-aaef-425b393d9584",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.08350420864501451"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "id = 50\n",
    "\n",
    "ms = [p[id] for p in masks] ## あるサンプルの mask (mask: select the train/val split (S)の際． train dataかval dataか)\n",
    "ps = [p[2*id:2*id+2] for p in preds] ## 異なる初期値から学習させたモデルのそれぞれに対し，　n-th sampleのsupersampleに対する予測確率 (e.g.; ３０個のモデルの，Z_{0,0},Z_{0,1}に対する予測確率)\n",
    "ls = [l[2*id:2*id+2] for l in labels]\n",
    "#loss = 0.\n",
    "for i in range(len(ps)):\n",
    "    ps[i] = torch.argmax(ps[i], dim=1) ## 各モデルごとに，　supersampleに対する予測確率の高い方のラベルに変換 (supersample間の予測ラベルを算出;サンプルの変更に頑健なら確率の変化は小さいのでラベルは一致(＝汎化))\n",
    "    #loss += (ps[i] != ls[i]).sum()\n",
    "    ls[i] = (ps[i] != ls[i]).sum()\n",
    "    ls[i] = ls[i].item()\n",
    "\n",
    "mutual_info_classif(ps, ms, discrete_features=[True, True]).sum() ## discrete_features detects the dtype of \"X (= loss in this case)\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba2a222a-46c2-423a-9180-a58109e42976",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 316,
   "id": "a7733363-6f18-48e0-bc7d-030a0cc28cfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = np.zeros(2)\n",
    "id = 0\n",
    "for a in ms:\n",
    "    res[a] += ls[id]\n",
    "    id += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 372,
   "id": "114e13df-a659-4fc2-9b2c-30f83536a0fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "#ls\n",
    "ls = np.array(ls)\n",
    "ls_dist = ls / ls.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "96b13617-7e8f-44ae-b5ee-6a7b71e3b89d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#ls_dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "86f6ca57-5924-407b-9a9e-cfb5aebef64a",
   "metadata": {},
   "outputs": [],
   "source": [
    "prob = np.zeros((2, 4))\n",
    "ind = 0\n",
    "for a, b in zip(ms, ps):\n",
    "    ## 30個のモデルのうち， (mask,予測パターン)に関する同時確率．例えば，30個のモデルのうち，(mask=0,予測パターン=2)のモデルが4個あった場合の確率は， 4/30．\n",
    "    prob[a,b] += 1 / len(ms)\n",
    "    #ind += 1\n",
    "    #1.0/len(ms) ## {train or val}かどうかで(0,1)で条件付確率 (e.g., val時（ms=1）, supersampleに対する予測は（1,0） --> prob[0,3]; ２値の場合，予測は(0,0), (0,1), (1,0), (1,1)の4通り)\n",
    "#prob /= prob.sum()\n",
    "pa = np.sum(prob, axis=1) ## 各予測確率パターンに対する周辺化\n",
    "pb = np.sum(prob, axis=0) ## maskに対する周辺化 (TODO: Check this corresponds to the 0-1 loss --> maybe, this is not 0-1 loss; overall, this estimation is based on Cor. 2 of H.)\n",
    "mi = 0\n",
    "for a in range(2):\n",
    "    for b in range(4):\n",
    "        if prob[a,b] < 1e-9:\n",
    "            continue\n",
    "        mi += prob[a,b] * np.log(prob[a,b]/(pa[a]*pb[b]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2f234c26-846e-40cf-bbe6-10775066e73a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-1.160938590460336\n"
     ]
    }
   ],
   "source": [
    "#print(prob)\n",
    "#print(pa/30)\n",
    "#print(pb/30)\n",
    "print(mi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 364,
   "id": "7981673b-3af2-48e3-a71c-1ede2e279601",
   "metadata": {},
   "outputs": [],
   "source": [
    "#prob[0][2] = 0\n",
    "#prob[1][2] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0fe254e3-c381-454c-9b6c-6c03fe0979aa",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prob = np.zeros([2,4])\n",
    "prob[0][3] = 1.\n",
    "pa = np.sum(prob, axis=1) ## 各予測確率パターンに対する周辺化\n",
    "pb = np.sum(prob, axis=0) ## maskに対する周辺化 (TODO: Check this corresponds to the 0-1 loss --> maybe, this is not 0-1 loss; overall, this estimation is based on Cor. 2 of H.)\n",
    "mi = 0\n",
    "for a in range(2):\n",
    "    for b in range(4):\n",
    "        if prob[a,b] < 1e-9:\n",
    "            continue\n",
    "        mi += prob[a,b] * np.log(prob[a,b]/(pa[a]*pb[b]))\n",
    "mi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 294,
   "id": "9e8fefb9-0eb2-4860-81d5-14c4d2fb7088",
   "metadata": {},
   "outputs": [],
   "source": [
    "pa = np.sum(prob, axis=1) ## 各予測確率パターンに対する周辺化\n",
    "pb = np.sum(prob, axis=0) ## maskに対する周辺化 (TODO: Check this corresponds to the 0-1 loss --> maybe, this is not 0-1 loss; overall, this estimation is based on Cor. 2 of H.)\n",
    "mi = 0\n",
    "for a in range(2):\n",
    "    for b in range(4):\n",
    "        if prob[a,b] < 1e-9:\n",
    "            continue\n",
    "        mi += prob[a,b] * np.log(prob[a,b]/(pa[a]*pb[b]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 299,
   "id": "4fe5fa99-f316-417c-9169-16f4a22864e4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0. , 0. , 0.5, 0. ],\n",
       "       [0. , 0. , 0.5, 0. ]])"
      ]
     },
     "execution_count": 299,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 354,
   "id": "46bad22c-b3d9-410e-af0c-c22baf3516e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "#discrete_mi_est(ms, ps, nx=2, ny=num_classes**2) ## See App. B of Harutyunyan et al., (2021).\n",
    "prob = np.zeros((2, 4))\n",
    "for a, b in zip(ms, ps):\n",
    "    ## 30個の予測出力のうち， (mask,予測パターン)に関する同時確率．例えば，30個の出力のうち，(mask=0,予測パターン=2)の出力が4個あった場合の確率は， 4/30．\n",
    "    prob[a,b] += 1.0/len(ms) ## {train or val}かどうかで(0,1)で条件付確率 (e.g., val時（ms=1）, supersampleに対する予測は（1,0） --> prob[0,3]; ２値の場合，予測は(0,0), (0,1), (1,0), (1,1)の4通り)\n",
    "pa = np.sum(prob, axis=1) ## 各予測確率パターンに対する周辺化\n",
    "pb = np.sum(prob, axis=0) ## maskに対する周辺化 (TODO: Check this corresponds to the 0-1 loss --> maybe, this is not 0-1 loss; overall, this estimation is based on Cor. 2 of H.)\n",
    "mi = 0\n",
    "for a in range(2):\n",
    "    for b in range(4):\n",
    "        if prob[a,b] < 1e-9:\n",
    "            continue\n",
    "        mi += prob[a,b] * np.log(prob[a,b]/(pa[a]*pb[b]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 355,
   "id": "564ec182-2f26-405b-915d-83c5ef15953d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.08350420864501443"
      ]
     },
     "execution_count": 355,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 254,
   "id": "5647f78b-d929-4f27-aa3f-85c56896f13f",
   "metadata": {},
   "outputs": [],
   "source": [
    "id = 2\n",
    "\n",
    "ms = [p[id] for p in masks] ## あるサンプルのlabel mask (mask: select the train/val split (S)の際． train dataかval dataか)\n",
    "ps = [p[2*id:2*id+2] for p in preds] ## 異なる初期値から学習させたモデルのそれぞれに対し，　n-th sampleのsupersampleに対する予測確率 (e.g.; ３０個のモデルの，Z_{0,0},Z_{0,1}に対する予測確率)\n",
    "\n",
    "for i in range(len(ps)):\n",
    "    ps[i] = torch.argmax(ps[i], dim=1) ## 各モデルごとに，　supersampleに対する予測確率の高い方のラベルに変換 (supersample間の予測ラベルを算出;サンプルの変更に頑健なら確率の変化は小さいのでラベルは一致(＝汎化))\n",
    "#    ps[i] = num_classes * ps[i][0] + ps[i][1] ## ???　--> supersampleに対する予測パターンのうち，どこに値するかを計算(２値の場合，予測は(0,0), (0,1), (1,0), (1,1)の4通り; 予測が(0,1)なら，2*0+1=1番目のパターン)\n",
    "#    ps[i] = ps[i].item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 273,
   "id": "9d0ebce2-14df-4f2d-b993-5a7f534fda23",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([10.2729, 27.0437])"
      ]
     },
     "execution_count": 273,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## 各mask(0,1)×予測パターン((0,0), (0,1), (1,0), (1,1)の4通り)それぞれに対して（2*4=8通り），損失間数値の\n",
    "#pl = torch.argmax(ps[0], dim=1)\n",
    "#(torch.argmax(ps[0], dim=1) - torch.max(ps[0],dim=1).values).pow(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 193,
   "id": "2fdb1d64-eda4-4d00-8d56-1fe1935416a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def discrete_mi_est(xs, ys, nx=2, ny=2):\n",
    "    prob = np.zeros((nx, ny))\n",
    "    for a, b in zip(xs, ys):\n",
    "        prob[a,b] += 1.0/len(xs)\n",
    "    pa = np.sum(prob, axis=1)\n",
    "    pb = np.sum(prob, axis=0)\n",
    "    mi = 0\n",
    "    for a in range(nx):\n",
    "        for b in range(ny):\n",
    "            if prob[a,b] < 1e-9:\n",
    "                continue\n",
    "            mi += prob[a,b] * np.log(prob[a,b]/(pa[a]*pb[b]))\n",
    "    return max(0.0, mi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7678134a-8d36-46f3-9b07-1ecb5f40b0ac",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d35f2406-0599-452e-9461-b38dd45f8dbe",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "id": "94aa1ef0-4732-4a5e-a74c-8613b12d753a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([  0,   2,   4,   6,   8,  10,  12,  14,  16,  18,  20,  22,  24,\n",
       "        26,  28,  30,  32,  34,  36,  38,  40,  42,  44,  46,  48,  50,\n",
       "        52,  54,  56,  58,  60,  62,  64,  66,  68,  70,  72,  74,  76,\n",
       "        78,  80,  82,  84,  86,  88,  90,  92,  94,  96,  98, 100, 102,\n",
       "       104, 106, 108, 110, 112, 114, 116, 118, 120, 122, 124, 126, 128,\n",
       "       130, 132, 134, 136, 138, 140, 142, 144, 146, 148])"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "2 * np.arange(len(masks[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "id": "3342d12a-a53b-46fe-870a-fb7fa35d31ed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1,\n",
       "       1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1,\n",
       "       1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0,\n",
       "       0, 0, 0, 0, 1, 1, 0, 0, 0])"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "masks[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 369,
   "id": "c29a99e6-1b00-44c3-8b5c-502f02dcfd63",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5"
      ]
     },
     "execution_count": 369,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0d15e0d0-b834-415a-9bdf-c7e63f3257b8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mutual Information: [0.00620743 0.04065212]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from sklearn.feature_selection import mutual_info_classif\n",
    "\n",
    "# 例として、ランダムなデータでテスト\n",
    "np.random.seed(42)\n",
    "data_size = 1000\n",
    "x = np.random.randn(data_size, 2)\n",
    "y = np.random.choice([0, 1], size=data_size)\n",
    "\n",
    "# 相互情報量を計算\n",
    "mi_values = mutual_info_classif(x, y, discrete_features=[False, False])\n",
    "print(\"Mutual Information:\", mi_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 314,
   "id": "86faa47a-a4d5-4e04-b37d-5e8f96b751fc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8114296582961815"
      ]
     },
     "execution_count": 314,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mutual_info_classif(loss, ms, discrete_features=[False, False]).sum() ## discrete_features detects the dtype of \"X (= loss in this case)\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 307,
   "id": "f0e7c69c-fc01-44f4-ba4e-1c236be674dd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dtype('int64')"
      ]
     },
     "execution_count": 307,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(ms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bc28df6-dfa6-4899-bf34-fdaf2fd52c9f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 328,
   "id": "e172578a-4298-46fb-8e1d-2d85f27d47dc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.08350420864501451"
      ]
     },
     "execution_count": 328,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8da6781-bf42-4e2a-b622-550d60540dea",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
