{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 488,
   "id": "1e1f360e-1f9a-4bdd-af01-7430bc14cfcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from torch.distributions import Categorical\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets, transforms\n",
    "from torch.utils.data import ConcatDataset, DataLoader\n",
    "from torch.utils.data.dataset import Subset\n",
    "import numpy as np\n",
    "from sklearn.feature_selection import mutual_info_classif\n",
    "\n",
    "from model import GaussianSQVAE, SQVAE\n",
    "\n",
    "import argparse\n",
    "from configs.defaults import get_cfgs_defaults\n",
    "\n",
    "from main import arg_parse, load_config\n",
    "from util import set_seeds, get_loader, get_loader_ecmi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4bdaa93-7439-4634-8d58-6bd3acf1c2e9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 489,
   "id": "db872fed-40db-45b4-af40-d57a6678c614",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_encode_kld(model, z_from_encoder, var_q, codebook, flg_quant_det=True):\n",
    "    bs, dim_z, width, height = z_from_encoder.shape\n",
    "    z_from_encoder_permuted = z_from_encoder.permute(0, 2, 3, 1).contiguous()\n",
    "    precision_q = 1. / torch.clamp(var_q, min=1e-10)\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    \n",
    "    logit = -model.quantizer._calc_distance_bw_enc_codes(z_from_encoder_permuted, codebook, 0.5 * precision_q)\n",
    "    probabilities = torch.softmax(logit, dim=-1)\n",
    "    log_probabilities = torch.log_softmax(logit, dim=-1)\n",
    "\n",
    "    if flg_quant_det:\n",
    "        indices = torch.argmax(logit, dim=1).unsqueeze(1)\n",
    "        encodings_hard = torch.zeros(indices.shape[0], model.quantizer.size_dict, device=device)\n",
    "        encodings_hard.scatter_(1, indices, 1)\n",
    "        #avg_probs = torch.mean(encodings_hard, dim=0)\n",
    "    else:\n",
    "        dist = Categorical(probabilities)\n",
    "        indices = dist.sample().view(bs, width, height)\n",
    "        encodings_hard = F.one_hot(indices, num_classes=model.quantizer.size_dict).type_as(codebook)\n",
    "        #avg_probs = torch.mean(probabilities, dim=0)\n",
    "\n",
    "    z_quantized = torch.matmul(encodings_hard, codebook).view(bs, width, height, dim_z)\n",
    "    z_to_decoder = z_quantized.permute(0, 3, 1, 2).contiguous()\n",
    "    \n",
    "    #kld_continuous = model.quantizer._calc_distance_bw_enc_dec(z_from_encoder, z_to_decoder, 0.5 * precision_q).mean()\n",
    "    kld_continuous = model.quantizer._calc_distance_bw_enc_dec(z_from_encoder, z_to_decoder, 0.5 * precision_q)\n",
    "\n",
    "    return encodings_hard, kld_continuous, z_to_decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 551,
   "id": "da5087a4-2652-4fe3-a85e-239044991c87",
   "metadata": {},
   "outputs": [],
   "source": [
    "dir_name = f\"ecmi_{'resnet'}, n={4000}, seed={0}, S_seed={0}, K={128}, d_dict={64}\"\n",
    "#dir_path = os.path.join('checkpoint', 'mnist_sqvae_gaussian_1', 'ecmi_resnet, n=4000, seed=0, S_seed=0')\n",
    "dir_path = os.path.join('checkpoint', 'mnist_sqvae_gaussian_1', dir_name)\n",
    "res_path = os.path.join('checkpoint', 'mnist_sqvae_gaussian_1')\n",
    "if not os.path.exists(dir_path):\n",
    "    print(f\"Did not find results for {dir_name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 539,
   "id": "c2e3060d-a890-4dc0-a586-cb4a94c421a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "device = torch.device('mps')\n",
    "\n",
    "with open(os.path.join(dir_path, 'current.pt'), 'rb') as f:\n",
    "    #state_dict = torch.load(f, map_location=device)\n",
    "    state_dict = torch.load(f, map_location=device)\n",
    "\n",
    "# 'module.'のプレフィックスを削除\n",
    "new_state_dict = {}\n",
    "for key, value in state_dict.items():\n",
    "    new_key = key.replace('module.', '')  # 'module.'を削除\n",
    "    new_state_dict[new_key] = value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 526,
   "id": "3b9a96b2-c8a8-4736-9093-efdad4bd8e86",
   "metadata": {},
   "outputs": [],
   "source": [
    "#saved_model['module.encoder.res_m.block.5.running_mean']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 547,
   "id": "e59b42fe-deb3-4a47-8f7d-ee4649a584c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "\n",
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument('-c', '--config_file', type=str, help='Config file path')\n",
    "parser.add_argument('-ts', '--timestamp', type=str, help='Timestamp')\n",
    "parser.add_argument('--save', action='store_true', help='Save the results')\n",
    "parser.add_argument('--dbg', action='store_true', help='Debug mode')\n",
    "parser.add_argument('--gpu', type=int, default=0, help='GPU device ID')\n",
    "parser.add_argument('--ecmi', action='store_true', help='Use ECMI')\n",
    "parser.add_argument('--n', type=int, default=4000, help='Number of samples')\n",
    "parser.add_argument('--K', '-K', type=int, default=128, help='Number of size_dict')\n",
    "parser.add_argument(\n",
    "        '--d_dict', type=int, default=64, help='Number of dim_dict')\n",
    "parser.add_argument('--seed', type=int, default=0, help='Random seed')\n",
    "parser.add_argument('--S_seed', type=int, default=0, help='S seed')\n",
    "\n",
    "args, unknown = parser.parse_known_args()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 548,
   "id": "9a6afe75-8d77-413e-a695-c95fbc9871dc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "128"
      ]
     },
     "execution_count": 548,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "args.K"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 528,
   "id": "b06f5854-1c69-4362-a391-79e74ba71ff5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/fujisawa/Dropbox/research/VAE_generalization/sqvae/vision/configs/mnist_gauss_1.yaml\n"
     ]
    }
   ],
   "source": [
    "cfgs = get_cfgs_defaults()\n",
    "args.config_file = \"mnist_gauss_1.yaml\"\n",
    "args.ecmi = True\n",
    "## args.S_seed = hoge\n",
    "\n",
    "if args.gpu != \"\":\n",
    "    if torch.cuda.is_available():\n",
    "        os.environ[\"CUDA_VISIBLE_DEVICES\"] = args.gpu\n",
    "    else:\n",
    "        pass\n",
    "cfgs, flgs = load_config(args)\n",
    "#print(\"[Checkpoint path] \"+cfgs.path)\n",
    "#print(cfgs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 549,
   "id": "9930122b-30d1-4f1e-a7f5-66c867e695a6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "64"
      ]
     },
     "execution_count": 549,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "args.d_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 529,
   "id": "3d9f716b-9594-439b-9dc4-50fdf4b1d1f9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 529,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "args.gpu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 540,
   "id": "9224acc8-b1bf-41dc-825f-fdea9d044475",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 540,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = GaussianSQVAE(cfgs, flgs)\n",
    "model.load_state_dict(new_state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 541,
   "id": "c490a7f1-a5ad-4ea2-a30d-92c2c81be38f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 486,
   "id": "29a2aee5-0be8-4c99-a697-52fcaedcafaf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CfgNode({'path': 'checkpoint', 'path_dataset': 'dataset', 'nworker': 2, 'list_dir_for_copy': ['', 'networks/'], 'dataset': CfgNode({}), 'model': CfgNode({}), 'network': CfgNode({}), 'train': CfgNode({'bs': 32, 'lr': 0.001, 'epoch_max': 100}), 'quantization': CfgNode({'temperature': CfgNode({'init': 1.0, 'decay': 1e-05, 'min': 0.0})}), 'test': CfgNode({'bs': 50}), 'flags': CfgNode({'arelbo': True, 'decay': True, 'bn': True})})"
      ]
     },
     "execution_count": 486,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 509,
   "id": "860f142f-bab0-425d-bb17-1e2187abd670",
   "metadata": {},
   "outputs": [],
   "source": [
    "#model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 545,
   "id": "40ba414b-172a-4fd7-90db-ebe7d67bf685",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CfgNode({'path': 'checkpoint/mnist_sqvae_gaussian_1/', 'path_dataset': 'dataset', 'nworker': 2, 'list_dir_for_copy': ['', 'networks/'], 'dataset': CfgNode({'name': 'MNIST', 'shape': (1, 28, 28), 'dim_x': 784}), 'model': CfgNode({'name': 'GaussianSQVAE', 'log_param_q_init': 2.302585092994046, 'param_var_q': 'gaussian_1'}), 'network': CfgNode({'name': 'resnet', 'num_rb': 2}), 'train': CfgNode({'bs': 32, 'lr': 0.001, 'epoch_max': 100, 'seed': 0}), 'quantization': CfgNode({'temperature': CfgNode({'init': 1.0, 'decay': 1e-05, 'min': 0.0}), 'size_dict': 128, 'dim_dict': 64}), 'test': CfgNode({'bs': 50}), 'flags': CfgNode({'arelbo': True, 'decay': True, 'bn': True, 'save': False, 'noprint': True, 'var_q': False}), 'path_specific': 'mnist_sqvae_gaussian_1/', 'path_data': 'checkpoint'})"
      ]
     },
     "execution_count": 545,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#flgs\n",
    "#args\n",
    "cfgs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 427,
   "id": "dd87e79e-eb23-4991-9aa1-53fe26bb8c6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 100\n",
    "#step = (epoch - 1) * len(self.train_loader) + batch_idx + 1\n",
    "step = (epoch - 1) * 4000 + 249 + 1\n",
    "temperature_current = np.max([1.0 * np.exp(-1e-05*step), 0.0])\n",
    "#self.model.module.quantizer.set_temperature(temperature_current)\n",
    "\n",
    "model.quantizer.set_temperature(temperature_current)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 428,
   "id": "6a742c76-00a5-43a0-902a-eee13da452d9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.019015516028502245"
      ]
     },
     "execution_count": 428,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.quantizer.temperature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 510,
   "id": "5da2649b-5bed-4562-8e53-b74c52652c7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"MNIST\"\n",
    "preproc_transform = transforms.Compose([\n",
    "            transforms.ToTensor(),\n",
    "        ])\n",
    "\n",
    "all_size = 60000\n",
    "np.random.seed(args.seed)\n",
    "include_indices = np.random.choice(range(all_size), size=2 * args.n, replace=False)\n",
    "\n",
    "trainval_dataset = eval(\"datasets.\"+dataset)(\n",
    "                os.path.join(cfgs.path_dataset, \"{}/\".format(dataset)),\n",
    "                train=True, download=True, transform=preproc_transform\n",
    "        )\n",
    "all_examples = Subset(trainval_dataset, include_indices)\n",
    "train_val_loader = torch.utils.data.DataLoader(\n",
    "            all_examples, batch_size=cfgs.train.bs, shuffle=False,\n",
    "            num_workers=cfgs.nworker, pin_memory=False\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 511,
   "id": "9a855bfe-32d3-4b73-ac2b-7228e0445142",
   "metadata": {},
   "outputs": [],
   "source": [
    "mask = np.random.randint(2, size=(args.n,)) ## Ber(1/2)\n",
    "subset1_indices = 2*np.arange(args.n) + mask\n",
    "subset2_indices = 2*np.arange(args.n) + (1-mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 512,
   "id": "6d95dff4-7b74-4d61-ac27-de5d59859eb1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CfgNode({'arelbo': True, 'decay': True, 'bn': True, 'save': False, 'noprint': True, 'var_q': False})"
      ]
     },
     "execution_count": 512,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#model.param_var_q\n",
    "flgs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 543,
   "id": "e5a78f36-664f-4ffd-8d74-70ac21e29297",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mps:0\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "Placeholder storage has not been allocated on MPS device!",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[543], line 14\u001b[0m\n\u001b[1;32m     11\u001b[0m z \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mencoder(x)\n\u001b[1;32m     12\u001b[0m \u001b[38;5;66;03m#print(z.device)\u001b[39;00m\n\u001b[1;32m     13\u001b[0m \u001b[38;5;66;03m#enc_hard, kld, z_to_decoder = calc_encode_kld(model, z, param_q, model.codebook, False)\u001b[39;00m\n\u001b[0;32m---> 14\u001b[0m enc_hard, kld, z_quantized \u001b[38;5;241m=\u001b[39m \u001b[43mcalc_encode_kld\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mz\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparam_q\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcodebook\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m     15\u001b[0m x_reconst \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mdecoder(z_to_decoder)\n\u001b[1;32m     16\u001b[0m loss_list\u001b[38;5;241m.\u001b[39mappend(F\u001b[38;5;241m.\u001b[39mmse_loss(x_reconst, x, reduction\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mnone\u001b[39m\u001b[38;5;124m'\u001b[39m)\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, x\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m*\u001b[39mx\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m2\u001b[39m]\u001b[38;5;241m*\u001b[39mx\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m3\u001b[39m])\u001b[38;5;241m.\u001b[39msum(\u001b[38;5;241m1\u001b[39m))\n",
      "Cell \u001b[0;32mIn[489], line 14\u001b[0m, in \u001b[0;36mcalc_encode_kld\u001b[0;34m(model, z_from_encoder, var_q, codebook, flg_quant_det)\u001b[0m\n\u001b[1;32m     12\u001b[0m     indices \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39margmax(logit, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m     13\u001b[0m     encodings_hard \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mzeros(indices\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m], model\u001b[38;5;241m.\u001b[39mquantizer\u001b[38;5;241m.\u001b[39msize_dict, device\u001b[38;5;241m=\u001b[39mdevice)\n\u001b[0;32m---> 14\u001b[0m     \u001b[43mencodings_hard\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscatter_\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m     15\u001b[0m     \u001b[38;5;66;03m#avg_probs = torch.mean(encodings_hard, dim=0)\u001b[39;00m\n\u001b[1;32m     16\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m     17\u001b[0m     dist \u001b[38;5;241m=\u001b[39m Categorical(probabilities)\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Placeholder storage has not been allocated on MPS device!"
     ]
    }
   ],
   "source": [
    "log_var_q = torch.tensor([0.0], device=device)\n",
    "param_q = log_var_q.exp() + model.log_param_q_scalar.exp()\n",
    "\n",
    "enc_hard_list = []\n",
    "kld_list = []\n",
    "loss_list = []\n",
    "with torch.no_grad():\n",
    "    for x, _ in train_val_loader:\n",
    "        x = x.to(device)\n",
    "        print(x.device)\n",
    "        z = model.encoder(x)\n",
    "        #print(z.device)\n",
    "        #enc_hard, kld, z_to_decoder = calc_encode_kld(model, z, param_q, model.codebook, False)\n",
    "        enc_hard, kld, z_quantized = calc_encode_kld(model, z, param_q, model.codebook, True)\n",
    "        x_reconst = model.decoder(z_to_decoder)\n",
    "        loss_list.append(F.mse_loss(x_reconst, x, reduction='none').view(-1, x.shape[1]*x.shape[2]*x.shape[3]).sum(1))\n",
    "        \n",
    "        enc_hard_list.append(enc_hard)\n",
    "        kld_list.append(kld)\n",
    "    \n",
    "    kld_list = torch.concat(kld_list)\n",
    "    enc_hard_list = torch.concat(enc_hard_list) #.reshape(2*args.n, 7, 7, -1)\n",
    "    loss_list = torch.concat(loss_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa1d7233-3dd7-4257-9501-a5d54cee4c56",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 460,
   "id": "766b24c4-89be-4f18-8a45-67d335211f11",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(-0.0910)"
      ]
     },
     "execution_count": 460,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss_list[subset1_indices].mean() - loss_list[subset2_indices].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 461,
   "id": "3d45ec36-4450-493b-83df-55356621fd86",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(178.2786)"
      ]
     },
     "execution_count": 461,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kld_list.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 448,
   "id": "aa9cecaa-4d58-4dfb-99a3-601651f3adea",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(104.8762)"
      ]
     },
     "execution_count": 448,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = [subset1_indices, subset2_indices]\n",
    "loss_list[a[0]].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 455,
   "id": "b136b795-d5a6-4066-8138-96a6331f6cfa",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 455,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "args.seed = 2\n",
    "args.seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 261,
   "id": "4292fdfd-48c2-4c83-af7a-2cf7812c45c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "ms = torch.zeros(2*args.n)\n",
    "ms[subset1_indices] = 0\n",
    "ms[subset2_indices] = 1\n",
    "#ms = ms.reshape(-1,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 262,
   "id": "e39b4b9a-d131-4671-b789-0735b1f969b6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8000])"
      ]
     },
     "execution_count": 262,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#np.tile(ms.reshape(-1,1),7*7*128).reshape(-1, 7, 7, 128)\n",
    "#max(0., mutual_info_classif(f_out, ms, discrete_features=[False, False]).sum())\n",
    "#mutual_info_classif(f_out, ms, discrete_features=[False, False]).sum()\n",
    "#mutual_info_classif(enc_hard_list, ms, discrete_features=[False, False])\n",
    "#enc_hard_list.shape\n",
    "ms.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 369,
   "id": "f41d053d-0f12-43f0-85e6-a729efabfefe",
   "metadata": {},
   "outputs": [],
   "source": [
    "cur_mi = max(0, mutual_info_classif(enc_hard_list.view(2*args.n,-1), ms, discrete_features=False).sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 370,
   "id": "c41f272e-bace-4651-8560-96a981ce3601",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0017176026033617656"
      ]
     },
     "execution_count": 370,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cur_mi/8000\n",
    "#250*32*7*7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 296,
   "id": "ff799199-27a6-41e9-898d-98cf1330c741",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9048374180359595"
      ]
     },
     "execution_count": 296,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.max([1.0 * np.exp(-1e-05*10000), 0.0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 372,
   "id": "d68d2864-ab6a-43a5-99be-d821da4bfdf4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.01901532587429273"
      ]
     },
     "execution_count": 372,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.quantizer.temperature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 362,
   "id": "46e22961-2995-4691-9673-addc0e16833b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.01901532587429273"
      ]
     },
     "execution_count": 362,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "epoch = 100\n",
    "#step = (epoch - 1) * len(self.train_loader) + batch_idx + 1\n",
    "step = (epoch - 1) * 4000 + 250 + 1\n",
    "np.max([1.0 * np.exp(-1e-05*step), 0.0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 606,
   "id": "aa113dd8-f32e-4bb0-824e-de7b74ec06c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "class NestedDict(dict):\n",
    "    def __missing__(self, key):\n",
    "        self[key] = type(self)()\n",
    "        return self[key]\n",
    "\n",
    "#with open(os.path.join(res_path, 'results_ecmi_mnist_sqvae_gaussian_1.pkl'), 'rb') as f:\n",
    "with open(os.path.join(res_path, 'results_ecmi_mnist_sqvae_gaussian_1_16.pkl'), 'rb') as f:\n",
    "#with open(os.path.join(res_path, 'results_ecmi_mnist_sqvae_gaussian_1_32.pkl'), 'rb') as f:\n",
    "#with open(os.path.join(res_path, 'results_ecmi_mnist_sqvae_gaussian_1_64.pkl'), 'rb') as f:\n",
    "    #res = torch.load(f, map_location=device)\n",
    "    res = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 607,
   "id": "3394412b-e882-4d8e-9e6c-379e360f8469",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.008802414"
      ]
     },
     "execution_count": 607,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res[4000][0]['train_loss'] - res[4000][0]['test_loss']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 586,
   "id": "a1f72dc2-9504-4fc0-aed6-3f1e20830d28",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.038240433"
      ]
     },
     "execution_count": 586,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res[4000][0]['train_loss'] - res[4000][0]['test_loss']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 608,
   "id": "77909284-096e-4825-9288-eb6d91ca7a07",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{4000: [{'ecmi': 1.6712565922739504,\n",
       "   'kl': 807614.06,\n",
       "   'train_loss': 3.6135154,\n",
       "   'test_loss': 3.604713},\n",
       "  {'ecmi': 1.7084151000625796,\n",
       "   'kl': 812559.2,\n",
       "   'train_loss': 3.6981559,\n",
       "   'test_loss': 3.7066467},\n",
       "  {'ecmi': 1.643432013733395,\n",
       "   'kl': 835984.5,\n",
       "   'train_loss': 3.6630695,\n",
       "   'test_loss': 3.712674}]}"
      ]
     },
     "execution_count": 608,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0406313-b60b-4693-b0f6-1800327a55cf",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
