{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/eavnjeong/anaconda3/envs/ebm/lib/python3.7/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from tool.args import get_general_args\n",
    "from tool.util import init_wandb\n",
    "from train.mlbase import MLBase\n",
    "from evaluate.evaluator import Evaluator\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from data.dl_getter import DATASETS, n_cls, sh, input_range\n",
    "\n",
    "import pandas as pd\n",
    "import argparse\n",
    "\n",
    "import numpy as np\n",
    "import sys\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from tool.util import set_seed, bool_flag\n",
    "from datetime import datetime\n",
    "\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = get_general_args(is_nb=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed(args.random_seed)\n",
    "args.bsz = 100\n",
    "args.bsz_vl = 100\n",
    "args.arch = 'wrn-28-10'\n",
    "args.method = 'evaluate'\n",
    "args.wandb_entity = 'eavnjeong'\n",
    "args.exp_load = 'analysis'\n",
    "args.head = \"ep\"\n",
    "args.cos_sq = True\n",
    "args.alpha = 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_wandb(args)\n",
    "eval = Evaluator(MLBase((args)))\n",
    "\n",
    "model = eval.model\n",
    "tr_dl = eval.tr_dl\n",
    "vl_dl = eval.vl_dl\n",
    "\n",
    "tr_dl.dataset.transforms = vl_dl.dataset.transform\n",
    "tr_dl.dataset.transform = vl_dl.dataset.transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9482"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    correct = 0\n",
    "\n",
    "    for x, y in vl_dl:\n",
    "        x = x.cuda()\n",
    "        y = y.cuda()\n",
    "        out = model(x)\n",
    "        # out = torch.matmul(out, model.head.num.ms.T)\n",
    "        correct += (out.argmax(dim=1) == y).sum().item()\n",
    "correct/10000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    in_latent, in_label = [], []\n",
    "    model.eval()\n",
    "    for x, y in tr_dl:\n",
    "        x = x.cuda()\n",
    "        latent = model.enc(x)\n",
    "        in_latent.append(latent)\n",
    "        in_label.append(y)\n",
    "in_latent = torch.cat(in_latent)\n",
    "in_label = torch.cat(in_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from data.dl_getter import get_transform\n",
    "import torchvision as tv\n",
    "from torch.utils.data import DataLoader\n",
    "import torchvision.transforms as tr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_workers = 4\n",
    "\n",
    "svhn_ds = tv.datasets.SVHN(\n",
    "        root=\"~/data\", transform=vl_dl.dataset.transform,\n",
    "        download=False, split=\"train\")\n",
    "cifar100_ds = tv.datasets.CIFAR100(\n",
    "        root=\"~/data\", transform=vl_dl.dataset.transform,\n",
    "        download=False, train=True)\n",
    "celeba_ds = tv.datasets.CelebA(\n",
    "        root=\"~/data\", download=False, split=\"train\",\n",
    "        transform=tr.Compose([tr.Resize(32), vl_dl.dataset.transform]))\n",
    "\n",
    "svhn_dl = DataLoader(\n",
    "    svhn_ds, batch_size=100, shuffle=True,\n",
    "    num_workers=num_workers, drop_last=False)\n",
    "cifar100_dl = DataLoader(\n",
    "    cifar100_ds, batch_size=100, shuffle=True,\n",
    "    num_workers=num_workers, drop_last=False)\n",
    "celeba_dl = DataLoader(\n",
    "    celeba_ds, batch_size=100, shuffle=True,\n",
    "    num_workers=num_workers, drop_last=False)\n",
    "cifar10_interp = DataLoader(\n",
    "    tv.datasets.CIFAR10(\n",
    "        root=\"~/data\", transform=get_transform('cifar10'),\n",
    "        download=False, train=False),\n",
    "    batch_size=100, shuffle=False,\n",
    "    num_workers=num_workers, drop_last=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma = 0.\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    odd_latent_svhn, svhn_labels = [], []\n",
    "    ood_latent_cifar100, cifar100_labels = [], []\n",
    "    ood_latent_celeba, celeba_labels = [], []\n",
    "    ood_latent_cifer10_interp, cifar10_interp_labels = [], []\n",
    "\n",
    "    for x, y in svhn_dl:\n",
    "        x = x.cuda()\n",
    "        latent = model.enc(x)\n",
    "        odd_latent_svhn.append(latent)\n",
    "        svhn_labels.append(y)\n",
    "    for x, y in cifar100_dl:\n",
    "        x = x.cuda()\n",
    "        latent = model.enc(x)\n",
    "        ood_latent_cifar100.append(latent)\n",
    "        cifar100_labels.append(y)\n",
    "    for x, y in celeba_dl:\n",
    "        x = x.cuda()\n",
    "        latent = model.enc(x)\n",
    "        ood_latent_celeba.append(latent)\n",
    "        celeba_labels.append(y)\n",
    "    for i, (x, y) in enumerate(cifar10_interp):\n",
    "        x = x.cuda()\n",
    "        if i > 0:\n",
    "            x_mix = (x + last_batch) / 2 + sigma * torch.randn_like(x)\n",
    "            latent = model.enc(x_mix)\n",
    "            ood_latent_cifer10_interp.append(latent)\n",
    "            cifar10_interp_labels.append(y)\n",
    "        last_batch = x\n",
    "\n",
    "ood_latent_svhn = torch.cat(odd_latent_svhn)\n",
    "svhn_labels = torch.cat(svhn_labels)\n",
    "ood_latent_cifar100 = torch.cat(ood_latent_cifar100)\n",
    "cifar100_labels = torch.cat(cifar100_labels)\n",
    "ood_latent_celeba = torch.cat(ood_latent_celeba)\n",
    "celeba_labels = torch.cat(celeba_labels)\n",
    "ood_latent_cifer10_interp = torch.cat(ood_latent_cifer10_interp)\n",
    "cifar10_interp_labels = torch.cat(cifar10_interp_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([50000, 640]),\n",
       " torch.Size([73257, 640]),\n",
       " torch.Size([50000, 640]),\n",
       " torch.Size([162770, 640]),\n",
       " torch.Size([9900, 640]))"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "in_latent.shape, \\\n",
    "ood_latent_svhn.shape,\\\n",
    "ood_latent_cifar100.shape, \\\n",
    "ood_latent_celeba.shape, \\\n",
    "ood_latent_cifer10_interp.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9953428506851196 0.001597103662788868 0.014472606591880322 0.009502946399152279\n"
     ]
    }
   ],
   "source": [
    "origin = 0\n",
    "softmax = nn.Softmax(1)\n",
    "w = model.head.num.ms\n",
    "lower_bound = np.linspace(0, 1, 21)[:-1]\n",
    "\n",
    "tmp_latent = in_latent - origin\n",
    "m = [tmp_latent[in_label == lbl].mean(0) for lbl in range(10)]\n",
    "m = torch.stack(m)\n",
    "# top1_logits, top2_logits = [], []\n",
    "prob1, prob2 = [], []\n",
    "\n",
    "top1_cos_theta_mean, top1_cos_theta_std = [], []\n",
    "top2_cos_theta_mean, top2_cos_theta_std = [], []\n",
    "\n",
    "top1_logits_mean, top2_logits_mean = [], []\n",
    "top1_logits_std, top2_logits_std = [], []\n",
    "counts = []\n",
    "z_norm_means, z_norm_stds = [], []\n",
    "\n",
    "top1_dist_mean, top2_dist_mean = [], []\n",
    "top1_dist_std, top2_dist_std = [], []\n",
    "\n",
    "with torch.no_grad():\n",
    "    # [in_latent, ood_latent_svhn, ood_latent_cifar100, ood_latent_celeba, ood_cifar10_interp]\n",
    "    for latent in [in_latent]:\n",
    "        latent = latent - origin\n",
    "\n",
    "        logits = model.head(latent)\n",
    "        prob, pred = torch.topk(softmax(logits), 2)\n",
    "        logits_k, _ = torch.topk(logits, 2)\n",
    "\n",
    "        print(prob[:, 0].mean().item(), prob[:, 1].mean().item(), \\\n",
    "            prob[:, 0].std().item(), prob[:, 1].std().item())\n",
    "\n",
    "        top1_prob = prob[:, 0].detach().cpu()\n",
    "        top2_prob = prob[:, 1].detach().cpu()\n",
    "        top1_logits = logits_k[:, 0].detach().cpu()\n",
    "        top2_logits = logits_k[:, 1].detach().cpu()\n",
    "        # norm\n",
    "        z_norm = latent.norm(dim=1)\n",
    "        # cos_theta, zn : (50000, 640)\n",
    "        zn = F.normalize(latent, dim=1)\n",
    "        # wn : (10, 640)\n",
    "        wn = F.normalize(w, dim=1)\n",
    "        cos_thetas = torch.matmul(zn, wn.transpose(0, 1))\n",
    "        # dist\n",
    "        diff = latent.unsqueeze(dim=1) - m.unsqueeze(dim=0)\n",
    "        dist = diff.norm(dim=-1)\n",
    "\n",
    "        for lb in lower_bound:\n",
    "            select_index = (lb < top1_prob) & (top1_prob <= lb + 0.05)\n",
    "            if select_index.sum().item() == 0:\n",
    "                top1_logits_mean.append(0); top2_logits_mean.append(0)\n",
    "                top1_logits_std.append(0); top2_logits_std.append(0)\n",
    "                z_norm_means.append(0); z_norm_stds.append(0) \n",
    "                counts.append(0)\n",
    "                top1_cos_theta_mean.append(0); top2_cos_theta_mean.append(0)\n",
    "                top1_cos_theta_std.append(0); top2_cos_theta_std.append(0)\n",
    "                top1_dist_mean.append(0); top2_dist_mean.append(0)\n",
    "                top1_dist_std.append(0); top2_dist_std.append(0)\n",
    "                continue\n",
    "            top1_pred= pred[:, 0][select_index]\n",
    "            top2_pred = pred[:, 1][select_index]\n",
    "            \n",
    "            top1_logits_mean.append(top1_logits[select_index].mean().item())\n",
    "            top2_logits_mean.append(top2_logits[select_index].mean().item())\n",
    "            top1_logits_std.append(top1_logits[select_index].std().item())\n",
    "            top2_logits_std.append(top2_logits[select_index].std().item())\n",
    "\n",
    "            z_norm_means.append(z_norm[select_index].mean().item())\n",
    "            z_norm_stds.append(z_norm[select_index].std().item())\n",
    "            counts.append(select_index.sum().item())\n",
    "\n",
    "            top1_cos = torch.gather(cos_thetas[select_index], 1, top1_pred.unsqueeze(dim=1))\n",
    "            top2_cos = torch.gather(cos_thetas[select_index], 1, top2_pred.unsqueeze(dim=1))\n",
    "            top1_cos_theta_mean.append(top1_cos.mean().item())\n",
    "            top2_cos_theta_mean.append(top2_cos.mean().item())\n",
    "            top1_cos_theta_std.append(top1_cos.std().item())\n",
    "            top2_cos_theta_std.append(top2_cos.std().item())\n",
    "\n",
    "            top1_dist = torch.gather(dist[select_index], 1, top1_pred.unsqueeze(dim=1))\n",
    "            top2_dist = torch.gather(dist[select_index], 1, top2_pred.unsqueeze(dim=1))\n",
    "            top1_dist_mean.append(top1_dist.mean().item())\n",
    "            top2_dist_mean.append(top2_dist.mean().item())\n",
    "            top1_dist_std.append(top1_dist.std().item())\n",
    "            top2_dist_std.append(top2_dist.std().item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/eavnjeong/anaconda3/envs/ebm/lib/python3.7/site-packages/ipykernel_launcher.py:30: FutureWarning: In a future version of pandas all arguments of concat except for the argument 'objs' will be keyword-only\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1</td>\n",
       "      <td>3.033127</td>\n",
       "      <td>3.001996</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.477551</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.580285</td>\n",
       "      <td>0.562632</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.597315</td>\n",
       "      <td>0.558496</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>1</td>\n",
       "      <td>3.500773</td>\n",
       "      <td>3.444402</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.295817</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.599378</td>\n",
       "      <td>0.616049</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.517174</td>\n",
       "      <td>0.579657</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>2</td>\n",
       "      <td>3.504164</td>\n",
       "      <td>3.037607</td>\n",
       "      <td>0.067069</td>\n",
       "      <td>0.430121</td>\n",
       "      <td>0.392021</td>\n",
       "      <td>0.103156</td>\n",
       "      <td>0.621852</td>\n",
       "      <td>0.580426</td>\n",
       "      <td>0.006648</td>\n",
       "      <td>0.037298</td>\n",
       "      <td>0.841209</td>\n",
       "      <td>0.989688</td>\n",
       "      <td>0.470500</td>\n",
       "      <td>0.637170</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>2</td>\n",
       "      <td>4.237620</td>\n",
       "      <td>3.883000</td>\n",
       "      <td>0.215789</td>\n",
       "      <td>0.243340</td>\n",
       "      <td>0.295608</td>\n",
       "      <td>0.102214</td>\n",
       "      <td>0.687061</td>\n",
       "      <td>0.642770</td>\n",
       "      <td>0.021354</td>\n",
       "      <td>0.036293</td>\n",
       "      <td>1.315311</td>\n",
       "      <td>0.513292</td>\n",
       "      <td>0.135540</td>\n",
       "      <td>0.004725</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>7</td>\n",
       "      <td>4.390718</td>\n",
       "      <td>3.697689</td>\n",
       "      <td>0.670432</td>\n",
       "      <td>1.019858</td>\n",
       "      <td>0.353820</td>\n",
       "      <td>0.109969</td>\n",
       "      <td>0.678105</td>\n",
       "      <td>0.626867</td>\n",
       "      <td>0.048933</td>\n",
       "      <td>0.084800</td>\n",
       "      <td>0.447032</td>\n",
       "      <td>0.790346</td>\n",
       "      <td>0.072043</td>\n",
       "      <td>0.382372</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>8</td>\n",
       "      <td>4.506901</td>\n",
       "      <td>3.515653</td>\n",
       "      <td>1.002059</td>\n",
       "      <td>1.406306</td>\n",
       "      <td>0.371374</td>\n",
       "      <td>0.054938</td>\n",
       "      <td>0.688596</td>\n",
       "      <td>0.608008</td>\n",
       "      <td>0.069415</td>\n",
       "      <td>0.126359</td>\n",
       "      <td>0.473931</td>\n",
       "      <td>0.776000</td>\n",
       "      <td>0.062361</td>\n",
       "      <td>0.317787</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>9</td>\n",
       "      <td>4.747839</td>\n",
       "      <td>3.613601</td>\n",
       "      <td>0.656549</td>\n",
       "      <td>1.040155</td>\n",
       "      <td>0.457703</td>\n",
       "      <td>0.326819</td>\n",
       "      <td>0.713963</td>\n",
       "      <td>0.616824</td>\n",
       "      <td>0.051364</td>\n",
       "      <td>0.090938</td>\n",
       "      <td>0.508973</td>\n",
       "      <td>0.791469</td>\n",
       "      <td>0.197320</td>\n",
       "      <td>0.310444</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>7</td>\n",
       "      <td>4.706829</td>\n",
       "      <td>3.251742</td>\n",
       "      <td>0.574373</td>\n",
       "      <td>0.814042</td>\n",
       "      <td>0.619353</td>\n",
       "      <td>0.408343</td>\n",
       "      <td>0.709786</td>\n",
       "      <td>0.596784</td>\n",
       "      <td>0.044749</td>\n",
       "      <td>0.073835</td>\n",
       "      <td>0.554542</td>\n",
       "      <td>0.968843</td>\n",
       "      <td>0.271013</td>\n",
       "      <td>0.386053</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>16</td>\n",
       "      <td>4.956493</td>\n",
       "      <td>3.103607</td>\n",
       "      <td>0.580142</td>\n",
       "      <td>0.878752</td>\n",
       "      <td>0.430077</td>\n",
       "      <td>0.162216</td>\n",
       "      <td>0.726047</td>\n",
       "      <td>0.576487</td>\n",
       "      <td>0.045147</td>\n",
       "      <td>0.082745</td>\n",
       "      <td>0.558313</td>\n",
       "      <td>0.895276</td>\n",
       "      <td>0.294310</td>\n",
       "      <td>0.316218</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>37</td>\n",
       "      <td>5.221251</td>\n",
       "      <td>3.062369</td>\n",
       "      <td>0.567641</td>\n",
       "      <td>0.907085</td>\n",
       "      <td>0.404092</td>\n",
       "      <td>0.117924</td>\n",
       "      <td>0.742506</td>\n",
       "      <td>0.567660</td>\n",
       "      <td>0.038220</td>\n",
       "      <td>0.089039</td>\n",
       "      <td>0.419016</td>\n",
       "      <td>0.809167</td>\n",
       "      <td>0.082516</td>\n",
       "      <td>0.314309</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>75</td>\n",
       "      <td>5.592988</td>\n",
       "      <td>3.045088</td>\n",
       "      <td>0.630330</td>\n",
       "      <td>1.022942</td>\n",
       "      <td>0.438536</td>\n",
       "      <td>0.200547</td>\n",
       "      <td>0.770494</td>\n",
       "      <td>0.564139</td>\n",
       "      <td>0.043757</td>\n",
       "      <td>0.098550</td>\n",
       "      <td>0.441106</td>\n",
       "      <td>0.825059</td>\n",
       "      <td>0.114369</td>\n",
       "      <td>0.319321</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>227</td>\n",
       "      <td>5.870104</td>\n",
       "      <td>2.482575</td>\n",
       "      <td>0.460818</td>\n",
       "      <td>0.830529</td>\n",
       "      <td>0.450656</td>\n",
       "      <td>0.203869</td>\n",
       "      <td>0.791094</td>\n",
       "      <td>0.509363</td>\n",
       "      <td>0.031784</td>\n",
       "      <td>0.086040</td>\n",
       "      <td>0.427498</td>\n",
       "      <td>0.816557</td>\n",
       "      <td>0.134800</td>\n",
       "      <td>0.309614</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>49608</td>\n",
       "      <td>8.312694</td>\n",
       "      <td>0.825794</td>\n",
       "      <td>0.506107</td>\n",
       "      <td>0.551967</td>\n",
       "      <td>0.870464</td>\n",
       "      <td>0.458559</td>\n",
       "      <td>0.947195</td>\n",
       "      <td>0.283532</td>\n",
       "      <td>0.031249</td>\n",
       "      <td>0.090445</td>\n",
       "      <td>0.316978</td>\n",
       "      <td>1.120722</td>\n",
       "      <td>0.184774</td>\n",
       "      <td>0.474792</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        0         0         0         0         0         0         0  \\\n",
       "0       0  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   \n",
       "1       0  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   \n",
       "2       0  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   \n",
       "3       0  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   \n",
       "4       0  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   \n",
       "5       1  3.033127  3.001996       NaN       NaN  0.477551       NaN   \n",
       "6       0  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   \n",
       "7       0  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   \n",
       "8       1  3.500773  3.444402       NaN       NaN  0.295817       NaN   \n",
       "9       2  3.504164  3.037607  0.067069  0.430121  0.392021  0.103156   \n",
       "10      2  4.237620  3.883000  0.215789  0.243340  0.295608  0.102214   \n",
       "11      7  4.390718  3.697689  0.670432  1.019858  0.353820  0.109969   \n",
       "12      8  4.506901  3.515653  1.002059  1.406306  0.371374  0.054938   \n",
       "13      9  4.747839  3.613601  0.656549  1.040155  0.457703  0.326819   \n",
       "14      7  4.706829  3.251742  0.574373  0.814042  0.619353  0.408343   \n",
       "15     16  4.956493  3.103607  0.580142  0.878752  0.430077  0.162216   \n",
       "16     37  5.221251  3.062369  0.567641  0.907085  0.404092  0.117924   \n",
       "17     75  5.592988  3.045088  0.630330  1.022942  0.438536  0.200547   \n",
       "18    227  5.870104  2.482575  0.460818  0.830529  0.450656  0.203869   \n",
       "19  49608  8.312694  0.825794  0.506107  0.551967  0.870464  0.458559   \n",
       "\n",
       "           0         0         0         0         0         0         0  \\\n",
       "0   0.000000  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   \n",
       "1   0.000000  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   \n",
       "2   0.000000  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   \n",
       "3   0.000000  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   \n",
       "4   0.000000  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   \n",
       "5   0.580285  0.562632       NaN       NaN  0.597315  0.558496       NaN   \n",
       "6   0.000000  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   \n",
       "7   0.000000  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   \n",
       "8   0.599378  0.616049       NaN       NaN  0.517174  0.579657       NaN   \n",
       "9   0.621852  0.580426  0.006648  0.037298  0.841209  0.989688  0.470500   \n",
       "10  0.687061  0.642770  0.021354  0.036293  1.315311  0.513292  0.135540   \n",
       "11  0.678105  0.626867  0.048933  0.084800  0.447032  0.790346  0.072043   \n",
       "12  0.688596  0.608008  0.069415  0.126359  0.473931  0.776000  0.062361   \n",
       "13  0.713963  0.616824  0.051364  0.090938  0.508973  0.791469  0.197320   \n",
       "14  0.709786  0.596784  0.044749  0.073835  0.554542  0.968843  0.271013   \n",
       "15  0.726047  0.576487  0.045147  0.082745  0.558313  0.895276  0.294310   \n",
       "16  0.742506  0.567660  0.038220  0.089039  0.419016  0.809167  0.082516   \n",
       "17  0.770494  0.564139  0.043757  0.098550  0.441106  0.825059  0.114369   \n",
       "18  0.791094  0.509363  0.031784  0.086040  0.427498  0.816557  0.134800   \n",
       "19  0.947195  0.283532  0.031249  0.090445  0.316978  1.120722  0.184774   \n",
       "\n",
       "           0  \n",
       "0   0.000000  \n",
       "1   0.000000  \n",
       "2   0.000000  \n",
       "3   0.000000  \n",
       "4   0.000000  \n",
       "5        NaN  \n",
       "6   0.000000  \n",
       "7   0.000000  \n",
       "8        NaN  \n",
       "9   0.637170  \n",
       "10  0.004725  \n",
       "11  0.382372  \n",
       "12  0.317787  \n",
       "13  0.310444  \n",
       "14  0.386053  \n",
       "15  0.316218  \n",
       "16  0.314309  \n",
       "17  0.319321  \n",
       "18  0.309614  \n",
       "19  0.474792  "
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_counts = pd.DataFrame(counts)\n",
    "df_top1_mean = pd.DataFrame([top1_logits_mean]).T\n",
    "df_top2_mean = pd.DataFrame([top2_logits_mean]).T\n",
    "df_top1_std = pd.DataFrame([top1_logits_std]).T\n",
    "df_top2_std = pd.DataFrame([top2_logits_std]).T\n",
    "\n",
    "df_norm_mean = pd.DataFrame([z_norm_means]).T\n",
    "df_norm_std = pd.DataFrame([z_norm_stds]).T\n",
    "\n",
    "df_top1_cos_theta_mean = pd.DataFrame([top1_cos_theta_mean]).T\n",
    "df_top2_cos_theta_mean = pd.DataFrame([top2_cos_theta_mean]).T\n",
    "df_top1_cos_theta_std = pd.DataFrame([top1_cos_theta_std]).T\n",
    "df_top2_cos_theta_std = pd.DataFrame([top2_cos_theta_std]).T\n",
    "\n",
    "df_top1_dist_mean = pd.DataFrame([top1_dist_mean]).T\n",
    "df_top2_dist_mean = pd.DataFrame([top2_dist_mean]).T\n",
    "df_top1_dist_std = pd.DataFrame([top1_dist_std]).T\n",
    "df_top2_dist_std = pd.DataFrame([top2_dist_std]).T\n",
    "\n",
    "\n",
    "pd.concat(\n",
    "    [\n",
    "     df_counts, \n",
    "     df_top1_mean, df_top2_mean,\n",
    "     df_top1_std, df_top2_std, \n",
    "     df_norm_mean, df_norm_std,\n",
    "     df_top1_cos_theta_mean, df_top2_cos_theta_mean,\n",
    "     df_top1_cos_theta_std, df_top2_cos_theta_std,\n",
    "     df_top1_dist_mean, df_top2_dist_mean,\n",
    "     df_top1_dist_std, df_top2_dist_std], 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Non Natural"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "N = []\n",
    "U = []\n",
    "OODomain = []\n",
    "Constant = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    model.eval()\n",
    "    for x, y in tqdm(vl_dl):\n",
    "        x = x.cuda()\n",
    "        Ns = torch.randn_like(x).cuda()\n",
    "        Us = torch.empty_like(x).uniform_(-1, 1).cuda()\n",
    "        OODomains = torch.empty_like(x).uniform_(-10, 10)\n",
    "\n",
    "        num_images = x.shape[0]      \n",
    "        pixels = torch.rand((num_images, 3), dtype=torch.float32)  \n",
    "        images = torch.ones((num_images, 32, 32, 3), dtype=torch.float32)\n",
    "        for i in range(num_images):\n",
    "            images[i] *= pixels[i]\n",
    "        images = images.permute(0, 3, 1, 2).cuda()\n",
    "\n",
    "        N.append(model.enc(Ns))\n",
    "        U.append(model.enc(Us))\n",
    "        OODomain.append(model.enc(OODomains))\n",
    "        Constant.append(model.enc(images))\n",
    "\n",
    "N = torch.cat(N)\n",
    "U = torch.cat(U)\n",
    "OODomain = torch.cat(OODomain)\n",
    "Constant = torch.cat(Constant)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "origin = 0\n",
    "softmax = nn.Softmax(1)\n",
    "w = model.head.num.ms\n",
    "lower_bound = np.linspace(0, 1, 21)[:-1]\n",
    "\n",
    "tmp_latent = in_latent - origin\n",
    "m = [tmp_latent[in_label == lbl].mean(0) for lbl in range(10)]\n",
    "m = torch.stack(m)\n",
    "# top1_logits, top2_logits = [], []\n",
    "prob1, prob2 = [], []\n",
    "\n",
    "top1_cos_theta_mean, top1_cos_theta_std = [], []\n",
    "top2_cos_theta_mean, top2_cos_theta_std = [], []\n",
    "\n",
    "top1_logits_mean, top2_logits_mean = [], []\n",
    "top1_logits_std, top2_logits_std = [], []\n",
    "counts = []\n",
    "z_norm_means, z_norm_stds = [], []\n",
    "\n",
    "top1_dist_mean, top2_dist_mean = [], []\n",
    "top1_dist_std, top2_dist_std = [], []\n",
    "\n",
    "with torch.no_grad():\n",
    "    # [N, U, OODomain, Constant]\n",
    "    for latent in [Constant]:\n",
    "        latent = latent - origin\n",
    "\n",
    "        logits = model.head(latent)\n",
    "        prob, pred = torch.topk(softmax(logits), 2)\n",
    "        logits_k, _ = torch.topk(logits, 2)\n",
    "\n",
    "        print(prob[:, 0].mean().item(), prob[:, 1].mean().item(), \\\n",
    "            prob[:, 0].std().item(), prob[:, 1].std().item())\n",
    "\n",
    "        top1_prob = prob[:, 0].detach().cpu()\n",
    "        top2_prob = prob[:, 1].detach().cpu()\n",
    "        top1_logits = logits_k[:, 0].detach().cpu()\n",
    "        top2_logits = logits_k[:, 1].detach().cpu()\n",
    "        # norm\n",
    "        z_norm = latent.norm(dim=1)\n",
    "        # cos_theta, zn : (50000, 640)\n",
    "        zn = F.normalize(latent, dim=1)\n",
    "        # wn : (10, 640)\n",
    "        wn = F.normalize(w, dim=1)\n",
    "        cos_thetas = torch.matmul(zn, wn.transpose(0, 1))\n",
    "        # dist\n",
    "        diff = latent.unsqueeze(dim=1) - m.unsqueeze(dim=0)\n",
    "        dist = diff.norm(dim=-1)\n",
    "\n",
    "        for lb in lower_bound:\n",
    "            select_index = (lb < top1_prob) & (top1_prob <= lb + 0.05)\n",
    "            if select_index.sum().item() == 0:\n",
    "                top1_logits_mean.append(0); top2_logits_mean.append(0)\n",
    "                top1_logits_std.append(0); top2_logits_std.append(0)\n",
    "                z_norm_means.append(0); z_norm_stds.append(0) \n",
    "                counts.append(0)\n",
    "                top1_cos_theta_mean.append(0); top2_cos_theta_mean.append(0)\n",
    "                top1_cos_theta_std.append(0); top2_cos_theta_std.append(0)\n",
    "                top1_dist_mean.append(0); top2_dist_mean.append(0)\n",
    "                top1_dist_std.append(0); top2_dist_std.append(0)\n",
    "                continue\n",
    "            top1_pred= pred[:, 0][select_index]\n",
    "            top2_pred = pred[:, 1][select_index]\n",
    "            \n",
    "            top1_logits_mean.append(top1_logits[select_index].mean().item())\n",
    "            top2_logits_mean.append(top2_logits[select_index].mean().item())\n",
    "            top1_logits_std.append(top1_logits[select_index].std().item())\n",
    "            top2_logits_std.append(top2_logits[select_index].std().item())\n",
    "\n",
    "            z_norm_means.append(z_norm[select_index].mean().item())\n",
    "            z_norm_stds.append(z_norm[select_index].std().item())\n",
    "            counts.append(select_index.sum().item())\n",
    "\n",
    "            top1_cos = torch.gather(cos_thetas[select_index], 1, top1_pred.unsqueeze(dim=1))\n",
    "            top2_cos = torch.gather(cos_thetas[select_index], 1, top2_pred.unsqueeze(dim=1))\n",
    "            top1_cos_theta_mean.append(top1_cos.mean().item())\n",
    "            top2_cos_theta_mean.append(top2_cos.mean().item())\n",
    "            top1_cos_theta_std.append(top1_cos.std().item())\n",
    "            top2_cos_theta_std.append(top2_cos.std().item())\n",
    "\n",
    "            top1_dist = torch.gather(dist[select_index], 1, top1_pred.unsqueeze(dim=1))\n",
    "            top2_dist = torch.gather(dist[select_index], 1, top2_pred.unsqueeze(dim=1))\n",
    "            top1_dist_mean.append(top1_dist.mean().item())\n",
    "            top2_dist_mean.append(top2_dist.mean().item())\n",
    "            top1_dist_std.append(top1_dist.std().item())\n",
    "            top2_dist_std.append(top2_dist.std().item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_counts = pd.DataFrame(counts)\n",
    "df_top1_mean = pd.DataFrame([top1_logits_mean]).T\n",
    "df_top2_mean = pd.DataFrame([top2_logits_mean]).T\n",
    "df_top1_std = pd.DataFrame([top1_logits_std]).T\n",
    "df_top2_std = pd.DataFrame([top2_logits_std]).T\n",
    "\n",
    "df_norm_mean = pd.DataFrame([z_norm_means]).T\n",
    "df_norm_std = pd.DataFrame([z_norm_stds]).T\n",
    "\n",
    "df_top1_cos_theta_mean = pd.DataFrame([top1_cos_theta_mean]).T\n",
    "df_top2_cos_theta_mean = pd.DataFrame([top2_cos_theta_mean]).T\n",
    "df_top1_cos_theta_std = pd.DataFrame([top1_cos_theta_std]).T\n",
    "df_top2_cos_theta_std = pd.DataFrame([top2_cos_theta_std]).T\n",
    "\n",
    "df_top1_dist_mean = pd.DataFrame([top1_dist_mean]).T\n",
    "df_top2_dist_mean = pd.DataFrame([top2_dist_mean]).T\n",
    "df_top1_dist_std = pd.DataFrame([top1_dist_std]).T\n",
    "df_top2_dist_std = pd.DataFrame([top2_dist_std]).T\n",
    "\n",
    "\n",
    "pd.concat(\n",
    "    [\n",
    "     df_counts, \n",
    "     df_top1_mean, df_top2_mean,\n",
    "     df_top1_std, df_top2_std, \n",
    "     df_norm_mean, df_norm_std,\n",
    "     df_top1_cos_theta_mean, df_top2_cos_theta_mean,\n",
    "     df_top1_cos_theta_std, df_top2_cos_theta_std,\n",
    "     df_top1_dist_mean, df_top2_dist_mean,\n",
    "     df_top1_dist_std, df_top2_dist_std], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ebm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
