{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "print(torch.cuda.is_available())\n",
    "from models.proo_head import PN_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def gen_clusters(a, sample_num):\n",
    "    mean1 = [a, 0]\n",
    "    cov1 = [[1, 0], [0, 1]]\n",
    "    data1 = np.random.multivariate_normal(mean1, cov1, sample_num)\n",
    "\n",
    "    mean2 = [-a, 0]\n",
    "    cov2 = [[1, 0], [0, 1]]\n",
    "    data2 = np.random.multivariate_normal(mean2, cov2, sample_num)\n",
    "\n",
    "    return np.round(data1, 4), np.round(data2, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def draw_gaussian_acc(acc_dict, dis_dict, acc_all_mean, a, shot):\n",
    "    info = acc_dict\n",
    "    key_sort = sorted(info.keys())\n",
    "    dis_list = []\n",
    "    acc_list = []\n",
    "    acc_std_list = []\n",
    "    acc_all_list = []\n",
    "    dis_num = []\n",
    "    for key in key_sort:\n",
    "        dis_list.append(key)\n",
    "        acc_all_list.append(info[key])\n",
    "        acc_std_list.append(np.std(np.array(info[key])))\n",
    "        acc_list.append(np.array(info[key]).mean())\n",
    "        dis_num.append(dis_dict[key])\n",
    "        print('dis is:', key, 'acc is:', np.array(info[key]).mean())\n",
    "\n",
    "    dis_list = np.array(dis_list)\n",
    "    fig, ax = plt.subplots(figsize=(7, 3.6))\n",
    "    ax.errorbar(dis_list, acc_list, acc_std_list, fmt='o-', ecolor='lightcoral', color='royalblue',\n",
    "                elinewidth=1.5, ms=3.5, capsize=2, label='Accuracies with Varied Distance')\n",
    "    ax.plot([dis_list.min(), dis_list.max()], [acc_all_mean, acc_all_mean], color='orange', label='Average Accuracy %s' % np.round(acc_all_mean*100, 1))\n",
    "    plt.legend(fontsize=12, loc='center right')\n",
    "\n",
    "    ax.set_ylim((0.0, 1.05))\n",
    "    ax.set_xlim((dis_list.min(), dis_list.max()))\n",
    "\n",
    "    ax.set_yticks(np.arange(0, 11) / 10.0, np.arange(0, 11) * 10, fontsize=10, color='indianred')\n",
    "    ax.set_ylabel(\"Accuracy(%)\", fontsize=15, color='indianred')\n",
    "    ax.set_xlabel(\"Average Distance with Local Centorid\", fontsize=15)\n",
    "\n",
    "    ax1 = ax.twinx()\n",
    "    ax1.set_ylim((0.0, 1.05))\n",
    "    ax1.bar(dis_list, np.array(dis_num) / 1000, alpha=0.3, width=0.1, color='g', label='# Number of Tasks')\n",
    "    ax1.set_yticks(np.arange(0, 11) / 10.0, np.arange(0, 11) * 100, fontsize=10, color='g')\n",
    "    ax1.set_ylabel(\"#Numbers of Tasks\", fontsize=15, color='g')\n",
    "\n",
    "    plt.legend(fontsize=15, loc='lower right')\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def get_acc(a, n_shot, n_query=200):\n",
    "    acc_dict = {}\n",
    "    acc_list = []\n",
    "    dis_dict = {}\n",
    "    n_task = 10000\n",
    "    for i in range(n_task):\n",
    "        class1, class2 = gen_clusters(a, n_shot + n_query)\n",
    "        class1 = torch.Tensor(class1).unsqueeze(0).cuda()\n",
    "        class2 = torch.Tensor(class2).unsqueeze(0).cuda()\n",
    "        label1 = np.array([0] * (n_query))\n",
    "        label2 = np.array([1] * (n_query))\n",
    "\n",
    "        class1_s = class1[:, :n_shot, :].reshape(1, n_shot, -1)  # (batch, n_shot, n_dim)\n",
    "        class1_q = class1[:, n_shot:, :]\n",
    "        class2_s = class2[:, :n_shot, :].reshape(1, n_shot, -1)\n",
    "        class2_q = class2[:, n_shot:, :]\n",
    "        query_label = np.concatenate([label1, label2])\n",
    "\n",
    "        prototypes_1 = torch.mean(class1_s, dim=1, keepdim=True)\n",
    "        prototypes_2 = torch.mean(class2_s, dim=1, keepdim=True)\n",
    "\n",
    "        prototype = torch.cat([prototypes_1, prototypes_2], dim=1)\n",
    "\n",
    "        dis = (torch.norm(prototypes_1, p=2) + torch.norm(prototypes_2, p=2))\n",
    "\n",
    "        support_data = torch.cat([class1_s, class2_s], axis=1)\n",
    "        query_data = torch.cat([class1_q, class2_q], axis=1)\n",
    "        classifier = PN_head(scale_cls=1, normalize=False, metric=\"euclidean\").cuda()\n",
    "        classification_scores = classifier(query_data, support_data, 2, n_shot,\n",
    "                                           prototypes=prototype)  # shape (batch, num_, n_way)\n",
    "        cls = torch.argmax(classification_scores.squeeze(0), dim=1)\n",
    "\n",
    "        acc = np.mean(cls.detach().cpu().numpy() == query_label)\n",
    "        acc_list.append(acc)\n",
    "\n",
    "        if i % 200 == 0:\n",
    "            print('step is :', i)\n",
    "\n",
    "        dis = dis.cpu().numpy()\n",
    "        dis = dis.mean()\n",
    "        dis = np.around(dis, 1)\n",
    "        if dis in dis_dict:\n",
    "            dis_dict[dis] = dis_dict[dis] + 1\n",
    "        else:\n",
    "            dis_dict[dis] = 1\n",
    "\n",
    "        if dis in acc_dict.keys():\n",
    "            acc_dict[dis].append(acc)\n",
    "        else:\n",
    "            acc_dict[dis] = [acc]\n",
    "\n",
    "    acc_np = np.array(acc_list)\n",
    "    draw_gaussian_acc(acc_dict, dis_dict, acc_np.mean(), a, n_shot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "Cannot initialize CUDA without ATen_cuda library. PyTorch splits its backend into two shared libraries: a CPU library and a CUDA library; this error has occurred because you are trying to use some CUDA functionality, but the CUDA library has not been loaded by the dynamic linker for some reason.  The CUDA library MUST be loaded, EVEN IF you don't directly use any symbols from the CUDA library! One common culprit is a lack of -Wl,--no-as-needed in your link arguments; many dynamic linkers will delete dynamic library dependencies if you don't depend on any of their symbols.  You can check if this has occurred by using ldd on your binary to see if there is a dependency on *_cuda.so library.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-19-7982bebf666a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mget_acc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_shot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-18-cea5b69d25c9>\u001b[0m in \u001b[0;36mget_acc\u001b[0;34m(a, n_shot, n_query)\u001b[0m\n\u001b[1;32m      6\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_task\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m         \u001b[0mclass1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclass2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgen_clusters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_shot\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mn_query\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m         \u001b[0mclass1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclass1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      9\u001b[0m         \u001b[0mclass2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclass2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m         \u001b[0mlabel1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mn_query\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Cannot initialize CUDA without ATen_cuda library. PyTorch splits its backend into two shared libraries: a CPU library and a CUDA library; this error has occurred because you are trying to use some CUDA functionality, but the CUDA library has not been loaded by the dynamic linker for some reason.  The CUDA library MUST be loaded, EVEN IF you don't directly use any symbols from the CUDA library! One common culprit is a lack of -Wl,--no-as-needed in your link arguments; many dynamic linkers will delete dynamic library dependencies if you don't depend on any of their symbols.  You can check if this has occurred by using ldd on your binary to see if there is a dependency on *_cuda.so library."
     ]
    }
   ],
   "source": [
    "get_acc(a=0.5, n_shot=1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "get_acc(a=1, n_shot=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "get_acc(a=2, n_shot=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "get_acc(a=1, n_shot=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "get_acc(a=1, n_shot=5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
