{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Polarity is all you need to learn and transfer faster\n",
    "This notebook is a tutorial for reproducing the computer vision results in **Polarity is all you need to learn and transfer faster**.  \n",
    "* Data used in this notebook can be downloaded from OSF following this anonymized [link](https://osf.io/f9wtc/?view_only=61b71c37306a41209da0eb1c35dbf8d0)\n",
    "* Pre-trained AlexNet weights were obtainer from [here](https://www.cs.toronto.edu/~guerzhoy/tf_alexnet/). Please download the tf2 version before running the experiments code\n",
    "* All experiments and analysis were performed within the following docker environment, which could be setup as following. \n",
    "```setup\n",
    "docker build -t weightpolarity .\n",
    "docker run --gpus all -v ${PWD}:/workspace --name weightpolarity weightpolarity\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Computer vision experiments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* The experiments are set up to run in background in paralell. \n",
    "* You may choose to change the training sample size (sample_size_list) and epoch number (num_epoch) configurations in batchcis.py. The current setup is for paper reproduction. \n",
    "* You are suggested to run below lines in command line. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! python batchcis.py --num_run=20 --baseFName=\"fashion_mnist\" --num_epoch=100 --ckpt_freq=100 --doRandInit=2\n",
    "! python batchcis.py --num_run=20 --baseFName=\"cifar10\" --num_epoch=100 --ckpt_freq=100 --doRandInit=2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Consolidate all data\n",
    "* Below two cells, will collect all experimental data into \"epoch_acc_loss.pkl\" file which are accessible from the above OSF [link](https://osf.io/f9wtc/?view_only=61b71c37306a41209da0eb1c35dbf8d0). The first one is for fashion_mnist, second is for cifar10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! python doPlots.py --num_run=20 --baseFName='/Users/alice/cis/wp_CV/fashion_mnist' --resetType='posRand' --numEpoch=100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! python doPlots.py --num_run=20 --baseFName='/Users/alice/cis/wp_CV/cifar10' --resetType='posRand' --numEpoch=100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.Preparation - load necessary packages and helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib as matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# white background\n",
    "plt.rcParams.update({\n",
    "    \"lines.color\": \"black\",\n",
    "    \"patch.edgecolor\": \"black\",\n",
    "    \"text.color\": \"black\",\n",
    "    \"axes.facecolor\": \"black\",\n",
    "    \"axes.edgecolor\": \"black\",\n",
    "    \"axes.labelcolor\": \"black\",\n",
    "    \"xtick.color\": \"black\",\n",
    "    \"ytick.color\": \"black\",\n",
    "    \"grid.color\": \"lightgray\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from plots import dict_to_query, plot_median_plus_example, simpleaxis, plot_sem, plot_diff_plus_mannwhitneyu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def paint_color(x, y):\n",
    "    if x > 0:\n",
    "        if y > 0:\n",
    "            return np.array(3)\n",
    "        else:\n",
    "            return np.array(2)\n",
    "    else:\n",
    "        if y > 0:\n",
    "            return np.array(1)\n",
    "        else:\n",
    "            return np.array(0)\n",
    "np_paint_color = np.vectorize(paint_color, otypes = [np.ndarray])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.Load Data\n",
    "* Here, we will use experiment data presented in the manuscript. They are presented in the main text, each experiment was ran for 20 repeats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = np.load('epoch_acc_loss.pkl', allow_pickle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Decided to still mannually define here so that it can be cross-checked with the loaded data\n",
    "config = {}\n",
    "config['batch_size'] = 1000\n",
    "\n",
    "doWeightFreeze_list = [True, False]\n",
    "sample_size_list = np.concatenate(([100,250,500,750], np.arange(1,7)*1000)) # fashion-MNIST\n",
    "# sample_size_list = np.array([100,250,500,750,1000,2500,5000,10000,25000,50000]) # cifar10\n",
    "num_epoch_list = [50] * len(sample_size_list)\n",
    "networkType_list = ['pretrained', 'vanilla', 'finetune']\n",
    "\n",
    "numVal = 1000\n",
    "numRun = 20 \n",
    "resetType_list = ['posRand']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.**DATA: Statistical Efficiency** - Figure 2 & 3 first column"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_converge_epoch(x):\n",
    "    epoch_num = np.squeeze(np.where(x == np.amin(x)))\n",
    "    if not len(epoch_num.shape)==0:\n",
    "        epoch_num = epoch_num[0]\n",
    "    return epoch_num"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_acc_at_conver(val_acc, val_loss):\n",
    "    # both are matrices with rows as run\n",
    "    epoch_at_conver = np.apply_along_axis(get_converge_epoch, 1, val_loss)\n",
    "    acc_at_conver = np.squeeze(np.take_along_axis(val_acc, np.expand_dims(epoch_at_conver, axis=1), axis=1))\n",
    "    # print(epoch_at_conver.shape)\n",
    "    # print(acc_at_conver.shape)    \n",
    "    return acc_at_conver, epoch_at_conver"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_conver(metrics, doWeightFreeze_list, resetType_list, networkType, numRun):\n",
    "    conver = pd.DataFrame(index=range(len(doWeightFreeze_list)*len(resetType_list)), columns=[\"typeStr\", \"resetType\", \"networkType\", \"acc\", \"epoch\"])\n",
    "\n",
    "    for ridx, resetType in enumerate(resetType_list):\n",
    "        for didx, doWeightFreeze in enumerate(doWeightFreeze_list):\n",
    "            df_loc = ridx*len(doWeightFreeze_list)+didx\n",
    "            if doWeightFreeze:\n",
    "                typeStr = 'freeze'\n",
    "                typeStr_dict = '\\'freeze\\''\n",
    "            else:\n",
    "                typeStr = 'liquid'\n",
    "                typeStr_dict = '\\'liquid\\''\n",
    "            conver['typeStr'][df_loc] = typeStr\n",
    "            conver['resetType'][df_loc] = resetType\n",
    "            conver['networkType'][df_loc] = networkType\n",
    "\n",
    "            # get different sample sizes into list\n",
    "            page_dict = {'networkType':'\\''+ networkType +'\\'', 'typeStr':typeStr_dict, 'resetType':'\\''+ resetType +'\\''}\n",
    "            val_loss = metrics.query(dict_to_query(page_dict))['validation_loss'].tolist()\n",
    "            val_acc = metrics.query(dict_to_query(page_dict))['validation_acc'].tolist()\n",
    "            # print(val_acc[0].shape)\n",
    "\n",
    "            acc_at_conver = np.empty(shape=(len(sample_size_list), numRun))\n",
    "            epoch_at_conver = np.empty(shape=(len(sample_size_list), numRun))\n",
    "            for idx, (this_val_loss, this_val_acc) in enumerate(zip(val_loss, val_acc)):\n",
    "                this_acc_at_conver, this_epoch_at_conver = get_acc_at_conver(this_val_acc, this_val_loss)\n",
    "                acc_at_conver[idx, :] = this_acc_at_conver\n",
    "                epoch_at_conver[idx, :] = this_epoch_at_conver\n",
    "                # print('sample size %d is done' % idx)\n",
    "            conver['acc'][df_loc] = np.transpose(acc_at_conver)\n",
    "            conver['epoch'][df_loc] = np.transpose(epoch_at_conver)\n",
    "    return conver"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_stack_flops(flops, typeStr, acc, colName, networkType_list):\n",
    "    return np.squeeze(np.stack([np.stack(flops.query(dict_to_query({'typeStr':'\\''+typeStr+'\\'', 'acc':acc, 'networkType':'\\''+networkType +'\\''}))[colName].tolist()) for networkType in networkType_list]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_stack_conver(conver, typeStr, colName, networkType_list):\n",
    "    return np.squeeze(np.stack([np.stack(conver.query(dict_to_query({'typeStr':'\\''+typeStr+'\\'', 'networkType':'\\''+networkType +'\\''}))[colName].tolist()) for networkType in networkType_list]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Below cell plot for Figure 2 first column"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize=20\n",
    "resetType_list = ['posRand']\n",
    "for resetType in resetType_list:\n",
    "    conver_smart = get_conver(metrics, doWeightFreeze_list, resetType_list, 'pretrained', numRun)\n",
    "    conver_dumb = get_conver(metrics, doWeightFreeze_list, resetType_list, 'vanilla', numRun)\n",
    "    conver_finetune = get_conver(metrics, doWeightFreeze_list, resetType_list, 'finetune', numRun)\n",
    "    fig = plt.figure(figsize=(5,4))\n",
    "\n",
    "    thisAx = plt.gca()\n",
    "    doExample = False\n",
    "    x_coord = np.array(sample_size_list)\n",
    "    xlim=[np.min(x_coord), np.max(x_coord)]\n",
    "    ylim=0\n",
    "    color_list = ['r','b','g']\n",
    "\n",
    "    for didx, doWeightFreeze in enumerate(doWeightFreeze_list):\n",
    "        if doWeightFreeze:\n",
    "            typeStr = 'Freeze sufficient-Polarity'\n",
    "            typeStr_dict = '\\'freeze\\''\n",
    "        else:\n",
    "            typeStr = 'Fluid'\n",
    "            typeStr_dict = '\\'liquid\\''\n",
    "        \n",
    "        print(typeStr)\n",
    "        x_mat = conver_smart.query(dict_to_query({'typeStr':typeStr_dict, 'resetType':'\\''+resetType+'\\''}))['acc'].tolist()[0]\n",
    "        plot_median_plus_example((1-x_mat)*100, x_coord, ylim, xlim, typeStr, doExample, thisAx, color=color_list[didx])\n",
    "        # plot_sem((1-x_mat)*100, x_coord, ylim, xlim, typeStr, thisAx, color=color_list[didx])\n",
    "\n",
    "    x_mat = conver_dumb.query(dict_to_query({'typeStr':'\\'freeze\\'', 'resetType':'\\''+resetType+'\\''}))['acc'].tolist()[0]\n",
    "    plot_median_plus_example((1-x_mat)*100, x_coord, ylim, xlim, 'Freeze RAND-Polarity', doExample, thisAx, color=color_list[2])\n",
    "    # plot_sem((1-x_mat)*100, x_coord, ylim, xlim, 'Freeze-uninformed', thisAx, color=color_list[2])\n",
    "\n",
    "    # plt.plot([100,50000], [50,50], '--', color=[.5,.5,.5])\n",
    "    plt.plot([100,6000], [20,20], '--', color=[.5,.5,.5])\n",
    "\n",
    "    simpleaxis(thisAx)\n",
    "    thisAx.set_xscale('log')\n",
    "    thisAx.set_xticks([100,1000,6000], ['%d' % x for x in [100,1000,6000]], fontsize=fontsize)\n",
    "    thisAx.set_yticks([0,25,50], ['%d' % x for x in [0,25,50]], fontsize=fontsize)\n",
    "    # thisAx.set_xticks([100,5000,50000], ['%d' % x for x in [100,5000,50000]], fontsize=fontsize)\n",
    "    # thisAx.set_yticks([0,50,100], ['%d' % x for x in [0,50,100]], fontsize=fontsize)\n",
    "    # thisAx.set_ylim([0,40])\n",
    "    thisAx.set_xlabel('# samples', fontsize=fontsize)\n",
    "    thisAx.set_ylabel('Validation error %', fontsize=fontsize)\n",
    "    handles, labels = thisAx.get_legend_handles_labels()\n",
    "    fig.text(0.5,1,'Statistical Efficiency', ha='center', fontsize=fontsize)\n",
    "\n",
    "    by_label = {labels[i]:handles[i] for i in [0,1,2]}\n",
    "    lgd = thisAx.legend(by_label.values(), by_label.keys(), loc='upper left', bbox_to_anchor=(1,.99), fancybox=True, framealpha=0, fontsize=fontsize)\n",
    "    plt.savefig(os.path.join(os.getcwd(), 'sample_efficiency.png'), dpi=400, bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Below cell plot for Supp Figure A.3 first column"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize=20\n",
    "resetType_list = ['posRand']\n",
    "for resetType in resetType_list:\n",
    "    conver_smart = get_conver(metrics, doWeightFreeze_list, resetType_list, 'pretrained', numRun)\n",
    "    conver_dumb = get_conver(metrics, doWeightFreeze_list, resetType_list, 'vanilla', numRun)\n",
    "    conver_finetune = get_conver(metrics, doWeightFreeze_list, resetType_list, 'finetune', numRun)\n",
    "    fig = plt.figure(figsize=(5,4))\n",
    "\n",
    "    thisAx = plt.gca()\n",
    "    doExample = False\n",
    "    x_coord = np.array(sample_size_list)\n",
    "    xlim=[np.min(x_coord), np.max(x_coord)]\n",
    "    ylim=0\n",
    "    color_list = ['r','b', [0,1,1], [1,.5,0]]\n",
    "\n",
    "    for didx, doWeightFreeze in enumerate(doWeightFreeze_list):\n",
    "        if doWeightFreeze:\n",
    "            typeStr = 'Freeze'\n",
    "            typeStr_dict = '\\'freeze\\''\n",
    "        else:\n",
    "            typeStr = 'Fluid'\n",
    "            typeStr_dict = '\\'liquid\\''\n",
    "        \n",
    "        print(typeStr)\n",
    "        x_mat = conver_smart.query(dict_to_query({'typeStr':typeStr_dict, 'resetType':'\\''+resetType+'\\''}))['acc'].tolist()[0]\n",
    "        if doWeightFreeze:\n",
    "            typeStrTmp = typeStr+' IN-Polarity'\n",
    "        else:\n",
    "            typeStrTmp = typeStr\n",
    "        plot_median_plus_example((1-x_mat)*100, x_coord, ylim, xlim, typeStrTmp, doExample, thisAx, color=color_list[didx])\n",
    "        # plot_sem((1-x_mat)*100, x_coord, ylim, xlim, typeStr, thisAx, color=color_list[didx])\n",
    "\n",
    "        x_mat = conver_finetune.query(dict_to_query({'typeStr':typeStr_dict, 'resetType':'\\''+resetType+'\\''}))['acc'].tolist()[0]\n",
    "        plot_median_plus_example((1-x_mat)*100, x_coord, ylim, xlim, typeStr+' IN-Weight', doExample, thisAx, color=color_list[didx+2], lineType = '--')\n",
    "        # plot_sem((1-x_mat)*100, x_coord, ylim, xlim, typeStr, thisAx, color=color_list[didx], lineType = '--')\n",
    "\n",
    "    # plt.plot([100,50000], [50,50], '--', color=[.5,.5,.5])\n",
    "    # plt.plot([100,6000], [20,20], '--', color=[.5,.5,.5])\n",
    "\n",
    "    simpleaxis(thisAx)\n",
    "    thisAx.set_xscale('log')\n",
    "    thisAx.set_xticks([100,1000,6000], ['%d' % x for x in [100,1000,6000]], fontsize=fontsize)\n",
    "    thisAx.set_yticks([0,25,50], ['%d' % x for x in [0,25,50]], fontsize=fontsize)\n",
    "    # thisAx.set_xticks([100,5000,50000], ['%d' % x for x in [100,5000,50000]], fontsize=fontsize)\n",
    "    # thisAx.set_yticks([0,50,100], ['%d' % x for x in [0,50,100]], fontsize=fontsize)\n",
    "    # thisAx.set_ylim([0,40])\n",
    "    thisAx.set_xlabel('# samples', fontsize=fontsize)\n",
    "    thisAx.set_ylabel('Validation error %', fontsize=fontsize)\n",
    "    handles, labels = thisAx.get_legend_handles_labels()\n",
    "    fig.text(0.5,1,'Statistical Efficiency', ha='center', fontsize=fontsize)\n",
    "\n",
    "    by_label = {labels[i]:handles[i] for i in [0,3,1,2]}\n",
    "    lgd = thisAx.legend(by_label.values(), by_label.keys(), loc='upper left', bbox_to_anchor=(1,.99), fancybox=True, framealpha=0, fontsize=fontsize)\n",
    "    plt.savefig(os.path.join(os.getcwd(), 'sample_efficiency_finetune.png'), dpi=400, bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Below cell plot for Figure 3 first column"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize=20\n",
    "resetType_list = ['posRand']\n",
    "for resetType in resetType_list:\n",
    "    conver_smart = get_conver(metrics, doWeightFreeze_list, resetType_list, 'pretrained', numRun)\n",
    "    conver_dumb = get_conver(metrics, doWeightFreeze_list, resetType_list, 'vanilla', numRun)\n",
    "    conver_finetune = get_conver(metrics, doWeightFreeze_list, resetType_list, 'finetune', numRun)\n",
    "    fig = plt.figure(figsize=(5,4))\n",
    "\n",
    "    thisAx = plt.gca()\n",
    "    doExample = False\n",
    "    x_coord = np.array(sample_size_list)\n",
    "    xlim=[np.min(x_coord), np.max(x_coord)]\n",
    "    ylim=0\n",
    "    color_list = ['b', [0,1,1], [1,.5,0]]\n",
    "\n",
    "    base_line = conver_smart.query(dict_to_query({'typeStr':'\\'freeze\\'', 'resetType':'\\''+resetType+'\\''}))['acc'].tolist()[0]\n",
    "\n",
    "    y_mat = conver_smart.query(dict_to_query({'typeStr':'\\'liquid\\'', 'resetType':'\\''+resetType+'\\''}))['acc'].tolist()[0]\n",
    "    plot_diff_plus_mannwhitneyu((1-base_line)*100, (1-y_mat)*100, x_coord, ylim, xlim, 'Liquid RAND-Polarity', thisAx, color=color_list[0], linewidth=4)\n",
    "\n",
    "    y_mat = conver_finetune.query(dict_to_query({'typeStr':'\\'freeze\\'', 'resetType':'\\''+resetType+'\\''}))['acc'].tolist()[0]\n",
    "    plot_diff_plus_mannwhitneyu((1-base_line)*100, (1-y_mat)*100, x_coord, ylim, xlim, 'Freeze IN-Weight', thisAx, color=color_list[1], linewidth=4)\n",
    "\n",
    "    y_mat = conver_finetune.query(dict_to_query({'typeStr':'\\'liquid\\'', 'resetType':'\\''+resetType+'\\''}))['acc'].tolist()[0]\n",
    "    plot_diff_plus_mannwhitneyu((1-base_line)*100, (1-y_mat)*100, x_coord, ylim, xlim, 'Liquid IN-Weight', thisAx, color=color_list[2], linewidth=4)\n",
    "\n",
    "    plt.plot([90,50000], [0,0], '--', color=[.5,.5,.5])\n",
    "    # plt.plot([90,6000], [0,0], '--', color=[.5,.5,.5])\n",
    "\n",
    "    simpleaxis(thisAx)\n",
    "    thisAx.set_xscale('log')\n",
    "    # thisAx.set_xticks([100,1000,6000], ['%d' % x for x in [100,1000,6000]], fontsize=fontsize)\n",
    "    # thisAx.set_yticks([0,10], ['%d' % x for x in [0,10]], fontsize=fontsize)\n",
    "    # thisAx.set_xlim([90,6000])\n",
    "    thisAx.set_xticks([100,5000,50000], ['%d' % x for x in [100,5000,50000]], fontsize=fontsize)\n",
    "    thisAx.set_yticks([0,25], ['%d' % x for x in [0,25]], fontsize=fontsize)\n",
    "    thisAx.set_xlim([90,50000])\n",
    "    # thisAx.set_ylim([0,40])\n",
    "    thisAx.set_xlabel('# samples', fontsize=fontsize)\n",
    "    thisAx.set_ylabel('$\\Delta$ Validation error %', fontsize=fontsize)\n",
    "    handles, labels = thisAx.get_legend_handles_labels()\n",
    "    fig.text(0.5,1,'Statistical Efficiency', ha='center', fontsize=fontsize)\n",
    "\n",
    "    by_label = {labels[i]:handles[i] for i in [2,1,0]}\n",
    "    lgd = thisAx.legend(by_label.values(), by_label.keys(), loc='upper left', bbox_to_anchor=(1,.99), fancybox=True, framealpha=0, fontsize=fontsize)\n",
    "    plt.savefig(os.path.join(os.getcwd(), 'sample_efficiency_test.png'), dpi=400, bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.**TIME: Computational Efficiency** - Figure 2 & 3 third column"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load func"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_learnt_epoch(x, cutoff):\n",
    "    epoch_num = np.squeeze(np.where(x >= cutoff))\n",
    "    if not len(epoch_num.shape) == 0: # only has single element\n",
    "        if epoch_num.shape[0]==0:\n",
    "            epoch_num = np.NaN\n",
    "        else:\n",
    "            epoch_num = epoch_num[0].astype(np.float64)\n",
    "    else:\n",
    "        epoch_num = epoch_num.astype(np.float64)\n",
    "    return epoch_num"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_flops(metrics, doWeightFreeze_list, acc_cutoff_list, resetType_list, networkType, numRun):\n",
    "    flops = pd.DataFrame(index=range(len(doWeightFreeze_list)*len(acc_cutoff_list)*len(resetType_list)), columns=[\"typeStr\", \"resetType\", \"networkType\", \"acc\", \"epoch\", \"flops\"])\n",
    "    counter = 0\n",
    "\n",
    "    for resetType in resetType_list:\n",
    "        for acc_cutoff in acc_cutoff_list:\n",
    "            for doWeightFreeze in doWeightFreeze_list:\n",
    "                if doWeightFreeze:\n",
    "                    typeStr = 'freeze'\n",
    "                    typeStr_dict = '\\'freeze\\''\n",
    "                else:\n",
    "                    typeStr = 'liquid'\n",
    "                    typeStr_dict = '\\'liquid\\''\n",
    "                flops['typeStr'][counter] = typeStr\n",
    "                flops['acc'][counter] = acc_cutoff\n",
    "                flops['resetType'][counter] = resetType\n",
    "                flops['networkType'][counter] = networkType\n",
    "\n",
    "                # get different sample sizes into list\n",
    "                page_dict = {'networkType':'\\'' + networkType + '\\'', 'typeStr':typeStr_dict, 'resetType':'\\'' + resetType + '\\''} #, 'train_sample':sample_size_list[8]\n",
    "                val_acc = metrics.query(dict_to_query(page_dict))['validation_acc'].tolist()\n",
    "                # print(val_acc[0].shape)\n",
    "\n",
    "                epoch_at_learnt = np.empty(shape=(len(sample_size_list), numRun), dtype=np.float64)\n",
    "                for idx, this_val_acc in enumerate(val_acc):\n",
    "                    this_epoch_at_learnt = np.apply_along_axis(get_learnt_epoch, 1, this_val_acc, acc_cutoff)\n",
    "                    epoch_at_learnt[idx, :] = np.squeeze(this_epoch_at_learnt)\n",
    "                    del this_epoch_at_learnt\n",
    "                flops['epoch'][counter] = np.transpose(epoch_at_learnt)\n",
    "                flops['flops'][counter] = flops['epoch'][counter]*62388354 # hSize_infocus*3\n",
    "                counter+=1\n",
    "        # print('@acc %1.2f is done' % acc_cutoff)\n",
    "    return flops"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Below cell plot for Figure 2 third column & Supp Figure A.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# acc_cutoff_list=[.5,.6,.75, 0.8, .82, .85]\n",
    "acc_cutoff_list=[.45,.5,.55,.6,.75, 0.8, .82, .85]\n",
    "# acc_cutoff_list=[.3,.4,.5,.6]\n",
    "colName = 'epoch'\n",
    "doTitle = False\n",
    "resetType_list = ['posRand']\n",
    "for resetType in resetType_list:\n",
    "    flops_smart = get_flops(metrics, doWeightFreeze_list, acc_cutoff_list, resetType_list, 'pretrained', numRun)\n",
    "    flops_dumb = get_flops(metrics, doWeightFreeze_list, acc_cutoff_list, resetType_list, 'vanilla', numRun)\n",
    "    flops_finetune = get_flops(metrics, doWeightFreeze_list, acc_cutoff_list, resetType_list, 'finetune', numRun)\n",
    "\n",
    "    fontsize=20\n",
    "    doExample = False\n",
    "    x_coord = np.array(sample_size_list)\n",
    "    xlim=[100, np.max(x_coord)]\n",
    "    ylim=0\n",
    "    lineTypeList = ['-', '--']\n",
    "    lineType = '-'\n",
    "    color_list = ['r','b','g']\n",
    "\n",
    "    for acc_cutoff in acc_cutoff_list:\n",
    "        fig = plt.figure(figsize=(5,4))\n",
    "        thisAx = plt.gca()\n",
    "    \n",
    "        for doWeightFreeze, color in zip(doWeightFreeze_list, color_list[:2]):\n",
    "            if doWeightFreeze:\n",
    "                typeStr = 'Freeze_informed'\n",
    "                typeStr_dict = '\\'freeze\\''\n",
    "            else:\n",
    "                typeStr = 'Fluid'\n",
    "                typeStr_dict = '\\'liquid\\''\n",
    "\n",
    "            x_mat = flops_smart.query(dict_to_query({'typeStr':typeStr_dict, 'acc':acc_cutoff, 'resetType': '\\''+ resetType + '\\''}))[colName].tolist()[0]\n",
    "            plot_median_plus_example(x_mat, x_coord, ylim, xlim, typeStr, doExample, thisAx, color=color, lineType=lineType)\n",
    "\n",
    "        x_mat = flops_dumb.query(dict_to_query({'typeStr':'\\'freeze\\'', 'acc':acc_cutoff, 'resetType': '\\''+ resetType + '\\''}))[colName].tolist()[0]\n",
    "        plot_median_plus_example(x_mat, x_coord, ylim, xlim, 'Freeze_uninformed', doExample, thisAx, color=color_list[2], lineType=lineType)\n",
    "\n",
    "        simpleaxis(thisAx)\n",
    "        thisAx.set_xscale('log')\n",
    "        thisAx.set_xlim(xlim)\n",
    "        thisAx.set_xticks([100,1000,6000], ['%d' % x for x in [100,1000,6000]], fontsize=fontsize)\n",
    "        # thisAx.set_xticks([100,5000,50000], ['%d' % x for x in [100,5000,50000]], fontsize=fontsize)\n",
    "\n",
    "        y_min, y_max = thisAx.get_ylim()\n",
    "        thisAx.set_yticks([0,round(y_max, -1)/2,round(y_max, -1)], ['%d' % x for x in [0,round(y_max, -1)/2,round(y_max, -1)]], fontsize=fontsize)\n",
    "        thisAx.set_xlabel('# samples', fontsize=fontsize)\n",
    "        thisAx.set_ylabel('#'+colName+'s to\\n reach %d%% accuracy' % int(acc_cutoff*100), fontsize=fontsize)\n",
    "\n",
    "        if doTitle:\n",
    "            y_min, y_max = thisAx.get_ylim()\n",
    "            thisAx.text(100, y_max*1.05, resetType,fontweight=\"bold\",fontsize=fontsize, ha='center')\n",
    "\n",
    "        handles, labels = thisAx.get_legend_handles_labels()\n",
    "        fig.text(0.5,1,'Computational Efficiency', ha='center', fontsize=fontsize)\n",
    "\n",
    "        # by_label = {labels[i]:handles[i] for i in [0,2,1]}\n",
    "        # lgd = thisAx.legend(by_label.values(), by_label.keys(), loc='upper left', bbox_to_anchor=(1,.99), fancybox=True, framealpha=0, fontsize=25)\n",
    "        plt.savefig(os.path.join(os.getcwd(), 'computational_efficiency_%1.2f.png' % acc_cutoff), dpi=400, bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Below cell plot for Supp Figure 3 third column"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# acc_cutoff_list=[.5,.6,.75, 0.8, .82, .85]\n",
    "acc_cutoff_list=[.45,.5,.55,.6,.75, 0.8, .82, .85]\n",
    "# acc_cutoff_list=[.3,.4,.5,.6]\n",
    "colName = 'epoch'\n",
    "doTitle = False\n",
    "resetType_list = ['posRand']\n",
    "for resetType in resetType_list:\n",
    "    flops_smart = get_flops(metrics, doWeightFreeze_list, acc_cutoff_list, resetType_list, 'pretrained', numRun)\n",
    "    flops_dumb = get_flops(metrics, doWeightFreeze_list, acc_cutoff_list, resetType_list, 'vanilla', numRun)\n",
    "    flops_finetune = get_flops(metrics, doWeightFreeze_list, acc_cutoff_list, resetType_list, 'finetune', numRun)\n",
    "\n",
    "    fontsize=20\n",
    "    doExample = False\n",
    "    x_coord = np.array(sample_size_list)\n",
    "    xlim=[100, np.max(x_coord)]\n",
    "    ylim=0\n",
    "    lineTypeList = ['-', '--']\n",
    "    lineType = '-'\n",
    "    color_list = ['r','b', [0,1,1], [1,.5,0]]\n",
    "    \n",
    "    for acc_cutoff in acc_cutoff_list:\n",
    "        fig = plt.figure(figsize=(5,4))\n",
    "        thisAx = plt.gca()\n",
    "    \n",
    "        for didx, doWeightFreeze in enumerate(doWeightFreeze_list):\n",
    "            if doWeightFreeze:\n",
    "                typeStr = 'Freeze_informed'\n",
    "                typeStr_dict = '\\'freeze\\''\n",
    "            else:\n",
    "                typeStr = 'Fluid'\n",
    "                typeStr_dict = '\\'liquid\\''\n",
    "\n",
    "            x_mat = flops_smart.query(dict_to_query({'typeStr':typeStr_dict, 'acc':acc_cutoff, 'resetType': '\\''+ resetType + '\\''}))[colName].tolist()[0]\n",
    "            plot_median_plus_example(x_mat, x_coord, ylim, xlim, typeStr, doExample, thisAx, color=color_list[didx], lineType=lineType)\n",
    "            x_mat = flops_finetune.query(dict_to_query({'typeStr':typeStr_dict, 'acc':acc_cutoff, 'resetType': '\\''+ resetType + '\\''}))[colName].tolist()[0]\n",
    "            plot_median_plus_example(x_mat, x_coord, ylim, xlim, typeStr+'-finetune', doExample, thisAx, color=color_list[didx+2], lineType='--')\n",
    "\n",
    "        simpleaxis(thisAx)\n",
    "        thisAx.set_xscale('log')\n",
    "        thisAx.set_xlim(xlim)\n",
    "        thisAx.set_xticks([100,1000,6000], ['%d' % x for x in [100,1000,6000]], fontsize=fontsize)\n",
    "        # thisAx.set_xticks([100,5000,50000], ['%d' % x for x in [100,5000,50000]], fontsize=fontsize)\n",
    "\n",
    "        y_min, y_max = thisAx.get_ylim()\n",
    "        thisAx.set_yticks([0,round(y_max, -1)/2,round(y_max, -1)], ['%d' % x for x in [0,round(y_max, -1)/2,round(y_max, -1)]], fontsize=fontsize)\n",
    "        thisAx.set_xlabel('# samples', fontsize=fontsize)\n",
    "        thisAx.set_ylabel('#'+colName+'s to\\n reach %d%% accuracy' % int(acc_cutoff*100), fontsize=fontsize)\n",
    "\n",
    "        if doTitle:\n",
    "            y_min, y_max = thisAx.get_ylim()\n",
    "            thisAx.text(100, y_max*1.05, resetType,fontweight=\"bold\",fontsize=fontsize, ha='center')\n",
    "\n",
    "        handles, labels = thisAx.get_legend_handles_labels()\n",
    "        fig.text(0.5,1,'Computational Efficiency', ha='center', fontsize=fontsize)\n",
    "\n",
    "        # by_label = {labels[i]:handles[i] for i in [0,2,3,1]}\n",
    "        # lgd = thisAx.legend(by_label.values(), by_label.keys(), loc='upper left', bbox_to_anchor=(1,.99), fancybox=True, framealpha=0, fontsize=25)\n",
    "        plt.savefig(os.path.join(os.getcwd(), 'computational_efficiency_%1.2f_finetune.png' % acc_cutoff), dpi=400, bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Below cell plot for Figure 3 third column & Supp Figure A.4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# acc_cutoff_list=[.5,.6,.75, 0.8, .82, .85]\n",
    "acc_cutoff_list=[.45,.5,.55,.6,.75, 0.8, .82, .85]\n",
    "# acc_cutoff_list=[.3,.4,.5,.6]\n",
    "colName = 'epoch'\n",
    "doTitle = False\n",
    "resetType_list = ['posRand']\n",
    "for resetType in resetType_list:\n",
    "    flops_smart = get_flops(metrics, doWeightFreeze_list, acc_cutoff_list, resetType_list, 'pretrained', numRun)\n",
    "    flops_dumb = get_flops(metrics, doWeightFreeze_list, acc_cutoff_list, resetType_list, 'vanilla', numRun)\n",
    "    flops_finetune = get_flops(metrics, doWeightFreeze_list, acc_cutoff_list, resetType_list, 'finetune', numRun)\n",
    "\n",
    "    fontsize=20\n",
    "    doExample = False\n",
    "    x_coord = np.array(sample_size_list)\n",
    "    xlim=[100, np.max(x_coord)]\n",
    "    ylim=0\n",
    "    lineTypeList = ['-', '--']\n",
    "    lineType = '-'\n",
    "    color_list = ['b', [0,1,1], [1,.5,0]]\n",
    "    \n",
    "    for acc_cutoff in acc_cutoff_list:\n",
    "        fig = plt.figure(figsize=(5,4))\n",
    "        thisAx = plt.gca()\n",
    "\n",
    "        base_line = flops_smart.query(dict_to_query({'typeStr':'\\'freeze\\'', 'acc':acc_cutoff, 'resetType': '\\''+ resetType + '\\''}))[colName].tolist()[0]\n",
    "\n",
    "        y_mat = flops_smart.query(dict_to_query({'typeStr':'\\'liquid\\'', 'acc':acc_cutoff, 'resetType': '\\''+ resetType + '\\''}))[colName].tolist()[0]\n",
    "        plot_diff_plus_mannwhitneyu(base_line, y_mat, x_coord, ylim, xlim, 'Liquid RAND-Polarity', thisAx, color=color_list[0], linewidth=4)\n",
    "\n",
    "        y_mat = flops_finetune.query(dict_to_query({'typeStr':'\\'freeze\\'', 'acc':acc_cutoff, 'resetType': '\\''+ resetType + '\\''}))[colName].tolist()[0]\n",
    "        plot_diff_plus_mannwhitneyu(base_line, y_mat, x_coord, ylim, xlim, 'Freeze IN-Weight', thisAx, color=color_list[1], linewidth=4)\n",
    "\n",
    "        y_mat = flops_finetune.query(dict_to_query({'typeStr':'\\'liquid\\'', 'acc':acc_cutoff, 'resetType': '\\''+ resetType + '\\''}))[colName].tolist()[0]\n",
    "        plot_diff_plus_mannwhitneyu(base_line, y_mat, x_coord, ylim, xlim, 'Liquid IN-Weight', thisAx, color=color_list[2], linewidth=4)\n",
    "\n",
    "        # plt.plot([100,50000], [0,0], '--', color=[.5,.5,.5])\n",
    "        plt.plot([100,6000], [0,0], '--', color=[.5,.5,.5])\n",
    "\n",
    "        simpleaxis(thisAx)\n",
    "        thisAx.set_xscale('log')\n",
    "        thisAx.set_xlim(xlim)\n",
    "        thisAx.set_xticks([100,1000,6000], ['%d' % x for x in [100,1000,6000]], fontsize=fontsize)\n",
    "        # thisAx.set_xticks([100,5000,50000], ['%d' % x for x in [100,5000,50000]], fontsize=fontsize)\n",
    "\n",
    "        y_min, y_max = thisAx.get_ylim()\n",
    "        thisAx.set_yticks([0,round(y_max, -1)/2,round(y_max, -1)], ['%d' % x for x in [0,round(y_max, -1)/2,round(y_max, -1)]], fontsize=fontsize)\n",
    "        thisAx.set_xlabel('# samples', fontsize=fontsize)\n",
    "        thisAx.set_ylabel('$\\Delta$ #'+colName+'s to\\n reach %d%% accuracy' % int(acc_cutoff*100), fontsize=fontsize)\n",
    "\n",
    "        if doTitle:\n",
    "            y_min, y_max = thisAx.get_ylim()\n",
    "            thisAx.text(100, y_max*1.05, resetType,fontweight=\"bold\",fontsize=fontsize, ha='center')\n",
    "\n",
    "        handles, labels = thisAx.get_legend_handles_labels()\n",
    "        fig.text(0.5,1,'Computational Efficiency', ha='center', fontsize=fontsize)\n",
    "\n",
    "        # by_label = {labels[i]:handles[i] for i in [0,2,3,1]}\n",
    "        # lgd = thisAx.legend(by_label.values(), by_label.keys(), loc='upper left', bbox_to_anchor=(1,.99), fancybox=True, framealpha=0, fontsize=25)\n",
    "        plt.savefig(os.path.join(os.getcwd(), 'computational_efficiency_diff_%1.2f_finetune.png' % acc_cutoff), dpi=400, bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.**PROBABILITY** - Figure 2&3 second column"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load func"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_prob_sim(metrics, numRun, typeStr, sample_size_list, acc_cutoff):\n",
    "    return np.stack([[np.sum(np.any(this_metric>acc_cutoff, axis=1))/numRun for this_metric in metrics.query(dict_to_query({'resetType':'\\'posRand\\'', 'typeStr':'\\''+typeStr+'\\'', 'train_sample':train_sample}))['validation_acc'].tolist()] for train_sample in sample_size_list])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Below cell plot for Figure 2 second column & Supp Figure A.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resetType_list = ['posRand']\n",
    "lineStyle_list = ['-']\n",
    "acc_cutoff_list=[.45,.5,.55,.6,.75, 0.8, .82, .85]\n",
    "# acc_cutoff_list=[.5,.6,.75, 0.8, .82,.85]\n",
    "# acc_cutoff_list=[.3,.4,.5,.6]\n",
    "\n",
    "for acc_cutoff in acc_cutoff_list:\n",
    "    for resetType in resetType_list:\n",
    "        fig = plt.figure(figsize=(5,4))\n",
    "        fontsize=20\n",
    "        x_l = sample_size_list\n",
    "\n",
    "        y_l = {}\n",
    "        y_l['Freeze_informed'] = [np.sum(np.any(this_metric>acc_cutoff, axis=1))/numRun for this_metric in metrics.query(dict_to_query({'networkType':'\\'pretrained\\'', 'resetType':'\\''+resetType+'\\'', 'typeStr':'\\'freeze\\''}))['validation_acc'].tolist()]\n",
    "        y_l['Fluid'] = [np.sum(np.any(this_metric>acc_cutoff, axis=1))/numRun for this_metric in metrics.query(dict_to_query({'networkType':'\\'pretrained\\'', 'resetType':'\\''+resetType+'\\'', 'typeStr':'\\'liquid\\''}))['validation_acc'].tolist()]\n",
    "\n",
    "        plt.plot(x_l, np.array(y_l['Freeze_informed'])*100, label='Freeze_informed', color='r', linewidth=4)\n",
    "        plt.plot(x_l, np.array(y_l['Fluid'])*100, label='Fluid', color='b', linewidth=4)\n",
    "\n",
    "        for lineStyle in lineStyle_list:\n",
    "            y_l['Freeze_uninformed'] = [np.sum(np.any(this_metric>acc_cutoff, axis=1))/numRun for this_metric in metrics.query(dict_to_query({'networkType':'\\'vanilla\\'', 'resetType':'\\''+resetType+'\\'', 'typeStr':'\\'freeze\\''}))['validation_acc'].tolist()]\n",
    "            plt.plot(x_l, np.array(y_l['Freeze_uninformed'])*100, lineStyle, label='Freeze_uninformed', color='g', linewidth=4)\n",
    "\n",
    "        plt.xlabel('# samples', fontsize=fontsize)\n",
    "        plt.ylabel('%% trials reached\\n %d%% validation accuracy' % int(acc_cutoff*100), fontsize=fontsize)\n",
    "\n",
    "        plt.gca().set_xscale('log')\n",
    "        plt.gca().set_xticks([100,1000, 6000], ['%d' % x for x in [100,1000, 6000]], fontsize=fontsize)#x_coord\n",
    "        # plt.gca().set_xticks([100,5000, 50000], ['%d' % x for x in [100,5000, 50000]], fontsize=fontsize)#x_coord\n",
    "        plt.gca().set_xlim([100, max(sample_size_list)])\n",
    "        plt.yticks([0,100], fontsize=fontsize)\n",
    "        # plt.grid(True)\n",
    "        plt.gca().spines['top'].set_visible(False)\n",
    "        plt.gca().spines['right'].set_visible(False)\n",
    "        # plt.gca().legend(loc='upper right', bbox_to_anchor=[1,.95], fancybox=True, framealpha=0)\n",
    "        handles, labels = plt.gca().get_legend_handles_labels()\n",
    "        rect = plt.gca().patch\n",
    "        rect.set_alpha(0)\n",
    "        # fig.text(.5,1,'$P$(learning XOR)', ha='center', fontsize=fontsize)\n",
    "\n",
    "        # by_label = {labels[i]:handles[i] for i in [0,1,2]}\n",
    "        # lgd = plt.gca().legend(by_label.values(), by_label.keys(), loc='upper left', bbox_to_anchor=(1,.99), fancybox=True, framealpha=0, fontsize=fontsize)\n",
    "        plt.savefig(os.path.join(os.getcwd(), 'prob_learning_%1.2f.png' % acc_cutoff), dpi=400, bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Below cell plot for Figure 3 second column & Supp Figure A.3-4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resetType_list = ['posRand']\n",
    "lineStyle_list = ['-']\n",
    "# acc_cutoff_list=[.5,.6,.75, 0.8, .82,.85]\n",
    "# acc_cutoff_list=[.3,.4,.5,.6]\n",
    "\n",
    "for acc_cutoff in acc_cutoff_list:\n",
    "    for resetType in resetType_list:\n",
    "        fig = plt.figure(figsize=(5,4))\n",
    "        fontsize=20\n",
    "        x_l = sample_size_list\n",
    "\n",
    "        y_l = {}\n",
    "        y_l['Freeze_informed'] = [np.sum(np.any(this_metric>acc_cutoff, axis=1))/numRun for this_metric in metrics.query(dict_to_query({'networkType':'\\'pretrained\\'', 'resetType':'\\''+resetType+'\\'', 'typeStr':'\\'freeze\\''}))['validation_acc'].tolist()]\n",
    "        y_l['Fluid'] = [np.sum(np.any(this_metric>acc_cutoff, axis=1))/numRun for this_metric in metrics.query(dict_to_query({'networkType':'\\'pretrained\\'', 'resetType':'\\''+resetType+'\\'', 'typeStr':'\\'liquid\\''}))['validation_acc'].tolist()]\n",
    "\n",
    "        plt.plot(x_l, np.array(y_l['Freeze_informed'])*100, label='Freeze_informed', color='r', linewidth=4)\n",
    "        plt.plot(x_l, np.array(y_l['Fluid'])*100, label='Fluid', color='b', linewidth=4)\n",
    "\n",
    "        y_l['Freeze_finetune'] = [np.sum(np.any(this_metric>acc_cutoff, axis=1))/numRun for this_metric in metrics.query(dict_to_query({'networkType':'\\'finetune\\'', 'resetType':'\\''+resetType+'\\'', 'typeStr':'\\'freeze\\''}))['validation_acc'].tolist()]\n",
    "        y_l['Fluid_finetune'] = [np.sum(np.any(this_metric>acc_cutoff, axis=1))/numRun for this_metric in metrics.query(dict_to_query({'networkType':'\\'finetune\\'', 'resetType':'\\''+resetType+'\\'', 'typeStr':'\\'liquid\\''}))['validation_acc'].tolist()]\n",
    "        plt.plot(x_l, np.array(y_l['Freeze_finetune'])*100, label='Freeze_finetune', color=[0,1,1], linestyle = '--', linewidth=4)\n",
    "        plt.plot(x_l, np.array(y_l['Fluid_finetune'])*100, label='Fluid_finetune', color=[1,.5,0], linestyle = '--', linewidth=4)\n",
    "\n",
    "        plt.xlabel('# samples', fontsize=fontsize)\n",
    "        plt.ylabel('%% trials reached\\n %d%% validation accuracy' % int(acc_cutoff*100), fontsize=fontsize)\n",
    "\n",
    "        plt.gca().set_xscale('log')\n",
    "        plt.gca().set_xticks([100,1000, 6000], ['%d' % x for x in [100,1000, 6000]], fontsize=fontsize)#x_coord\n",
    "        # plt.gca().set_xticks([100,5000, 50000], ['%d' % x for x in [100,5000, 50000]], fontsize=fontsize)#x_coord\n",
    "        plt.yticks([0,100], fontsize=fontsize)\n",
    "        plt.gca().set_xlim([100, max(sample_size_list)])\n",
    "        # plt.grid(True)\n",
    "        plt.gca().spines['top'].set_visible(False)\n",
    "        plt.gca().spines['right'].set_visible(False)\n",
    "        # plt.gca().legend(loc='upper right', bbox_to_anchor=[1,.95], fancybox=True, framealpha=0)\n",
    "        handles, labels = plt.gca().get_legend_handles_labels()\n",
    "        rect = plt.gca().patch\n",
    "        rect.set_alpha(0)\n",
    "        # fig.text(.5,1,'$P$(learning XOR)', ha='center', fontsize=fontsize)\n",
    "\n",
    "        # by_label = {labels[i]:handles[i] for i in [0,1,2,3]}\n",
    "        # lgd = plt.gca().legend(by_label.values(), by_label.keys(), loc='upper left', bbox_to_anchor=(1,.99), fancybox=True, framealpha=0, fontsize=fontsize)\n",
    "        plt.savefig(os.path.join(os.getcwd(), 'prob_learning_%1.2f_finetune.png' % acc_cutoff), dpi=400, bbox_inches = 'tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.13 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
