{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f098133-34f1-45a9-b398-0063427c5177",
   "metadata": {},
   "outputs": [],
   "source": [
    "# NEED TO RUN disentangled_rnn_train_clean\n",
    "# Do this for different values of \"condition_type\" in disentangled_rnn_params & over several seeds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ef20828-c26d-429a-951d-fb1909da20d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from skimage.transform import resize\n",
    "import itertools\n",
    "from itertools import repeat \n",
    "from disentangled_rnn_utils import DotDict as Dd\n",
    "\n",
    "your_path = 'YOUR PATH'\n",
    "\n",
    "import seaborn\n",
    "seaborn.set_style(style='white')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83ee2e00-1c3c-406a-9acb-d187dc140bfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "def find_most_recent(file_list, must_contain=None, cant_contain=None, recent=-1):\n",
    "    \"\"\"\n",
    "    Accepts a list of strings of format X_n[.Y optional], returns highest number n\n",
    "    Each of the strings needs to contain one of must_contain and can't contain any of cant_contain\n",
    "    \"\"\"\n",
    "    # Find all iteration numbers from file list where files match and sort them\n",
    "    iter_numbers = [int(re.findall(r'\\d+', x)[0]) for x in file_list\n",
    "                    if (True if cant_contain is None else not any([y in x for y in cant_contain]))\n",
    "                    and (True if must_contain is None else any([y in x for y in must_contain]))]\n",
    "    iter_numbers.sort()\n",
    "    # Index is the latest iteration, or None if no iterations were found at all\n",
    "    index = None if len(iter_numbers) == 0 else str(np.unique(iter_numbers)[recent])\n",
    "    return index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d3d7481-47e8-4842-a0d2-4ea6648ffab0",
   "metadata": {},
   "outputs": [],
   "source": [
    "date = '2024-08-23'  # YOUR DATE FOR SUMMARIES\n",
    "x_intercepts_2 = []\n",
    "oversamples = []\n",
    "oversamples_2 = []\n",
    "mirs_corner_both = []\n",
    "mirs_corr = []\n",
    "mirs_corr_2 = []\n",
    "for run in os.listdir(your_path + date):\n",
    "    path = your_path + date + '/' + run + '/save'\n",
    "    if len(os.listdir(path)) > 1:\n",
    "        param = Dd(np.load(path + '/params.npy', allow_pickle=True).item())\n",
    "        index = find_most_recent(os.listdir(path), ['.pickle'], ['par'])  \n",
    "        with open(path + '/mirs_all_' + index + '.pickle', 'rb') as handle:\n",
    "            mir = pickle.load(handle)\n",
    "\n",
    "        # find run of particular point\n",
    "        if param.data.condition_type == 'corner_cut_both' and mir[2][7][-1] < 0.5 and mir[2][-1] < 0.01:\n",
    "            continue\n",
    "\n",
    "        if mir[2][-1] > 0.08:\n",
    "            continue\n",
    "        if param.data.condition_type == 'corner_cut_both':\n",
    "            x_intercepts_2.append(param.data.intercept_x)\n",
    "            mirs_corner_both.append(mir)\n",
    "        elif param.data.condition_type == 'oversample_diagonal_2':\n",
    "            oversamples.append(param.data.prop_batch_oversample)\n",
    "            mirs_corr.append(mir)\n",
    "        elif param.data.condition_type == 'oversample_diagonals':\n",
    "            oversamples_2.append(param.data.prop_batch_oversample)\n",
    "            mirs_corr_2.append(mir)\n",
    "        else:\n",
    "            raise ValueError(param.data.condition_type)\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d150f264-2014-484a-8809-72886d398ff3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.gaussian_process import GaussianProcessRegressor\n",
    "from sklearn.gaussian_process.kernels import RBF\n",
    "\n",
    "mir_index = 0\n",
    "source_mi_index = -1\n",
    "\n",
    "corner_both_source_multi_info = np.array([x[2][source_mi_index] for x in mirs_corner_both])\n",
    "corner_both_ncmi = np.array([x[2][mir_index][-1] for x in mirs_corner_both])\n",
    "\n",
    "corr_source_multi_info = np.array([x[2][source_mi_index] for x in mirs_corr])\n",
    "corr_ncmi = np.array([x[2][mir_index][-1] for x in mirs_corr])\n",
    "\n",
    "corr_2_source_multi_info = np.array([x[2][source_mi_index] for x in mirs_corr_2])\n",
    "corr_2_ncmi = np.array([x[2][mir_index][-1] for x in mirs_corr_2])\n",
    "\n",
    "noise_stds = [0.08, 0.04, 0.04, 0.08]\n",
    "for (X_train, y_train, color, label, noise_std) in zip([np.abs(corner_both_source_multi_info[:, None]), \n",
    "                                                        np.abs(corr_source_multi_info[:, None]),\n",
    "                                                        np.abs(corr_2_source_multi_info[:, None])], \n",
    "                                                       [corner_both_ncmi,\n",
    "                                                         corr_ncmi,\n",
    "                                                        corr_2_ncmi], \n",
    "                                                       ['orange', 'green', 'purple'],\n",
    "                                                       ['Corner Cut Both', 'Diagonal', 'Both diagonals'],\n",
    "                                                       noise_stds):\n",
    "    try:\n",
    "        start = X_train.min()\n",
    "        stop = X_train.max()\n",
    "        kernel = 1 * RBF(length_scale=1.0, length_scale_bounds=(1e-3, 1e2))\n",
    "        gaussian_process = GaussianProcessRegressor(kernel=kernel, alpha=noise_std**2, n_restarts_optimizer=9)\n",
    "        \n",
    "        gaussian_process.fit(X_train, y_train)\n",
    "        gaussian_process.kernel_\n",
    "        \n",
    "        X = np.linspace(start=start, stop=stop, num=1_000).reshape(-1, 1)\n",
    "        mean_prediction, std_prediction = gaussian_process.predict(X, return_std=True)\n",
    "        \n",
    "        plt.scatter(X_train, y_train, label=label, c=color, s=10)\n",
    "        plt.plot(X, mean_prediction, c=color) # label=\"Mean prediction\",\n",
    "        plt.fill_between( X.ravel(), mean_prediction - 1.96 * std_prediction, mean_prediction + 1.96 * std_prediction, alpha=0.2, color=color) #label=r\"95% confidence interval\"\n",
    "    except ValueError:\n",
    "        pass\n",
    "plt.ylim(0,1)\n",
    "plt.legend(fontsize=16)\n",
    "plt.xlabel(\"Normalised source multiinformation\", fontsize=16)\n",
    "plt.ylabel(\"Linear Conditional InfoM\", fontsize=16)\n",
    "\n",
    "plt.savefig('modularirt_vs_corr')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33d8d955-7a3b-4144-b4b4-c882b5d725a3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
