{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e15057c8-45d4-4c32-9598-6af6c7391421",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import plotting_functions as p_f\n",
    "import cell_analyses as c_u\n",
    "import data_utils as d_u\n",
    "import model_utils as m_u\n",
    "\n",
    "COLOR = \"black\"\n",
    "plt.rcParams[\"text.color\"] = COLOR\n",
    "plt.rcParams[\"axes.labelcolor\"] = COLOR\n",
    "plt.rcParams[\"xtick.color\"] = COLOR\n",
    "plt.rcParams[\"ytick.color\"] = COLOR\n",
    "\n",
    "fontsize = 15\n",
    "linewidth = 3\n",
    "labelsize = 15\n",
    "legendsize = 12\n",
    "legend_loc = (0.45, 0.22)\n",
    "legend_ncol = 2\n",
    "window_len = 10\n",
    "dpi = 300\n",
    "cs = ['white', 'pink']\n",
    "\n",
    "fig_path = 'FIG PATH'\n",
    "\n",
    "save_path_net = fig_path + 'Subspace_Net/'\n",
    "save_path_ae = fig_path + 'Subspace_AE/'\n",
    "save_path_vae = fig_path + 'Subspace_VAE/'\n",
    "save_path_pattern = fig_path + 'Pattern/'\n",
    "\n",
    "data = {'sup_shal_lin_factors': {},\n",
    "        'sup_deep_nonlin_factors': {},\n",
    "        'sup_deep_lin_factors': {},\n",
    "        'sup_deep_2_subspace': {},\n",
    "        'ae_lin_factors': {},\n",
    "        'ae_nonlin_factors': {},\n",
    "        'vae_shapes3d': {},\n",
    "        'vae_shapes3d_baselines': {},\n",
    "        'vae_shapes3d_longer' : {},\n",
    "        'vae_shapes3d_baselines_longer': {},\n",
    "        'vae_dsprites_longer' : {},\n",
    "        'vae_dsprites_baselines_longer': {},\n",
    "        'ae_categorical': {},\n",
    "        'sup_shal_lin_factors_sparse': {},\n",
    "        'sup_shal_lin_factors_sparse_weights': {},\n",
    "       }\n",
    "\n",
    "loss_act = r\"$\\mathcal{L}_{activity}$,  \"\n",
    "loss_weight = r\"$\\mathcal{L}_{weight}$,  \"\n",
    "loss_nonneg = r\"$\\mathcal{L}_{nonneg}$,  \"\n",
    "loss_sparse = r\"$\\mathcal{L}_{sparse}$,  \"\n",
    "relu = \"ReLu,  \""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdc0362b-2383-4738-b96b-b47d30d93bdc",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# SubspaceNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7db208d8-3ac2-4255-ad59-4bb354917bef",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "model_folder_name = \"Subspaces\"  # 'Group_embedder' 'Subspaces_vae' 'Pattern_Learning'\n",
    "path1 = \"SOME PATH\" + model_folder_name + \"/\"\n",
    "path2 = \"SOME PATH\" + model_folder_name + \"/path2/\"\n",
    "paths = [path2, path1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "610d7ef5-f589-436e-a775-b5bb2d594a7b",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Supervised Shallow Net - Linear data -> Factors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b243f7d-65ba-48d7-8221-bca9f7bef2ae",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# this is one hidden layer net\n",
    "data['sup_shal_lin_factors']['info'] = {\n",
    "    relu + loss_act + loss_weight: {\n",
    "        \"2022-08-20\": [0, 1, 2],\n",
    "    },\n",
    "    loss_nonneg + loss_act + loss_weight: {\n",
    "        \"2022-08-20\": [3, 4, 5],\n",
    "    },\n",
    "    relu + loss_act: {\n",
    "        \"2022-08-20\": [9, 10, 11],\n",
    "    },\n",
    "    relu + loss_weight: {\n",
    "        \"2022-08-20\": [12, 13, 14],\n",
    "    },\n",
    "    loss_act + loss_weight: {\n",
    "        \"2022-08-20\": [6, 7, 8],\n",
    "    },\n",
    "}\n",
    "\n",
    "data['sup_shal_lin_factors'] = p_f.get_tensorboard_df(data['sup_shal_lin_factors'], paths=paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d390d6b-1960-4870-8715-0c5d43dbc8e9",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['sup_shal_lin_factors'], 'metrics/discrete_mil_0', 'MIR',\\\n",
    "                     save_path_net + 'shal_lin_factors_mil', label_keep=-3, ylim=(0,1.0)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8517772-e863-4d35-9519-aab665c2d1d7",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['sup_shal_lin_factors'], 'accuracies/r2', '$r^2$',\\\n",
    "                     save_path_net + 'shal_lin_factors_r2', label_keep=-3, ylim=(0.9999,1.0)) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fea3a239-3d6d-4e4c-a0cc-25d38c8fd1e1",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Supervised Shallow Net SPARSE neurons - Linear data -> Factors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64d2b457-7369-4b6a-8be4-65228900a272",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# this is one hidden layer net\n",
    "data['sup_shal_lin_factors_sparse']['info'] = {\n",
    "    loss_weight + r\"$\\beta_{sparse}=0.0001$\": {\n",
    "        \"2022-09-27\": [0, 2, 4],\n",
    "    },\n",
    "    loss_weight + r\"$\\beta_{sparse}=0.001$\": {\n",
    "        \"2022-09-27\": [1, 3, 6],\n",
    "    },\n",
    "    loss_weight + r\"$\\beta_{sparse}=0.01$\": {\n",
    "        \"2022-09-27\": [5, 7, 8],\n",
    "    },\n",
    "    loss_weight + r\"$\\beta_{sparse}=0.1$\": {\n",
    "        \"2022-09-27\": [9, 10, 11],\n",
    "    },\n",
    "    r\"$\\beta_{sparse}=0.0001$\" : {\n",
    "        \"2022-09-27\": [12, 13, 14],\n",
    "    },\n",
    "    r\"$\\beta_{sparse}=0.001$\": {\n",
    "        \"2022-09-27\": [15, 16, 17],\n",
    "    },\n",
    "    r\"$\\beta_{sparse}=0.01$\": {\n",
    "        \"2022-09-27\": [18, 19, 21],\n",
    "    },\n",
    "    r\"$\\beta_{sparse}=0.1$\": {\n",
    "        \"2022-09-27\": [20, 22, 23],\n",
    "    },\n",
    "}\n",
    "\n",
    "data['sup_shal_lin_factors_sparse'] = p_f.get_tensorboard_df(data['sup_shal_lin_factors_sparse'], paths=paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "674b9f0b-df65-4863-a122-f109f665607c",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['sup_shal_lin_factors_sparse'], 'metrics/discrete_mil_0', 'MIR',\\\n",
    "                     save_path_net + 'shal_lin_factors_sparse_mil', label_keep=None, ylim=(0,1.0), legend_loc_=(0.05, 0.8))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7498ba73-f1e6-46a4-ba86-02f46e3222c5",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Supervised Shallow Net SPARSE weights - Linear data -> Factors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f598748b-5050-4bc9-aaea-6b00c62a00d6",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# this is one hidden layer net_weights\n",
    "data['sup_shal_lin_factors_sparse_weights']['info'] = {\n",
    "    r\"$\\beta_{sparse}=0.0001$\" : {\n",
    "        \"2022-11-09\": [17, 19, 20, 22, 25],\n",
    "    },\n",
    "    r\"$\\beta_{sparse}=0.001$\": {\n",
    "        \"2022-11-09\": [15, 16, 18, 21 ,23, 24],\n",
    "    },\n",
    "    r\"$\\beta_{sparse}=0.01$\": {\n",
    "        \"2022-11-09\": [6, 7, 9, 13, 14],\n",
    "    },\n",
    "    r\"$\\beta_{sparse}=0.1$\": {\n",
    "        \"2022-11-09\": [5, 8, 10, 11, 12],\n",
    "    },\n",
    "    loss_act+ r\"$\\beta_{sparse}=0.0001$\" : {\n",
    "        \"2022-11-10\": [0, 2, 4],\n",
    "    },\n",
    "    loss_act + r\"$\\beta_{sparse}=0.001$\": {\n",
    "        \"2022-11-10\": [1, 3, 9],\n",
    "    },\n",
    "    loss_act + r\"$\\beta_{sparse}=0.01$\": {\n",
    "        \"2022-11-10\": [5, 6, 11],\n",
    "    },\n",
    "    loss_act + r\"$\\beta_{sparse}=0.1$\": {\n",
    "        \"2022-11-10\": [7, 8, 10],\n",
    "    },\n",
    "    #r\"$\\beta_{sparse}=1.0$\": {\n",
    "    #    \"2022-11-09\": [0, 1, 2, 3, 4],\n",
    "    #},\n",
    "}\n",
    "\n",
    "data['sup_shal_lin_factors_sparse_weights'] = p_f.get_tensorboard_df(data['sup_shal_lin_factors_sparse_weights'], paths=paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4784fe3c-bc93-4c49-83ca-c4be79d9a5f4",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['sup_shal_lin_factors_sparse_weights'], 'metrics/discrete_mil_0', 'MIR',\\\n",
    "                     save_path_net + 'shal_lin_factors_sparse_weights_mil', label_keep=None, ylim=(0,1.0), legend_loc_=(0.05, 0.8))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29765753-acc3-4c82-97a2-fcf7a0570867",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Supervised Deep Net - Nonlinear data -> Factors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0702a388-ea78-4c08-950b-08a3d311c5f0",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# this one has large network size, WITH non-linear function on input\n",
    "data['sup_deep_nonlin_factors']['info'] = {\n",
    "    \"non-linear data factor dim = 6\": {\n",
    "        \"2022-08-20\": [15, 16],\n",
    "        \"2022-08-21\": [0, 1, 7],\n",
    "    },\n",
    "    \"non-linear data factor dim = 4\": {\n",
    "        \"2022-08-21\": [2, 3, 4, 5, 6],\n",
    "    },\n",
    "}\n",
    "\n",
    "data['sup_deep_nonlin_factors'] = p_f.get_tensorboard_df(data['sup_deep_nonlin_factors'], paths=paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13848e05-1ec3-43fc-a2b6-de300fd37b1b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_training_curves_layers(data['sup_deep_nonlin_factors'], 'metrics/discrete_mil_', 'MIR', \\\n",
    "                            save_path_net + 'deep_nonlin_factors_mil_', label_keep=None, ylim=(0,1), legend_loc_=(0.40, 0.16))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02be2ded-70dd-4ee3-974e-f20d3a394cad",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Supervised Deep Net - Linear data -> Factors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d47dded-6ed0-40a4-875e-4d17b381dad3",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# this one has large network size, WITHOUT non-linear function on input\n",
    "data['sup_deep_lin_factors']['info'] = {\n",
    "    \"linear data factor dim = 6\": {\n",
    "        \"2022-08-21\": [23, 24, 25, 26, 27],\n",
    "    },\n",
    "    \"linear data factor dim = 4\": {\n",
    "        \"2022-08-21\": [28, 29, 30, 31, 32],\n",
    "    },\n",
    "}\n",
    "\n",
    "data['sup_deep_lin_factors'] = p_f.get_tensorboard_df(data['sup_deep_lin_factors'], paths=paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e6a3fb6-c945-4c1d-b898-f05126131b6e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_training_curves_layers(data['sup_deep_lin_factors'], 'metrics/discrete_mil_', 'MIR', \\\n",
    "                            save_path_net + 'deep_lin_factors_mil_', label_keep=None, ylim=(0,1), legend_loc_=(0.40, 0.16))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27d90f40-8091-4f30-a207-fc9a57ef9518",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Autoencoders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b063767b-8c94-4c1c-a2d5-124b49edcc35",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "model_folder_name = \"Subspaces_vae\"  # 'Group_embedder' 'Subspaces_vae' 'Pattern_Learning'\n",
    "path1 = \"SOME PATH\" + model_folder_name + \"/\"\n",
    "path2 = \"OTHER PATH\" + model_folder_name + \"/path2/\"\n",
    "paths = [path2, path1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2e12536-f959-41b1-8d05-b7bad63818a7",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Autoencoder - Linear synthetic data - factors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2d761bb-32bb-474f-99ea-f300a2a7f739",
   "metadata": {
    "pycharm": {
     "name": "#%%\n",
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "# autoencoder linear synthetic data\n",
    "data['ae_lin_factors']['info'] = {\n",
    "    relu + loss_act + loss_weight: {\n",
    "        \"2022-08-21\": [ 9, 10, 11, 12, 13 ],\n",
    "    },\n",
    "    loss_nonneg + loss_act + loss_weight: {\n",
    "        \"2022-08-21\": [ 24, 25, 26, 27, 28 ],\n",
    "    },\n",
    "    relu + loss_act: {\n",
    "        \"2022-08-21\": [ 14, 15, 16, 17, 18 ],\n",
    "    },\n",
    "    relu + loss_weight: {\n",
    "        \"2022-08-21\": [ 19, 20, 21, 22, 23 ],\n",
    "    },\n",
    "    loss_act + loss_weight: {\n",
    "        \"2022-08-21\": [ 29, 30, 31, 32],\n",
    "    },\n",
    "}\n",
    "\n",
    "data['ae_lin_factors'] = p_f.get_tensorboard_df(data['ae_lin_factors'], paths=paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce5277fd-bdfd-44fb-89c2-cda94a242dbe",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['ae_lin_factors'], 'metrics/discrete_mig', 'MIG',\\\n",
    "                 save_path_ae + 'lin_factors_mig', label_keep=-3, ylim=(0,0.5))                  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5535ee86-25b4-49dc-8a77-e7591607787f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['ae_lin_factors'], 'metrics/discrete_mil', 'MIR',\\\n",
    "                 save_path_ae + 'lin_factors_mil', label_keep=-3, ylim=(0,1.0))  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5aaab54-0ef0-42a6-baf2-775671457d1c",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['ae_lin_factors'], 'accuracies/r2', '$r^2$',\\\n",
    "                 save_path_ae + 'lin_factors_r2', label_keep=-3, ylim=(0.95,1.0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a04be453-f593-4311-9174-e3952a635c9a",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Autoencoder - Nonlinear synthetic data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac9dd5f1-da3a-47d2-a54e-4b0de2a70af0",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# autoencoder nonlinear synthetic data\n",
    "data['ae_nonlin_factors']['info'] = {\n",
    "    relu + loss_act + loss_weight: {\n",
    "        \"2022-08-21\": [ 6, 7, 8, 34, 35],\n",
    "    },\n",
    "    loss_nonneg + loss_act + loss_weight: {\n",
    "        \"2022-08-21\": [ 40, 41, 42, 43, 44],\n",
    "    },\n",
    "    relu + loss_act: {\n",
    "        \"2022-08-21\": [ 50, 51, 52, 53, 54],\n",
    "    },\n",
    "    relu + loss_weight: {\n",
    "        \"2022-08-21\": [ 55, 56, 57, 58, 59],\n",
    "    },\n",
    "    loss_act + loss_weight: {\n",
    "        \"2022-08-21\": [ 45, 46, 47, 48, 49],\n",
    "    },\n",
    "}\n",
    "\n",
    "# autoencoder nonlinear synthetic data\n",
    "data['ae_nonlin_factors']['info'] = {\n",
    "    relu + loss_act + loss_weight: {\n",
    "        \"2022-11-11\": [2, 5, 6],\n",
    "        \"2022-11-12\": [0, 4],\n",
    "    },\n",
    "    loss_nonneg + loss_act + loss_weight: {\n",
    "        \"2022-11-11\": [0, 1, 3],\n",
    "        \"2022-11-12\": [7, 8],\n",
    "    },\n",
    "    relu + loss_act: {\n",
    "        \"2022-11-11\": [ 9, 11, 12 ],\n",
    "        \"2022-11-12\": [1, 6],\n",
    "    },\n",
    "    relu + loss_weight: {\n",
    "        \"2022-11-11\": [ 10, 13, 17],\n",
    "        \"2022-11-12\": [2, 5],\n",
    "    },\n",
    "    loss_act + loss_weight: {\n",
    "        \"2022-11-11\": [ 14, 15, 16],\n",
    "        \"2022-11-12\": [3, 9],\n",
    "    },\n",
    "}\n",
    "\n",
    "data['ae_nonlin_factors'] = p_f.get_tensorboard_df(data['ae_nonlin_factors'], paths=paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "713d667c-cd4b-448c-a012-ea356d0cf6a7",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['ae_nonlin_factors'], 'metrics/discrete_mig', 'MIG',\\\n",
    "                 save_path_ae + 'nonlin_factors_mig', label_keep=-3, ylim=(0,0.6)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7bbb44e-5600-4b52-a616-ce0249a21ff9",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['ae_nonlin_factors'], 'metrics/discrete_mil', 'MIR',\\\n",
    "                 save_path_ae + 'nonlin_factors_mil', label_keep=-3, ylim=(0,1.0)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c830ed4a-40e4-4884-9a98-c42e8ffcf43f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['ae_nonlin_factors'], 'accuracies/r2', '$r^2$',\\\n",
    "                 save_path_ae + 'nonlin_factors_r2', label_keep=-3, ylim=(0.95,1.0)) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c0c34af-0ea2-4daf-9ec2-e07b44b4b1a5",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### VAE - Shapes3D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45baa47e-cc6f-4271-8472-44db7026eab4",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# 500000 iterations:\n",
    "data['vae_shapes3d_longer']['info'] = {}\n",
    "data['vae_shapes3d_longer']['info'] = {\n",
    "        relu + r\"$\\beta_{VAE}=1$,  \" + r\"$\\beta_{weight}=1.0$\": {\n",
    "            \"2022-08-25\": [0, 1, 2, 3, 4], #checked\n",
    "    },\n",
    "        relu + r\"$\\beta_{VAE}=1$,  \" + r\"$\\beta_{weight}=0.1$\": {\n",
    "            \"2022-08-25\": [5, 6, 7, 8, 9], #checked\n",
    "    },\n",
    "        relu + r\"$\\beta_{VAE}=1$,  \" + r\"$\\beta_{weight}=0.01$\": {\n",
    "            \"2022-08-26\": [10, 11, 12, 13, 14], #checked\n",
    "    },\n",
    "}\n",
    "\n",
    "data['vae_shapes3d_longer'] = p_f.get_tensorboard_df(data['vae_shapes3d_longer'], min_steps=400, paths=paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "966d9929-f3e6-43c9-87de-63a1aa40b134",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['vae_shapes3d_longer'], 'metrics/discrete_mig', 'MIG',\\\n",
    "                 save_path_vae + 'shapes3d_mig', label_keep=None, ylim=(0,0.7), legend_ncol_=1, legend_loc_=(0.3, 0.2)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3e7a5a1-4377-4450-8373-03fea9774a4c",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['vae_shapes3d_longer'], 'metrics/discrete_mil', 'MIR',\\\n",
    "                 save_path_vae + 'shapes3d_mil', label_keep=None, ylim=(0,1.0), legend_ncol_=1, legend_loc_=(0.3, 0.2)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08a49e37-eb88-4e8c-bc47-35431935086b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['vae_shapes3d_longer'], 'accuracies/r2', '$r^2$',\\\n",
    "                 save_path_vae + 'shapes3d_r2', label_keep=None, ylim=(0.95,1.0), legend_ncol_=1, legend_loc_=(0.3, 0.2)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b83857e-4206-4e17-80cc-e3ea5ef9308a",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['vae_shapes3d_longer'], 'metrics/num used latents', 'Number Used Latents',\\\n",
    "                 save_path_vae + 'shapes3d_num_latents', label_keep=None, ylim=(0,10.0), legend_ncol_=1, legend_loc_=(0.3, 0.2))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b3ddfab-842f-499a-afd0-ddffa59fd753",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### VAE - Shapes3D baselines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed250bcc-610b-4b82-89c8-de0aed023bf7",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "data['vae_shapes3d_baselines_longer']['info'] = {}\n",
    "data['vae_shapes3d_baselines_longer']['info'] = {\n",
    "    r\"$\\beta_{VAE}=16$\": {\n",
    "        \"2022-08-25\": [10, 11, 12, 13, 14],\n",
    "    },\n",
    "    relu + r\"$\\beta_{VAE}=16$\": {\n",
    "    },\n",
    "    r\"$\\beta_{VAE}=4$\": {\n",
    "        \"2022-08-25\": [20, 21, 22, 23],\n",
    "        \"2022-08-26\": [0],\n",
    "    },\n",
    "    r\"$\\beta_{VAE}=1$\": {\n",
    "        \"2022-08-26\": [5, 6, 7, 8, 9],\n",
    "    },\n",
    "    relu + r\"$\\beta_{VAE}=4$\": {\n",
    "        \"2022-08-25\": [15, 16, 17, 18, 19],\n",
    "    },\n",
    "    relu + r\"$\\beta_{VAE}=1$\": {\n",
    "        \"2022-08-26\": [1, 2, 3, 4, 15],  \n",
    "    },  \n",
    "}\n",
    "\n",
    "data['vae_shapes3d_baselines_longer'] = p_f.get_tensorboard_df(data['vae_shapes3d_baselines_longer'], min_steps=400, paths=paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f5d2bd9-ba44-41d3-b95b-36605bd7582e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['vae_shapes3d_baselines_longer'], 'metrics/discrete_mig', 'MIG',\\\n",
    "                 save_path_vae + 'shapes3d_baselines_mig', label_keep=None, ylim=(0,0.7), legend_loc_=(0.0, 0.86)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02e3b56d-48df-40c6-aab7-86640f3f1ed8",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['vae_shapes3d_baselines_longer'], 'metrics/discrete_mil', 'MIR',\\\n",
    "                 save_path_vae + 'shapes3d_baselines_mil', label_keep=None, ylim=(0,1.0), legend_loc_=(0.3, 0.2)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa410129-6503-404d-b2d9-a8064f864af4",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['vae_shapes3d_baselines_longer'], 'accuracies/r2', '$r^2$',\\\n",
    "                 save_path_vae + 'shapes3d_baselines_r2', label_keep=None, ylim=(0.95,1.0), legend_loc_=(0.3, 0.2)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5074235a-11b3-41e1-b86e-3d764cbddc78",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['vae_shapes3d_baselines_longer'], 'metrics/num used latents', 'Number Used Latents',\\\n",
    "                 save_path_vae + 'shapes3d_baselines_num_latents', label_keep=None, ylim=(0,10.0), legend_loc_=(0.3, 0.2)) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18ceb2c4-a599-467d-a500-c7ba4b843e38",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### MIG vs Accuracy - for all models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce782400-e752-45e6-8d37-3de81708292c",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# get final value then plot r2 vs mig for all models. and colour models\n",
    "data_1 = data['vae_shapes3d_longer']\n",
    "data_2 = data['vae_shapes3d_baselines_longer']\n",
    "dfs_all_mig_1 = [[a['metrics/discrete_mig'][:data_1['cutoff']] for a in x] for x in data_1['dfs']]\n",
    "dfs_all_mig_2 = [[a['metrics/discrete_mig'][:data_2['cutoff']] for a in x] for x in data_2['dfs']]\n",
    "final_mig_1 = [pd.concat(a, axis=1).iloc[-1:] for a in dfs_all_mig_1]\n",
    "final_mig_2 = [pd.concat(a, axis=1).iloc[-1:] for a in dfs_all_mig_2]\n",
    "dfs_all_r2_1 = [[a['accuracies/r2'][:data_1['cutoff']] for a in x] for x in data_1['dfs']]\n",
    "dfs_all_r2_2 = [[a['accuracies/r2'][:data_2['cutoff']] for a in x] for x in data_2['dfs']]\n",
    "final_r2_1 = [pd.concat(a, axis=1).iloc[-1:] for a in dfs_all_r2_1]\n",
    "final_r2_2 = [pd.concat(a, axis=1).iloc[-1:] for a in dfs_all_r2_2]\n",
    "\n",
    "for lab, mig, r2 in zip(data_1['labels'] + data_2['labels'], final_mig_1 + final_mig_2, final_r2_1 + final_r2_2):\n",
    "    plt.scatter(mig, r2, label=lab, s=80)\n",
    "plt.legend(fontsize=legendsize, ncol=2, loc='center left', bbox_to_anchor=(-0.05, 0.28))\n",
    "plt.xlabel('MIG', fontsize=fontsize)\n",
    "plt.ylabel('$r^2$', fontsize=fontsize)\n",
    "\n",
    "plt.savefig(save_path_vae + 'shapes3d_mig_r2' + \".png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9d6ac87-b3da-404a-83bc-e77b27ead880",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# get final value then plot r2 vs mig for all models. and colour models\n",
    "data_1 = data['vae_shapes3d_longer']\n",
    "data_2 = data['vae_shapes3d_baselines_longer']\n",
    "dfs_all_mig_1 = [[a['metrics/discrete_mig'][:data_1['cutoff']] for a in x] for x in data_1['dfs']]\n",
    "dfs_all_mig_2 = [[a['metrics/discrete_mig'][:data_2['cutoff']] for a in x] for x in data_2['dfs']]\n",
    "final_mig_1 = [pd.concat(a, axis=1).iloc[-10:] for a in dfs_all_mig_1]\n",
    "final_mig_2 = [pd.concat(a, axis=1).iloc[-10:] for a in dfs_all_mig_2]\n",
    "dfs_all_ll_1 = [[a['losses/rec'][:data_1['cutoff']] for a in x] for x in data_1['dfs']]\n",
    "dfs_all_ll_2 = [[a['losses/rec'][:data_2['cutoff']] for a in x] for x in data_2['dfs']]\n",
    "final_ll_1 = [pd.concat(a, axis=1).iloc[-10:] for a in dfs_all_ll_1]\n",
    "final_ll_2 = [pd.concat(a, axis=1).iloc[-10:] for a in dfs_all_ll_2]\n",
    "\n",
    "for lab, mig, l_l in zip(data_1['labels'] + data_2['labels'], final_mig_1 + final_mig_2, final_ll_1 + final_ll_2):\n",
    "    plt.scatter(mig.mean(), -l_l.mean(), label=lab, s=80)\n",
    "plt.legend(fontsize=legendsize, ncol=2, loc='center left', bbox_to_anchor=(-0.05, 0.25))\n",
    "plt.xlabel('MIG', fontsize=fontsize)\n",
    "plt.ylabel('Log Liklihood', fontsize=fontsize)\n",
    "plt.savefig(save_path_vae + 'shapes3d_mig_ll' + \".png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55e785c0-a51e-4618-9db3-4859563845a1",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### dsprites"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b331c96c-b864-49bc-9118-0ad72911cf09",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "data['vae_dsprites_longer']['info'] = {}\n",
    "data['vae_dsprites_longer']['info'] = {\n",
    "        relu + r\"$\\beta_{VAE}=1$,  \" + r\"$\\beta_{weight}=10.0$\": {\n",
    "           \"2022-08-28\": [20, 21, 22, 23, 24]\n",
    "    },\n",
    "        relu + r\"$\\beta_{VAE}=1$,  \" + r\"$\\beta_{weight}=3.0$\": {\n",
    "           \"2022-08-28\": [25, 26, 27, 28, 28]\n",
    "    },\n",
    "        relu + r\"$\\beta_{VAE}=1$,  \" + r\"$\\beta_{weight}=1.0$\": {\n",
    "           \"2022-08-27\": [9, 10, 18, 19, 20]\n",
    "    },\n",
    "        relu + r\"$\\beta_{VAE}=1$,  \" + r\"$\\beta_{weight}=0.3$\": {\n",
    "            \"2022-08-28\": [6, 7, 8, 9, 10],\n",
    "    },\n",
    "}\n",
    "data['vae_dsprites_longer'] = p_f.get_tensorboard_df(data['vae_dsprites_longer'], min_steps=300, paths=paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b07eac7d-596b-4d3a-93cf-67c042acdf75",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['vae_dsprites_longer'], 'metrics/discrete_mig', 'MIG',\\\n",
    "                 save_path_vae + 'dsprites_mig', label_keep=None, ylim=(0,0.4), legend_ncol_=1, legend_loc_=(0.01, 0.8)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc009c28-8ceb-461c-bb0b-365082028aa8",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "data['vae_dsprites_baselines_longer']['info'] = {}\n",
    "data['vae_dsprites_baselines_longer']['info'] = {\n",
    "    r\"$\\beta_{VAE}=16$\": {\n",
    "        \"2022-08-28\": [0,1,2,3,5],\n",
    "    },\n",
    "    relu + r\"$\\beta_{VAE}=16$\": {\n",
    "    },\n",
    "    r\"$\\beta_{VAE}=4$\": {\n",
    "        \"2022-08-27\": [15,16,17],\n",
    "        \"2022-08-28\": [12,14],\n",
    "    },\n",
    "    relu + r\"$\\beta_{VAE}=4$\": {\n",
    "        \"2022-08-27\": [12,13,14],\n",
    "        \"2022-08-28\": [11,13],\n",
    "    },\n",
    "    r\"$\\beta_{VAE}=1$\": {\n",
    "    },\n",
    "    relu + r\"$\\beta_{VAE}=16$\": {\n",
    "    },  \n",
    "}\n",
    "data['vae_dsprites_baselines_longer'] = p_f.get_tensorboard_df(data['vae_dsprites_baselines_longer'], min_steps=400, paths=paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "311d171a-c4e6-4287-95b2-cd062b5f84ef",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['vae_dsprites_baselines_longer'], 'metrics/discrete_mig', 'MIG',\\\n",
    "                 save_path_vae + 'dsprites_baselines_mig', label_keep=None, ylim=(0,0.4), legend_ncol_=1, legend_loc_=(0.5, 0.5)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "237bd6d8-f3c0-478c-b0a9-d8c81392ded7",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# get final value then plot r2 vs mig for all models. and colour models\n",
    "data_1 = data['vae_dsprites_longer']\n",
    "data_2 = data['vae_dsprites_baselines_longer']\n",
    "dfs_all_mig_1 = [[a['metrics/discrete_mig'][:data_1['cutoff']] for a in x] for x in data_1['dfs']]\n",
    "dfs_all_mig_2 = [[a['metrics/discrete_mig'][:data_2['cutoff']] for a in x] for x in data_2['dfs']]\n",
    "final_mig_1 = [pd.concat(a, axis=1).iloc[-1:] for a in dfs_all_mig_1]\n",
    "final_mig_2 = [pd.concat(a, axis=1).iloc[-1:] for a in dfs_all_mig_2]\n",
    "dfs_all_r2_1 = [[a['accuracies/r2'][:data_1['cutoff']] for a in x] for x in data_1['dfs']]\n",
    "dfs_all_r2_2 = [[a['accuracies/r2'][:data_2['cutoff']] for a in x] for x in data_2['dfs']]\n",
    "final_r2_1 = [pd.concat(a, axis=1).iloc[-1:] for a in dfs_all_r2_1]\n",
    "final_r2_2 = [pd.concat(a, axis=1).iloc[-1:] for a in dfs_all_r2_2]\n",
    "\n",
    "legend_loc = (0.01, 0.06)\n",
    "\n",
    "for lab, mig, r2 in zip(data_1['labels'] + data_2['labels'], final_mig_1 + final_mig_2, final_r2_1 + final_r2_2):\n",
    "    plt.scatter(mig, r2, label=lab, s=80)\n",
    "plt.legend(fontsize=legendsize, loc=(-0.03, 0.17), ncol=2)\n",
    "plt.xlabel('MIG', fontsize=fontsize)\n",
    "plt.ylabel('$r^2$', fontsize=fontsize)\n",
    "\n",
    "plt.savefig(save_path_vae + 'dsprites_mig_r2' + \".png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c63da7b-b97a-465c-ba8d-3025aee4aa3a",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# get final value then plot r2 vs mig for all models. and colour models\n",
    "data_1 = data['vae_dsprites_longer']\n",
    "data_2 = data['vae_dsprites_baselines_longer']\n",
    "dfs_all_mig_1 = [[a['metrics/discrete_mig'][:data_1['cutoff']] for a in x] for x in data_1['dfs']]\n",
    "dfs_all_mig_2 = [[a['metrics/discrete_mig'][:data_2['cutoff']] for a in x] for x in data_2['dfs']]\n",
    "final_mig_1 = [pd.concat(a, axis=1).iloc[-10:] for a in dfs_all_mig_1]\n",
    "final_mig_2 = [pd.concat(a, axis=1).iloc[-10:] for a in dfs_all_mig_2]\n",
    "dfs_all_ll_1 = [[a['losses/rec'][:data_1['cutoff']] for a in x] for x in data_1['dfs']]\n",
    "dfs_all_ll_2 = [[a['losses/rec'][:data_2['cutoff']] for a in x] for x in data_2['dfs']]\n",
    "final_ll_1 = [pd.concat(a, axis=1).iloc[-10:] for a in dfs_all_ll_1]\n",
    "final_ll_2 = [pd.concat(a, axis=1).iloc[-10:] for a in dfs_all_ll_2]\n",
    "\n",
    "for lab, mig, l_l in zip(data_1['labels'] + data_2['labels'], final_mig_1 + final_mig_2, final_ll_1 + final_ll_2):\n",
    "    plt.scatter(mig.mean(), -l_l.mean(), label=lab, s=80)\n",
    "plt.legend(fontsize=legendsize, ncol=2, loc='center left', bbox_to_anchor=(-0.05, 0.37))\n",
    "plt.xlabel('MIG', fontsize=fontsize)\n",
    "plt.ylabel('Log Liklihood', fontsize=fontsize)\n",
    "plt.savefig(save_path_vae + 'dsprites_mig_ll' + \".png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2d7ebe7-4a26-4021-bc7b-a07882ad8181",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Categorical"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4de5df85-5417-4bea-a8b5-03e6d9ef8e06",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# autoencoder linear synthetic data\n",
    "data['ae_categorical']['info'] = {}\n",
    "data['ae_categorical']['info'] = {\n",
    "    relu + loss_act + loss_weight: {\n",
    "        \"2022-09-09\": [ 5, 6, 7, 8, 9 ],\n",
    "    },\n",
    "    loss_nonneg + loss_act + loss_weight: {\n",
    "        \"2022-09-09\": [ 0, 1, 2, 3, 4 ],\n",
    "    },\n",
    "    relu + loss_act: {\n",
    "        \"2022-09-09\": [ 20, 21, 22, 23, 24 ],\n",
    "    },\n",
    "    relu + loss_weight: {\n",
    "        \"2022-09-09\": [ 15, 16, 17, 18, 19 ],\n",
    "    },\n",
    "    loss_act + loss_weight: {\n",
    "        \"2022-09-09\": [ 10, 11, 12, 13, 14],\n",
    "    },\n",
    "    loss_sparse: {\n",
    "        \"2022-09-09\": [ 40, 41, 42, 43, 44],\n",
    "    },\n",
    "    loss_sparse + loss_weight: {\n",
    "        \"2022-09-09\": [ 35, 36, 37, 38, 39],\n",
    "    },\n",
    "}\n",
    "\n",
    "data['ae_categorical'] = p_f.get_tensorboard_df(data['ae_categorical'], paths=paths, min_steps=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "033c6ade-9a39-47f4-a0ab-3ce2c8ea4824",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['ae_categorical'], 'metrics/discrete_mig', 'MIG',\\\n",
    "                 save_path_vae + 'ae_categorical_mig', label_keep=-3, ylim=(0,0.5), cutoff=300000, legend_loc_=(0.05, 0.22)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ff96e21-98be-4ffa-975c-200a99b7d233",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['ae_categorical'], 'metrics/discrete_mil', 'MIL',\\\n",
    "                 save_path_vae + 'ae_categorical_mil', label_keep=-3, ylim=(0,0.7), cutoff=300000, legend_loc_=(0.05, 0.22)) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffd7f94a-54b8-4cc9-9672-df625dfeffd4",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## SUBSPACE NETWORK - MUTUAL INFO FIGS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "936ac9bd-5717-4d12-aea3-b9f287823b31",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "model_folder_name = 'subspaces'\n",
    "path1 = 'SOME PATH' + model_folder_name + '/'\n",
    "path2 = 'ANOTHER PATH' + model_folder_name + '/path2/'\n",
    "model_type = 'subspaces'\n",
    "base_path = path1\n",
    "cmap ='binary'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d191f102-9275-42f3-8389-911acb16e109",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Linear Net - Linear data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cc151bc-5055-4843-8d7f-5413c7d342cd",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "date = '2022-08-20'\n",
    "run = 1\n",
    "index = None\n",
    "\n",
    "# Get directories for the requested run\n",
    "run_path, train_path, model_path, save_path, script_path = p_f.set_directories(date, run, base_path=base_path)\n",
    "if index == None:\n",
    "    index = max([int(x.split('subspaces_')[1].split('.index')[0]) for x in os.listdir(model_path) if 'index' in x])\n",
    "    print(index)\n",
    "# Load model from file\n",
    "model, params, stored_mu, stored_du = p_f.get_model(model_path + '/subspaces_' + str(index), script_path, save_path, use_old_scripts=True, model_type=model_type)\n",
    "\n",
    "# collect data\n",
    "params.graph_mode = False\n",
    "ds_batch, (ds_metric, ds_metric_2) = d_u.data_subspaces(params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e48bea46-5409-413b-9a66-ffa7aba10a2f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "begin_layer = 1\n",
    "neuron_used = m_u.important_neuroms(model, ds_metric['input'], ds_metric['image'], params, begin_layer=1)\n",
    "\n",
    "xs = model(ds_metric['input'])\n",
    "prob = xs[-1].numpy()\n",
    "xs = [x.numpy().T for x in xs]\n",
    "xs_2 = model(ds_metric_2['input'])\n",
    "xs_2 = [x.numpy().T for x in xs_2]\n",
    "\n",
    "metrics = [c_u.compute_mig(x, ds_metric, x_2, ds_metric_2, dataset=params.dataset, neuron_used=n_u) \\\n",
    "           for (x, x_2, n_u) in zip(xs[begin_layer:], xs_2[begin_layer:], neuron_used)]\n",
    "mi_matrices = [a[1][0] for a in metrics]\n",
    "entropies = [a[1][1] for a in metrics]\n",
    "metrics = [a[0] for a in metrics]\n",
    "\n",
    "print(metrics)\n",
    "\n",
    "tick_names = [x[6:] if 'value' in x else x for x in ds_metric.keys() if x not in ['image', 'input']]\n",
    "plt.imshow(mi_matrices[-1], cmap=cmap)\n",
    "#plt.colorbar()\n",
    "ax = plt.gca()\n",
    "plt.xticks(np.arange(len(tick_names)))\n",
    "ax.set_xticklabels(tick_names, rotation = 60, ha=\"right\")\n",
    "plt.xlabel('Factors', fontsize=fontsize)\n",
    "plt.ylabel('Latents', fontsize=fontsize)\n",
    "plt.savefig(save_path_net + 'lin_lin_MI' + \".png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58325734-8cec-437f-8c89-6d80dbb98b9e",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Non-linear Net - nonlinear data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78f65ac8-9b1d-4aab-bfce-e72b7c25f2fd",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "date = '2022-08-21'\n",
    "run = 0\n",
    "index = None\n",
    "layer = -2\n",
    "\n",
    "# Get directories for the requested run\n",
    "run_path, train_path, model_path, save_path, script_path = p_f.set_directories(date, run, base_path=base_path)\n",
    "if index == None:\n",
    "    index = max([int(x.split('subspaces_')[1].split('.index')[0]) for x in os.listdir(model_path) if 'index' in x])\n",
    "    print(index)\n",
    "# Load model from file\n",
    "model, params, stored_mu, stored_du = p_f.get_model(model_path + '/subspaces_' + str(index), script_path, save_path, use_old_scripts=True, model_type=model_type)\n",
    "\n",
    "# collect data\n",
    "params.graph_mode = False\n",
    "ds_batch, (ds_metric, ds_metric_2) = d_u.data_subspaces(params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "379cc50d-6077-4ca5-b507-5f05314e89c1",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "begin_layer = 1\n",
    "neuron_used = m_u.important_neuroms(model, ds_metric['input'], ds_metric['image'], params, begin_layer=1)\n",
    "\n",
    "xs = model(ds_metric['input'])\n",
    "prob = xs[-1].numpy()\n",
    "xs = [x.numpy().T for x in xs]\n",
    "xs_2 = model(ds_metric_2['input'])\n",
    "xs_2 = [x.numpy().T for x in xs_2]\n",
    "\n",
    "metrics = [c_u.compute_mig(x, ds_metric, x_2, ds_metric_2, dataset=params.dataset, neuron_used=n_u) \\\n",
    "           for (x, x_2, n_u) in zip(xs[begin_layer:], xs_2[begin_layer:], neuron_used)]\n",
    "mi_matrices = [a[1][0] for a in metrics]\n",
    "entropies = [a[1][1] for a in metrics]\n",
    "metrics = [a[0] for a in metrics]\n",
    "\n",
    "print(metrics[layer])\n",
    "\n",
    "tick_names = [x[6:] if 'value' in x else x for x in ds_metric.keys() if x not in ['image', 'input']]\n",
    "plt.imshow(mi_matrices[layer], cmap=cmap)\n",
    "#plt.colorbar()\n",
    "ax = plt.gca()\n",
    "plt.xticks(np.arange(len(tick_names)))\n",
    "ax.set_xticklabels(tick_names, rotation = 60, ha=\"right\")\n",
    "plt.xlabel('Factors', fontsize=fontsize)\n",
    "plt.ylabel('Latents', fontsize=fontsize)\n",
    "plt.savefig(save_path_net + 'nonlin_nonlin_MI' + \".png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e4456a2-0b1b-412f-8b0c-9d1e4c3afb14",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Linear Net - no constraints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "622737dc-f9be-4e18-ac61-2594443455a0",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "base_path = path1\n",
    "\n",
    "date = '2022-08-20'\n",
    "run = 6\n",
    "index = None\n",
    "\n",
    "# Get directories for the requested run\n",
    "run_path, train_path, model_path, save_path, script_path = p_f.set_directories(date, run, base_path=base_path)\n",
    "if index == None:\n",
    "    index = max([int(x.split('subspaces_')[1].split('.index')[0]) for x in os.listdir(model_path) if 'index' in x])\n",
    "    print(index)\n",
    "# Load model from file\n",
    "model, params, stored_mu, stored_du = p_f.get_model(model_path + '/subspaces_' + str(index), script_path, save_path, use_old_scripts=True, model_type=model_type)\n",
    "\n",
    "# collect data\n",
    "params.graph_mode = False\n",
    "ds_batch, (ds_metric, ds_metric_2) = d_u.data_subspaces(params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "017616ff-e1e6-4a8d-81fa-9c938571f83b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "begin_layer = 1\n",
    "neuron_used = m_u.important_neuroms(model, ds_metric['input'], ds_metric['image'], params, begin_layer=1)\n",
    "\n",
    "xs = model(ds_metric['input'])\n",
    "prob = xs[-1].numpy()\n",
    "xs = [x.numpy().T for x in xs]\n",
    "xs_2 = model(ds_metric_2['input'])\n",
    "xs_2 = [x.numpy().T for x in xs_2]\n",
    "\n",
    "metrics = [c_u.compute_mig(x, ds_metric, x_2, ds_metric_2, dataset=params.dataset, neuron_used=n_u) \\\n",
    "           for (x, x_2, n_u) in zip(xs[begin_layer:], xs_2[begin_layer:], neuron_used)]\n",
    "mi_matrices = [a[1][0] for a in metrics]\n",
    "entropies = [a[1][1] for a in metrics]\n",
    "metrics = [a[0] for a in metrics]\n",
    "\n",
    "print(metrics)\n",
    "\n",
    "tick_names = [x[6:] if 'value' in x else x for x in ds_metric.keys() if x not in ['image', 'input']]\n",
    "plt.imshow(mi_matrices[-1], cmap=cmap)\n",
    "#plt.colorbar()\n",
    "ax = plt.gca()\n",
    "plt.xticks(np.arange(len(tick_names)))\n",
    "ax.set_xticklabels(tick_names, rotation = 60, ha=\"right\")\n",
    "plt.xlabel('Factors', fontsize=fontsize)\n",
    "plt.ylabel('Latents', fontsize=fontsize)\n",
    "plt.savefig(save_path_net + 'lin_lin_no_constraints_MI' + \".png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d84f828-8e7d-4cc5-b460-66d98175ece4",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Subspace VAE - MUTUAL INFO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cdf9a25-1dfa-494c-8298-ca4637a1b06a",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "model_folder_name = 'Subspaces_vae'\n",
    "path1 = 'SOME PATH' + model_folder_name + '/'\n",
    "path2 = 'ANOTHER PATH' + model_folder_name + '/path2/'\n",
    "model_type = 'subspaces_vae'\n",
    "cmap ='binary'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "edfe15e0-d3fe-4350-a489-6ba81f4727d3",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    },
    "tags": []
   },
   "source": [
    "### Linear DeepNet - Linear Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eef444bd-7f6b-46bd-b7f8-f6130c677fda",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "base_path = path1\n",
    "\n",
    "date = '2022-08-21'\n",
    "run = 9\n",
    "index = None\n",
    "\n",
    "# Get directories for the requested run\n",
    "run_path, train_path, model_path, save_path, script_path = p_f.set_directories(date, run, base_path=base_path)\n",
    "if index == None:\n",
    "    index = max([int(x.split('subspaces_')[1].split('.index')[0]) for x in os.listdir(model_path) if 'index' in x])\n",
    "    print(index)\n",
    "# Load model from file\n",
    "model, params, stored_mu, stored_du = p_f.get_model(model_path + '/subspaces_' + str(index), script_path, save_path, use_old_scripts=True, model_type=model_type)\n",
    "\n",
    "import parameters\n",
    "par_new = parameters.default_params_subspaces_vae()\n",
    "for key in par_new.keys():\n",
    "    try:\n",
    "        params[key]\n",
    "    except:\n",
    "        params[key] = par_new[key]\n",
    "        \n",
    "# get data\n",
    "params.graph_mode = False\n",
    "# params.dataset = 'dsprites'  #'shapes3d'\n",
    "ds_batch, (ds_metric, ds_metric_2) = d_u.data_subspaces(params)\n",
    "# run model\n",
    "(logits, rec), latents = model(ds_metric['image'])\n",
    "(_, mu, logvar) = [x.numpy().T for x in latents]\n",
    "(logits_2, rec_2), latents_2 = model(ds_metric_2['image'])\n",
    "(_, mu_2, logvar_2) = [x.numpy().T for x in latents_2]\n",
    "\n",
    "neuron_used = np.mean(np.exp(logvar), axis=1) <= 0.5 if params.sample else None\n",
    "metrics, mi_mat_ = c_u.compute_mig(mu, ds_metric, mu_2, ds_metric_2, neuron_used=neuron_used,\\\n",
    "                                   dataset=params.dataset, remove_unused=True, compute_dci=False)\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8a513a4-495e-4b25-95a7-e423a66e3547",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "tick_names = [x[6:] if 'value' in x else x for x in ds_metric.keys() if x not in ['image', 'input']]\n",
    "mi_mat, entropy = mi_mat_\n",
    "mi_mat_scaled = mi_mat / entropy[None, :]\n",
    "plt.imshow(mi_mat_scaled, vmin=0.0, vmax=mi_mat_scaled.max(), cmap=cmap)\n",
    "#plt.colorbar()\n",
    "ax = plt.gca()\n",
    "plt.xticks(np.arange(len(tick_names)))\n",
    "ax.set_xticklabels(tick_names, rotation = 60, ha=\"right\")\n",
    "plt.xlabel('Factors', fontsize=fontsize)\n",
    "plt.ylabel('Latents', fontsize=fontsize)\n",
    "plt.savefig(save_path_vae + 'lin_lin_MI' + \".png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00c3799e-be8d-4621-aa3c-56d71abf01f2",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### NonLinear DeepNet - NonLinear Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5930c79-e6ec-4ab9-a496-e496537a481f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "base_path = path2\n",
    "\n",
    "date = '2022-11-12'\n",
    "run = 4\n",
    "index = None\n",
    "\n",
    "# Get directories for the requested run\n",
    "run_path, train_path, model_path, save_path, script_path = p_f.set_directories(date, run, base_path=base_path)\n",
    "if index == None:\n",
    "    index = max([int(x.split('subspaces_')[1].split('.index')[0]) for x in os.listdir(model_path) if 'index' in x])\n",
    "    print(index)\n",
    "# Load model from file\n",
    "model, params, stored_mu, stored_du = p_f.get_model(model_path + '/subspaces_' + str(index), script_path, save_path, use_old_scripts=True, model_type=model_type)\n",
    "\n",
    "import parameters\n",
    "par_new = parameters.default_params_subspaces_vae()\n",
    "for key in par_new.keys():\n",
    "    try:\n",
    "        params[key]\n",
    "    except:\n",
    "        params[key] = par_new[key]\n",
    "        \n",
    "# get data\n",
    "params.graph_mode = False\n",
    "# params.dataset = 'dsprites'  #'shapes3d'\n",
    "ds_batch, (ds_metric, ds_metric_2) = d_u.data_subspaces(params)\n",
    "# run model\n",
    "(logits, rec), latents = model(ds_metric['image'])\n",
    "(_, mu, logvar) = [x.numpy().T for x in latents]\n",
    "(logits_2, rec_2), latents_2 = model(ds_metric_2['image'])\n",
    "(_, mu_2, logvar_2) = [x.numpy().T for x in latents_2]\n",
    "\n",
    "neuron_used = np.mean(np.exp(logvar), axis=1) <= 0.5 if params.sample else None\n",
    "metrics, mi_mat_ = c_u.compute_mig(mu, ds_metric, mu_2, ds_metric_2, neuron_used=neuron_used,\\\n",
    "                                   dataset=params.dataset, remove_unused=True, compute_dci=False)\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93d4137a-3afc-4f40-b97f-ad869f625592",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "tick_names = [x[6:] if 'value' in x else x for x in ds_metric.keys() if x not in ['image', 'input']]\n",
    "mi_mat, entropy = mi_mat_\n",
    "mi_mat_scaled = mi_mat / entropy[None, :]\n",
    "plt.imshow(mi_mat_scaled, vmin=0.0, vmax=mi_mat_scaled.max(), cmap=cmap)\n",
    "ax = plt.gca()\n",
    "plt.xticks(np.arange(len(tick_names)))\n",
    "ax.set_xticklabels(tick_names, rotation = 60, ha=\"right\")\n",
    "plt.xlabel('Factors', fontsize=fontsize)\n",
    "plt.ylabel('Latents', fontsize=fontsize)\n",
    "plt.savefig(save_path_vae + 'nonlin_nonlin_MI' + \".png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d973752-301f-416d-8aee-42bdbae9a856",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Shapes 3d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bcf9f7e-6a44-4aed-8bf7-96ccaba0bad5",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "base_path = path1\n",
    "\n",
    "date = '2022-08-25'\n",
    "run = 3 #33\n",
    "index = None\n",
    "\n",
    "# Get directories for the requested run\n",
    "run_path, train_path, model_path, save_path, script_path = p_f.set_directories(date, run, base_path=base_path)\n",
    "if index == None:\n",
    "    index = max([int(x.split('subspaces_')[1].split('.index')[0]) for x in os.listdir(model_path) if 'index' in x])\n",
    "    print(index)\n",
    "# Load model from file\n",
    "model, params, stored_mu, stored_du = p_f.get_model(model_path + '/subspaces_' + str(index), script_path, save_path, use_old_scripts=True, model_type=model_type)\n",
    "\n",
    "import parameters\n",
    "par_new = parameters.default_params_subspaces_vae()\n",
    "for key in par_new.keys():\n",
    "    try:\n",
    "        params[key]\n",
    "    except:\n",
    "        params[key] = par_new[key]\n",
    "        \n",
    "params.graph_mode = False\n",
    "# params.dataset = 'dsprites'  #'shapes3d'\n",
    "ds_batch, (ds_metric, ds_metric_2) = d_u.data_subspaces(params)\n",
    "# run model\n",
    "(logits, rec), latents = model(ds_metric['image'])\n",
    "(_, mu, logvar) = [x.numpy().T for x in latents]\n",
    "(logits_2, rec_2), latents_2 = model(ds_metric_2['image'])\n",
    "(_, mu_2, logvar_2) = [x.numpy().T for x in latents_2]\n",
    "\n",
    "neuron_used = np.mean(np.exp(logvar), axis=1) <= 0.5 if params.sample else None\n",
    "metrics, mi_mat_ = c_u.compute_mig(mu, ds_metric, mu_2, ds_metric_2, neuron_used=neuron_used,\\\n",
    "                                   dataset=params.dataset, remove_unused=False, compute_dci=False)\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16e20f2f-efe2-4184-ad12-da689793fdec",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "tick_names = [x[6:] if ('value' in x or 'label' in x) else x for x in ds_metric.keys() if x not in ['image', 'target']]\n",
    "mi_mat, entropy = mi_mat_\n",
    "mi_mat_scaled = mi_mat / entropy[None, :]\n",
    "plt.imshow(mi_mat_scaled, vmin=0.0, vmax=mi_mat_scaled.max(), cmap=cmap)\n",
    "ax = plt.gca()\n",
    "plt.xticks(np.arange(len(tick_names)))\n",
    "ax.set_xticklabels(tick_names, rotation = 60, ha=\"right\")\n",
    "plt.xlabel('Factors', fontsize=fontsize)\n",
    "plt.ylabel('Latents', fontsize=fontsize)\n",
    "plt.savefig(save_path_vae + 'shapes3d_MI' + \".png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bff39ffd-9310-475a-918f-7eda04f85bfd",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "sample_id = 2\n",
    "latent = mu[:, sample_id]\n",
    "\n",
    "plt.imshow(ds_metric['image'][sample_id,...].numpy())\n",
    "plt.axis('off')\n",
    "plt.savefig(save_path_vae + 'shapes3d_example_1' + \".png\", dpi=300, bbox_inches='tight')\n",
    "\n",
    "plt.figure()\n",
    "plt.imshow(ds_metric['image'][sample_id+2,...].numpy())\n",
    "plt.axis('off')\n",
    "plt.savefig(save_path_vae + 'shapes3d_example_2' + \".png\", dpi=300, bbox_inches='tight')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4adfd4fb-01d5-4585-a6c3-715058b713a3",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# latent traversals\n",
    "\n",
    "import imageio\n",
    "\n",
    "grid_points = 15\n",
    "grid = np.linspace(-2 if not params.relu_latent_mu else 0,2,grid_points)\n",
    "latent_all = []\n",
    "for l_dim in range(params.latent_dim):\n",
    "    mask = np.zeros((1,params.latent_dim))\n",
    "    mask[0, l_dim] = 1.0\n",
    "    for g_ in grid:\n",
    "        latent_ = tf.identity(latent - mask * latent) + mask * g_\n",
    "        latent_all.append(latent_)\n",
    "latent_ = tf.concat(latent_all, axis=0)\n",
    "_, image_pred = model.decode(latent_, apply_sigmoid=params.sigmoid_output)\n",
    "\n",
    "plt.figure(figsize=(grid_points+1, params.latent_dim))\n",
    "i=1\n",
    "for l_dim in range(params.latent_dim):\n",
    "    for g_point in range(grid_points+1):\n",
    "        plt.subplot(params.latent_dim, grid_points+1, g_point + l_dim * (grid_points+1) + 1)\n",
    "        if g_point == grid_points:\n",
    "            _ = plt.hist(mu[l_dim,:], bins=np.linspace(-3,3,num=20)) \n",
    "        else:\n",
    "            plt.imshow(image_pred[i-1,...])\n",
    "            i+=1\n",
    "        plt.axis('off')\n",
    "        \n",
    "plt.savefig(save_path_vae + 'shapes3d_latent_traversal' + \".png\", dpi=300, bbox_inches='tight')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5897c3b-3e17-4600-9f26-2208270f308b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "filenames = []\n",
    "for g_point in range(grid_points):\n",
    "    plt.figure(figsize=(params.latent_dim, 1))\n",
    "    for l_dim in range(params.latent_dim):\n",
    "        plt.subplot(1, params.latent_dim, l_dim + 1)\n",
    "        plt.imshow(image_pred[l_dim * grid_points + g_point,...])\n",
    "        plt.axis('off')\n",
    "        \n",
    "       # create file name and append it to a list\n",
    "    filename = save_path_vae + f'/{g_point}.png'\n",
    "    filenames.append(filename)\n",
    "\n",
    "    # save frame\n",
    "    plt.savefig(filename)\n",
    "    plt.close()\n",
    "        \n",
    "# build gif\n",
    "with imageio.get_writer(save_path_vae + '/latent_traversal.gif', mode='I') as writer:\n",
    "    for filename in filenames:\n",
    "        image = imageio.imread(filename)\n",
    "        writer.append_data(image)\n",
    "\n",
    "# Remove files\n",
    "for filename in set(filenames):\n",
    "    os.remove(filename)\n",
    "    \n",
    "# show gif\n",
    "from IPython.display import Image, display\n",
    "with open(save_path_vae + '/latent_traversal.gif','rb') as file:\n",
    "    display(Image(file.read()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c08a5caf-ffab-49cd-96b6-f17c7b7655d4",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## dsprites"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32f21d40-293f-4409-ae58-1fd0995bc68e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "base_path = path1\n",
    "\n",
    "date = '2022-08-27'#'2022-08-24'\n",
    "run = 9 #33\n",
    "index = None\n",
    "\n",
    "# Get directories for the requested run\n",
    "run_path, train_path, model_path, save_path, script_path = p_f.set_directories(date, run, base_path=base_path)\n",
    "if index == None:\n",
    "    index = max([int(x.split('subspaces_')[1].split('.index')[0]) for x in os.listdir(model_path) if 'index' in x])\n",
    "    print(index)\n",
    "# Load model from file\n",
    "model, params, stored_mu, stored_du = p_f.get_model(model_path + '/subspaces_' + str(index), script_path, save_path, use_old_scripts=True, model_type=model_type)\n",
    "\n",
    "import parameters\n",
    "par_new = parameters.default_params_subspaces_vae()\n",
    "for key in par_new.keys():\n",
    "    try:\n",
    "        params[key]\n",
    "    except:\n",
    "        params[key] = par_new[key]\n",
    "        \n",
    "params.graph_mode = False\n",
    "# params.dataset = 'dsprites'  #'shapes3d'\n",
    "ds_batch, (ds_metric, ds_metric_2) = d_u.data_subspaces(params)\n",
    "# run model\n",
    "(logits, rec), latents = model(ds_metric['image'])\n",
    "(_, mu, logvar) = [x.numpy().T for x in latents]\n",
    "(logits_2, rec_2), latents_2 = model(ds_metric_2['image'])\n",
    "(_, mu_2, logvar_2) = [x.numpy().T for x in latents_2]\n",
    "\n",
    "neuron_used = np.mean(np.exp(logvar), axis=1) <= 0.5 if params.sample else None\n",
    "metrics, mi_mat_ = c_u.compute_mig(mu, ds_metric, mu_2, ds_metric_2, neuron_used=neuron_used,\\\n",
    "                                   dataset=params.dataset, remove_unused=False, compute_dci=False)\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "706dc578-e000-48d6-bc4e-741a0979e462",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "tick_names = [x[6:] if ('value' in x or 'label' in x) else x for x in ds_metric.keys() if x not in ['image', 'target']]\n",
    "mi_mat, entropy = mi_mat_\n",
    "mi_mat_scaled = mi_mat / entropy[None, :]\n",
    "plt.imshow(mi_mat_scaled, vmin=0.0, vmax=mi_mat_scaled.max(), cmap=cmap)\n",
    "#plt.colorbar()\n",
    "ax = plt.gca()\n",
    "plt.xticks(np.arange(len(tick_names)))\n",
    "ax.set_xticklabels(tick_names, rotation = 60, ha=\"right\")\n",
    "plt.xlabel('Factors', fontsize=fontsize)\n",
    "plt.ylabel('Latents', fontsize=fontsize)\n",
    "plt.savefig(save_path_vae + 'dsprites_MI' + \".png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c9e3faa-59f9-4cdb-af8f-1ef76efc5826",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "sample_id = 7\n",
    "latent = mu[:, sample_id]\n",
    "\n",
    "plt.imshow(ds_metric['image'][sample_id,...].numpy(), cmap=cmap)\n",
    "plt.axis('off')\n",
    "plt.savefig(save_path_vae + 'dsprites_example_1' + \".png\", dpi=300, bbox_inches='tight')\n",
    "\n",
    "plt.figure()\n",
    "plt.imshow(ds_metric['image'][sample_id+2,...].numpy(), cmap=cmap)\n",
    "plt.axis('off')\n",
    "plt.savefig(save_path_vae + 'dsprites_example_2' + \".png\", dpi=300, bbox_inches='tight')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "910af1eb-8db9-44ac-82c0-18a781fc8092",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# latent traversals\n",
    "\n",
    "import imageio\n",
    "\n",
    "grid_points = 15\n",
    "grid = np.linspace(-2 if not params.relu_latent_mu else 0,2,grid_points)\n",
    "latent_all = []\n",
    "for l_dim in range(params.latent_dim):\n",
    "    mask = np.zeros((1,params.latent_dim))\n",
    "    mask[0, l_dim] = 1.0\n",
    "    for g_ in grid:\n",
    "        latent_ = tf.identity(latent - mask * latent) + mask * g_\n",
    "        latent_all.append(latent_)\n",
    "latent_ = tf.concat(latent_all, axis=0)\n",
    "_, image_pred = model.decode(latent_, apply_sigmoid=params.sigmoid_output)\n",
    "\n",
    "plt.figure(figsize=(grid_points+1, params.latent_dim))\n",
    "i=1\n",
    "for l_dim in range(params.latent_dim):\n",
    "    for g_point in range(grid_points+1):\n",
    "        plt.subplot(params.latent_dim, grid_points+1, g_point + l_dim * (grid_points+1) + 1)\n",
    "        if g_point == grid_points:\n",
    "            _ = plt.hist(mu[l_dim,:], bins=np.linspace(-3,3,num=20)) \n",
    "        else:\n",
    "            plt.imshow(image_pred[i-1,...])\n",
    "            i+=1\n",
    "        plt.axis('off')\n",
    "        \n",
    "plt.savefig(save_path_vae + 'dsprites_latent_traversal' + \".png\", dpi=300, bbox_inches='tight')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e51badbd-1df5-4689-b25c-5f5a7c304a58",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "filenames = []\n",
    "for g_point in range(grid_points):\n",
    "    plt.figure(figsize=(1, params.latent_dim))\n",
    "    for l_dim in range(params.latent_dim):\n",
    "        plt.subplot(params.latent_dim, 1, l_dim + 1)\n",
    "        plt.imshow(image_pred[l_dim * grid_points + g_point,...])\n",
    "        plt.axis('off')\n",
    "        \n",
    "       # create file name and append it to a list\n",
    "    filename = save_path_vae + f'/{g_point}.png'\n",
    "    filenames.append(filename)\n",
    "\n",
    "    # save frame\n",
    "    plt.savefig(filename)\n",
    "    plt.close()\n",
    "        \n",
    "# build gif\n",
    "with imageio.get_writer(save_path_vae + '/latent_traversal.gif', mode='I') as writer:\n",
    "    for filename in filenames:\n",
    "        image = imageio.imread(filename)\n",
    "        writer.append_data(image)\n",
    "\n",
    "# Remove files\n",
    "for filename in set(filenames):\n",
    "    os.remove(filename)\n",
    "    \n",
    "# show gif\n",
    "from IPython.display import Image, display\n",
    "with open(save_path_vae + '/latent_traversal.gif','rb') as file:\n",
    "    display(Image(file.read()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "040f3192-c5e7-448d-be4c-4db3839249c9",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Categories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a266bdda-3103-43b4-9d58-2b54721b321b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "base_path = path1\n",
    "\n",
    "date = '2022-09-09'\n",
    "run = 5\n",
    "index = None\n",
    "\n",
    "# Get directories for the requested run\n",
    "run_path, train_path, model_path, save_path, script_path = p_f.set_directories(date, run, base_path=base_path)\n",
    "if index == None:\n",
    "    index = max([int(x.split('subspaces_')[1].split('.index')[0]) for x in os.listdir(model_path) if 'index' in x])\n",
    "    print(index)\n",
    "# Load model from file\n",
    "model, params, stored_mu, stored_du = p_f.get_model(model_path + '/subspaces_' + str(index), script_path, save_path, use_old_scripts=True, model_type=model_type)\n",
    "\n",
    "import parameters\n",
    "par_new = parameters.default_params_subspaces_vae()\n",
    "for key in par_new.keys():\n",
    "    try:\n",
    "        params[key]\n",
    "    except:\n",
    "        params[key] = par_new[key]\n",
    "        \n",
    "        # get data\n",
    "params.graph_mode = False\n",
    "# params.dataset = 'dsprites'  #'shapes3d'\n",
    "ds_batch, (ds_metric, ds_metric_2) = d_u.data_subspaces(params)\n",
    "# run model\n",
    "(logits, rec), latents = model(ds_metric['image'])\n",
    "(_, mu, logvar) = [x.numpy().T for x in latents]\n",
    "(logits_2, rec_2), latents_2 = model(ds_metric_2['image'])\n",
    "(_, mu_2, logvar_2) = [x.numpy().T for x in latents_2]\n",
    "\n",
    "neuron_used = np.mean(np.exp(logvar), axis=1) <= 0.5 if params.sample else None\n",
    "metrics, mi_mat_ = c_u.compute_mig(mu, ds_metric, mu_2, ds_metric_2, neuron_used=neuron_used,\\\n",
    "                                   dataset=params.dataset, remove_unused=False, compute_dci=False)\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33cd58c8-bda9-4091-91b6-e7050641a26d",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "tick_names = [x[6:] if ('value' in x or 'label' in x) else x for x in ds_metric.keys() if x not in ['image', 'target', 'input']]\n",
    "mi_mat, entropy = mi_mat_\n",
    "mi_mat_scaled = mi_mat / entropy[None, :]\n",
    "plt.imshow(mi_mat_scaled, vmin=0.0, vmax=mi_mat_scaled.max(), cmap=cmap)\n",
    "ax = plt.gca()\n",
    "plt.xticks(np.arange(len(tick_names)))\n",
    "ax.set_xticklabels(tick_names, rotation = 60, ha=\"right\")\n",
    "plt.xlabel('Factors', fontsize=fontsize)\n",
    "plt.ylabel('Latents', fontsize=fontsize)\n",
    "plt.savefig(save_path_vae + 'categories_MI' + \".png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "532e9394-88b2-47cc-966e-740e6046fb5c",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d26a5c56-1660-4166-a09a-1e208c1c80b1",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "b8a8bc45-d223-49b9-a9a4-5f5cb5349af3",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Pattern Learning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d900a687-1dbd-4c71-8aea-f557ad97e4d0",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "model_folder_name = 'Pattern_Learning'\n",
    "path1 = 'SOME PATH' + model_folder_name + '/'\n",
    "path2 = 'OTHER PATH' + model_folder_name + '/path2/'\n",
    "paths = [path2, path1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17f8d633-de4e-46fb-bcf1-91a1f0acd3a8",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "data = {'pattern_learning_old': {},\n",
    "        'pattern_learning': {}\n",
    "       }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20c2a7ba-ed9c-429f-8342-45dd4d110a2b",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Cosine curves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86a667e9-eb27-4f9e-98dd-065dacf6a23b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "data['pattern_learning']['info'] = {\n",
    "    'ReLu': {\n",
    "        \"2022-08-30\": [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],\n",
    "        \"2022-08-31\": [13, 14, 15, 16, 17]\n",
    "    },\n",
    "    'No ReLu': {\n",
    "        \"2022-08-30\": [5, 6, 7, 8, 9],\n",
    "        \"2022-08-31\": [7, 9, 10, 11, 12],\n",
    "    },\n",
    "}\n",
    "data['pattern_learning'] = p_f.get_tensorboard_df(data['pattern_learning'], paths=paths, min_steps=2000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c8d5e15-978f-4cb3-a6a0-29046b67074c",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "names = p_f.show_df_names(data['pattern_learning']['dfs'], to_print=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4db2529-ebad-431f-9528-3b5b38f00589",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_train_curve(data['pattern_learning'], 'metrics/cosine', 'Cosine Distance',\\\n",
    "                 save_path_pattern + 'cosine', label_keep=None, ylim=(0,1.0), figsize=(4,4))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eda7d25c-4d79-4c73-9f2f-b7339cf89dd4",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Compare Firing maps around object locations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5c6ac27-38d9-4f7d-8224-6083c03c8385",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "base_path = path2\n",
    "index = None\n",
    "results = {}\n",
    "for experiment_type, experiment_info in data['pattern_learning']['info'] .items():\n",
    "    print(experiment_type)\n",
    "    results[experiment_type] = []\n",
    "    for date, runs in experiment_info.items():\n",
    "        for run in runs:\n",
    "            params, g, inputs, model, info = p_f.get_pattern_data(date, run, index, base_path)\n",
    "            object_all_type = p_f.get_object_surround(model, params)\n",
    "            mean_corr, mean_corr_cells = p_f.get_mean_spatial_corrs(object_all_type, params)\n",
    "            results[experiment_type].append(mean_corr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b5f2dc6-8295-48a6-917e-2f91583f58ca",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "for experiment_type, corrs in results.items():\n",
    "    plt.scatter([experiment_type for _ in corrs], corrs, label=experiment_type)\n",
    "    \n",
    "ax = plt.gca()\n",
    "plt.xticks('off')\n",
    "plt.xlabel('Model / Data type', fontsize=fontsize)\n",
    "plt.ylabel('Spatial Correlation Around Objects', fontsize=fontsize)\n",
    "plt.legend(loc='center left', bbox_to_anchor=(0.15,0.3), fontsize=legendsize)\n",
    "plt.savefig(save_path_pattern + 'object_correlations' + \".png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6848d1b-7a44-4ae2-9e89-5458803a0f27",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Relu + Factorised data (1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1a68e15-9f1b-4906-b16e-405744426d9e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "date = '2022-08-31'\n",
    "run = 1\n",
    "index = None\n",
    "base_path = path2\n",
    "\n",
    "params, g, inputs, model, info = p_f.get_pattern_data(date, run, index, base_path)\n",
    "info.save_path = save_path_pattern + 'ReLu_FactoredData_1_'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "532b789f",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "g_to_plot = g.numpy()\n",
    "\n",
    "object_all_type = p_f.get_object_surround(model, params)\n",
    "mean_corr, mean_corr_cells = p_f.get_mean_spatial_corrs(object_all_type, params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52d4b343",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Grid score and scale\n",
    "\"\"\"\n",
    "import cell_analyses as ca\n",
    "fit_ellipse = True\n",
    "ring = True\n",
    "torus = False\n",
    "\n",
    "width = params.width\n",
    "ent_or_mem = params.ent_dim\n",
    "\n",
    "scores_all = []\n",
    "for g_env in g_to_plot:\n",
    "    module_analysis = []\n",
    "    for i in range(ent_or_mem):\n",
    "        if i%10 ==0:\n",
    "            print(str(i), end=' ')\n",
    "\n",
    "        # get cell\n",
    "        cell = g_env[:, i]\n",
    "        rate_map = np.reshape(cell, (params.height, params.width))\n",
    "        auto = p_f.autocorr2d_no_nans(rate_map, torus=torus)\n",
    "        auto[np.isnan(auto)]=0\n",
    "        score, scale, theta = ca.grid_score_scale_analysis(auto, fit_ellipse=fit_ellipse, ring=ring)\n",
    "        norm_firing = np.mean(cell**2)\n",
    "        module_analysis.append([i, score, scale, theta, norm_firing])\n",
    "\n",
    "    scores, scales, thetas, norm_firings = [], [], [], []\n",
    "    for x in module_analysis:\n",
    "        scores.append(x[1])\n",
    "        scales.append(x[2])\n",
    "        thetas.append(x[3])\n",
    "        norm_firings.append(x[4])\n",
    "\n",
    "    scores = np.asarray(scores)\n",
    "    scales = np.asarray(scales)\n",
    "    thetas = np.asarray(thetas)\n",
    "    norm_firings = np.asarray(norm_firings)\n",
    "\n",
    "    scores_all.append(scores)\n",
    "scores_mean = np.nanmean(scores_all, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84f3eae9-22db-4dda-9821-3dbcf0597b66",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_pattern_all_cells(inputs, g, params, info)       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b750faf-2e31-413d-8da8-b27b9440b81b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "change, spat_corr = p_f.plot_pattern_metrics(model, g, info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71f21dd7",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "fig, ax1 = plt.subplots(figsize=(4, 4))\n",
    "\n",
    "ax2 = ax1.twinx()\n",
    "ax1.scatter(spat_corr, scores_mean, c='b', marker=\"^\", label='Grid score', alpha=0.5)\n",
    "ax2.scatter(spat_corr, mean_corr_cells, c='g', marker='o', label='Object patch correlation', alpha=0.5)\n",
    "\n",
    "ax1.set_xlabel('Spatial correlation', fontsize=fontsize)\n",
    "ax1.set_ylabel('Grid score', color='b', fontsize=fontsize)\n",
    "ax2.set_ylabel('Object patch correlation', color='g', fontsize=fontsize)\n",
    "\n",
    "plt.savefig(save_path_pattern + 'spat_corr_vs_grid_score_and_patch_corr' + '_' + str(index), dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79a12e55-e775-4213-8991-cee4debaa14c",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### NoRelu + Factorised data (1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4fd3d04-4f0f-4df6-95b2-5021aeb6eb6c",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "date = '2022-07-21'\n",
    "run = 6\n",
    "index = None\n",
    "base_path = path1\n",
    "\n",
    "params, g, inputs, model, info = p_f.get_pattern_data(date, run, index, base_path)\n",
    "info.save_path = save_path_pattern + 'NoReLu_FactoredData_1_'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ffe2d14-dc46-4eac-bb24-cb83ef07970e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_pattern_all_cells(inputs, g, params, info)      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4a3de8d-f95c-4a37-b16d-2b4aefdb3e3b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_pattern_metrics(model, g, info)    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "427de361-ff1b-4883-8d13-9e51b600add7",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Relu + Entangled Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17764c9c-c4df-454e-ac35-0c0c29cbadd9",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "date = '2022-08-30'\n",
    "run = 14\n",
    "index = None\n",
    "base_path = path2\n",
    "\n",
    "params, g, inputs, model, info = p_f.get_pattern_data(date, run, index, base_path)\n",
    "info.save_path = save_path_pattern + 'ReLu_EnatngledData_1_'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cf06429-b16a-4f3d-9306-3932fd7c3cee",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_pattern_all_cells(inputs, g, params, info)      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "845391c0-ad50-4fd8-a473-f9661dd2dc03",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p_f.plot_pattern_metrics(model, g, info)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}