{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c5ab8377",
   "metadata": {},
   "outputs": [],
   "source": [
    "import platform\n",
    "if 'mac' in platform.platform():\n",
    "    BASE_DIR = \"/Users/USER/vrtopc/\"\n",
    "    DATA_DIR = \"/media/data/vrtopc\"\n",
    "else:\n",
    "    BASE_DIR = \"/home/USER/vr_to_pc/\"\n",
    "    DATA_DIR = \"/media/data/vrtopc\"\n",
    "\n",
    "import sys\n",
    "sys.path.append(BASE_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1413c425",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "eeadacc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "from utils.metrics import get_spatial_correlation\n",
    "\n",
    "import scipy\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "import yaml"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8606be7e",
   "metadata": {},
   "source": [
    "### Params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ce009fdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "WITH_GRID_CELLS = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05c613f7",
   "metadata": {},
   "source": [
    "# Processed data for the Science and Muessig papers"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7a2a1f0",
   "metadata": {},
   "source": [
    "### Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0407688c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.spatial_units import RateMaps, PolarMaps\n",
    "\n",
    "N_SAMPLES_POS = RateMaps.N_SAMPLES_POS\n",
    "PLACE_SI_TH = RateMaps.PLACE_SI_TH\n",
    "\n",
    "N_SAMPLES_THET = PolarMaps.N_SAMPLES_THET\n",
    "HD_SI_TH = PolarMaps.HD_SI_TH\n",
    "HD_RVL_TH = PolarMaps.HD_RVL_TH\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d24f394",
   "metadata": {},
   "outputs": [],
   "source": [
    "edge_n_bins = 4\n",
    "\n",
    "ONLY_2ND_TRIAL = True\n",
    "\n",
    "AGES_TO_REMOVE = list(range(26, 32 +1))\n",
    "\n",
    "SAVE_PLOTS = False\n",
    "SAVE_DIR = None"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "168265bd",
   "metadata": {},
   "source": [
    "### Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "52eb0aa2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['r101_p20', 'r104_p26', 'r112_p40', 'r115_p24', 'r118_p24']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ratnames_science = []\n",
    "ratnames_muessig = []\n",
    "\n",
    "data_dir = os.path.join(DATA_DIR, 'real_data', 'science2010_data_struct')\n",
    "data = {}\n",
    "for file in os.listdir(data_dir):\n",
    "    if file.endswith('.mat'):\n",
    "        name = file.split('.')[0].lower()\n",
    "        if 'shuffled' in name:\n",
    "            name += '_science2010'\n",
    "        else:\n",
    "            ratnames_science.append(name)\n",
    "        data[name] = scipy.io.loadmat(os.path.join(data_dir, file))\n",
    "\n",
    "data_dir = os.path.join(DATA_DIR, 'real_data', 'muessig_data_struct')\n",
    "for file in os.listdir(data_dir):\n",
    "    if file.endswith('.mat'):\n",
    "        name = file.split('.')[0].lower()\n",
    "        if 'shuffled' in name:\n",
    "            name += '_muessig'\n",
    "        else:\n",
    "            ratnames_muessig.append(name)\n",
    "            \n",
    "        if name in data.keys():\n",
    "            raise ValueError(f\"Duplicate file name: {name}\")\n",
    "        data[name] = scipy.io.loadmat(os.path.join(data_dir, file))\n",
    "\n",
    "sorted(list(data.keys()))[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "abf44aab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r1308_d4\n",
      "\t1 trial(s)\n",
      "\n",
      "r1526_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r1343_d4\n",
      "\t1 trial(s)\n",
      "\n",
      "r1526_p23\n",
      "\t1 trial(s)\n",
      "\n",
      "r1343_d1\n",
      "\t1 trial(s)\n",
      "\n",
      "r1589_p22\n",
      "\t1 trial(s)\n",
      "\n",
      "r1589_p28\n",
      "\t1 trial(s)\n",
      "\n",
      "r1333_d1\n",
      "\t1 trial(s)\n",
      "\n",
      "r1477_p29\n",
      "\t1 trial(s)\n",
      "\n",
      "r1552_p22\n",
      "\t1 trial(s)\n",
      "\n",
      "r1637_p23\n",
      "\t1 trial(s)\n",
      "\n",
      "r1588_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r1589_p23\n",
      "\t1 trial(s)\n",
      "\n",
      "r1589_p27\n",
      "\t1 trial(s)\n",
      "\n",
      "r1588_p22\n",
      "\t1 trial(s)\n",
      "\n",
      "r1515_p23\n",
      "\t1 trial(s)\n",
      "\n",
      "r1589_p17\n",
      "\t1 trial(s)\n",
      "\n",
      "r1589_p26\n",
      "\t1 trial(s)\n",
      "\n",
      "r1552_p17\n",
      "\t1 trial(s)\n",
      "\n",
      "r1308_d1\n",
      "\t1 trial(s)\n",
      "\n",
      "r1526_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r1589_p21\n",
      "\t1 trial(s)\n",
      "\n",
      "r1552_p16_1\n",
      "\t1 trial(s)\n",
      "\n",
      "r1589_p25\n",
      "\t1 trial(s)\n",
      "\n",
      "r1333_d2\n",
      "\t1 trial(s)\n",
      "\n",
      "r1628_p22_ca1\n",
      "\t1 trial(s)\n",
      "\n",
      "r1589_p19\n",
      "\t1 trial(s)\n",
      "\n",
      "r1515_p22\n",
      "\t1 trial(s)\n",
      "\n",
      "r1526_p21\n",
      "\t1 trial(s)\n",
      "\n",
      "r1588_p24\n",
      "\t1 trial(s)\n",
      "\n",
      "r1474_p25\n",
      "\t1 trial(s)\n",
      "\n",
      "r1552_p16_2\n",
      "\t1 trial(s)\n",
      "\n",
      "r1588_p16\n",
      "\t1 trial(s)\n",
      "\n",
      "shuffled_metrics_science2010\n",
      "File name shuffled_metrics_science2010 does not start with 'r', skipping\n",
      "shuffled_metrics_adult_science2010\n",
      "File name shuffled_metrics_adult_science2010 does not start with 'r', skipping\n",
      "r1588_p21\n",
      "\t1 trial(s)\n",
      "\n",
      "r1589_p24\n",
      "\t1 trial(s)\n",
      "\n",
      "r1546_p17\n",
      "\t1 trial(s)\n",
      "\n",
      "r1490_p22\n",
      "\t1 trial(s)\n",
      "\n",
      "r1590_p22\n",
      "\t1 trial(s)\n",
      "\n",
      "r1474_p27\n",
      "\t1 trial(s)\n",
      "\n",
      "r1546_p19\n",
      "\t1 trial(s)\n",
      "\n",
      "r1526_p19\n",
      "\t1 trial(s)\n",
      "\n",
      "r1526_p22\n",
      "\t1 trial(s)\n",
      "\n",
      "r1498_p21\n",
      "\t1 trial(s)\n",
      "\n",
      "r1552_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r1262_d1\n",
      "\t1 trial(s)\n",
      "\n",
      "r1552_p19\n",
      "\t1 trial(s)\n",
      "\n",
      "r1552_p21\n",
      "\t1 trial(s)\n",
      "\n",
      "r1588_p19\n",
      "\t1 trial(s)\n",
      "\n",
      "r1589_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r1637_p24\n",
      "\t0 trial(s)\n",
      "\n",
      "r1546_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r1628_p24_ca1\n",
      "\t1 trial(s)\n",
      "\n",
      "r1512_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r1515_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r1588_p23_2\n",
      "\t1 trial(s)\n",
      "\n",
      "r1589_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r1552_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r1588_p17\n",
      "\t1 trial(s)\n",
      "\n",
      "r1515_p21\n",
      "\t1 trial(s)\n",
      "\n",
      "r1262_d3\n",
      "\t1 trial(s)\n",
      "\n",
      "r1588_p23_1\n",
      "\t1 trial(s)\n",
      "\n",
      "r1588_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r1589_p16\n",
      "\t1 trial(s)\n",
      "\n",
      "r1546_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r732_p26\n",
      "\t0 trial(s)\n",
      "\n",
      "r572_p20\n",
      "\t0 trial(s)\n",
      "\n",
      "r1783_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r1776_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r1770_p19\n",
      "\t1 trial(s)\n",
      "\n",
      "r14_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r1776_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r67_p17\n",
      "\t1 trial(s)\n",
      "\n",
      "r379_p28\n",
      "\t1 trial(s)\n",
      "\n",
      "r98_p19\n",
      "\t1 trial(s)\n",
      "\n",
      "r1770_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r44_p40\n",
      "\t1 trial(s)\n",
      "\n",
      "r711_p22\n",
      "\t0 trial(s)\n",
      "\n",
      "r118_p26\n",
      "\t1 trial(s)\n",
      "\n",
      "r716_p26\n",
      "\t0 trial(s)\n",
      "\n",
      "r1783_p15\n",
      "\t1 trial(s)\n",
      "\n",
      "r732_p31\n",
      "\t0 trial(s)\n",
      "\n",
      "r739_p30\n",
      "\t0 trial(s)\n",
      "\n",
      "r566_p17\n",
      "\t1 trial(s)\n",
      "\n",
      "r378_p30\n",
      "\t1 trial(s)\n",
      "\n",
      "r72_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r66_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r663_p19\n",
      "\t1 trial(s)\n",
      "\n",
      "r66_p16\n",
      "\t1 trial(s)\n",
      "\n",
      "r1770_p15\n",
      "\t1 trial(s)\n",
      "\n",
      "r572_p22\n",
      "\t0 trial(s)\n",
      "\n",
      "r1771_p14\n",
      "\t1 trial(s)\n",
      "\n",
      "r1917_p19\n",
      "\t1 trial(s)\n",
      "\n",
      "r631_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r573_p20\n",
      "\t0 trial(s)\n",
      "\n",
      "r596_p22\n",
      "\t1 trial(s)\n",
      "\n",
      "r86_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r733_p28\n",
      "\t0 trial(s)\n",
      "\n",
      "r1776_p16\n",
      "\t1 trial(s)\n",
      "\n",
      "r574_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r85_p17\n",
      "\t1 trial(s)\n",
      "\n",
      "r129_p40\n",
      "\t1 trial(s)\n",
      "\n",
      "r586_p24\n",
      "\t0 trial(s)\n",
      "\n",
      "r27_p40\n",
      "\t1 trial(s)\n",
      "\n",
      "r710_p22\n",
      "\t0 trial(s)\n",
      "\n",
      "r97_p17\n",
      "\t1 trial(s)\n",
      "\n",
      "r378_p29\n",
      "\t1 trial(s)\n",
      "\n",
      "r724_p25\n",
      "\t0 trial(s)\n",
      "\n",
      "r739_p29\n",
      "\t0 trial(s)\n",
      "\n",
      "r65_p19\n",
      "\t1 trial(s)\n",
      "\n",
      "r72_p21\n",
      "\t1 trial(s)\n",
      "\n",
      "r65_p17\n",
      "\t1 trial(s)\n",
      "\n",
      "r97_p21\n",
      "\t1 trial(s)\n",
      "\n",
      "r731_p30\n",
      "\t0 trial(s)\n",
      "\n",
      "r28_p16\n",
      "\t1 trial(s)\n",
      "\n",
      "r66_p21\n",
      "\t1 trial(s)\n",
      "\n",
      "r573_p19\n",
      "\t1 trial(s)\n",
      "\n",
      "r86_p19\n",
      "\t1 trial(s)\n",
      "\n",
      "r678_p40\n",
      "\t1 trial(s)\n",
      "\n",
      "r579_p22\n",
      "\t0 trial(s)\n",
      "\n",
      "r86_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r659_p22\n",
      "\t1 trial(s)\n",
      "\n",
      "r98_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r661_p40\n",
      "\t1 trial(s)\n",
      "\n",
      "r716_p30\n",
      "\t0 trial(s)\n",
      "\n",
      "r28_p17\n",
      "\t1 trial(s)\n",
      "\n",
      "r98_p17\n",
      "\t1 trial(s)\n",
      "\n",
      "r631_p21\n",
      "\t1 trial(s)\n",
      "\n",
      "r76_p40\n",
      "\t1 trial(s)\n",
      "\n",
      "r14_p17\n",
      "\t1 trial(s)\n",
      "\n",
      "r713_p29\n",
      "\t0 trial(s)\n",
      "\n",
      "r710_p21\n",
      "\t0 trial(s)\n",
      "\n",
      "r1776_p21\n",
      "\t1 trial(s)\n",
      "\n",
      "r67_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r1770_p17\n",
      "\t1 trial(s)\n",
      "\n",
      "r1770_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r45_p40\n",
      "\t1 trial(s)\n",
      "\n",
      "r474_p40\n",
      "\t1 trial(s)\n",
      "\n",
      "r738_p26\n",
      "\t0 trial(s)\n",
      "\n",
      "r726_p24\n",
      "\t0 trial(s)\n",
      "\n",
      "r14_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r658_p17\n",
      "\t1 trial(s)\n",
      "\n",
      "r1919_p19\n",
      "\t1 trial(s)\n",
      "\n",
      "r4_p15\n",
      "\t1 trial(s)\n",
      "\n",
      "r115_p24\n",
      "\t1 trial(s)\n",
      "\n",
      "r378_p27\n",
      "\t1 trial(s)\n",
      "\n",
      "r631_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r13_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r73_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "shuffled_metrics_muessig\n",
      "File name shuffled_metrics_muessig does not start with 'r', skipping\n",
      "r97_p19\n",
      "\t1 trial(s)\n",
      "\n",
      "r1776_p17\n",
      "\t1 trial(s)\n",
      "\n",
      "r379_p26\n",
      "\t1 trial(s)\n",
      "\n",
      "r87_p17\n",
      "\t1 trial(s)\n",
      "\n",
      "r73_p24\n",
      "\t1 trial(s)\n",
      "\n",
      "r572_p19\n",
      "\t1 trial(s)\n",
      "\n",
      "r97_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r1783_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r1770_p16\n",
      "\t1 trial(s)\n",
      "\n",
      "r71_p22\n",
      "\t1 trial(s)\n",
      "\n",
      "r87_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r85_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r1771_p17\n",
      "\t1 trial(s)\n",
      "\n",
      "r574_p21\n",
      "\t1 trial(s)\n",
      "\n",
      "r75_p16\n",
      "\t1 trial(s)\n",
      "\n",
      "r574_p20\n",
      "\t0 trial(s)\n",
      "\n",
      "r15_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r87_p19\n",
      "\t1 trial(s)\n",
      "\n",
      "r715_p27\n",
      "\t0 trial(s)\n",
      "\n",
      "r731_p32\n",
      "\t0 trial(s)\n",
      "\n",
      "r1770_p14\n",
      "\t1 trial(s)\n",
      "\n",
      "r65_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r96_p16\n",
      "\t1 trial(s)\n",
      "\n",
      "r96_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r1843_p40\n",
      "\t1 trial(s)\n",
      "\n",
      "r1783_p16\n",
      "\t1 trial(s)\n",
      "\n",
      "r1771_p16\n",
      "\t1 trial(s)\n",
      "\n",
      "r739_p31\n",
      "\t0 trial(s)\n",
      "\n",
      "r710_p20\n",
      "\t0 trial(s)\n",
      "\n",
      "r566_p25\n",
      "\t1 trial(s)\n",
      "\n",
      "r101_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r737_p26\n",
      "\t0 trial(s)\n",
      "\n",
      "r596_p23\n",
      "\t1 trial(s)\n",
      "\n",
      "r1776_p19\n",
      "\t1 trial(s)\n",
      "\n",
      "r4_p14\n",
      "\t1 trial(s)\n",
      "\n",
      "r118_p25\n",
      "\t1 trial(s)\n",
      "\n",
      "r716_p25\n",
      "\t0 trial(s)\n",
      "\n",
      "r737_p25\n",
      "\t0 trial(s)\n",
      "\n",
      "r14_p19\n",
      "\t1 trial(s)\n",
      "\n",
      "r1919_p15\n",
      "\t1 trial(s)\n",
      "\n",
      "r86_p22\n",
      "\t1 trial(s)\n",
      "\n",
      "r733_p27\n",
      "\t0 trial(s)\n",
      "\n",
      "r659_p21\n",
      "\t1 trial(s)\n",
      "\n",
      "r1919_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r673_p40\n",
      "\t1 trial(s)\n",
      "\n",
      "r738_p27\n",
      "\t0 trial(s)\n",
      "\n",
      "r96_p18\n",
      "\t1 trial(s)\n",
      "\n",
      "r118_p27\n",
      "\t1 trial(s)\n",
      "\n",
      "r711_p23\n",
      "\t0 trial(s)\n",
      "\n",
      "r1783_p19\n",
      "\t1 trial(s)\n",
      "\n",
      "r104_p26\n",
      "\t1 trial(s)\n",
      "\n",
      "r32_p40\n",
      "\t1 trial(s)\n",
      "\n",
      "r2_p40\n",
      "\t1 trial(s)\n",
      "\n",
      "r86_p21\n",
      "\t1 trial(s)\n",
      "\n",
      "r1776_p15\n",
      "\t1 trial(s)\n",
      "\n",
      "r98_p20\n",
      "\t1 trial(s)\n",
      "\n",
      "r112_p40\n",
      "\t1 trial(s)\n",
      "\n",
      "r118_p24\n",
      "\t1 trial(s)\n",
      "\n",
      "r574_p19\n",
      "\t0 trial(s)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def get(x):\n",
    "    return x[0][0]\n",
    "\n",
    "data_dict = {}\n",
    "ratnames_old = []\n",
    "\n",
    "for k in data.keys():\n",
    "    print(k)\n",
    "\n",
    "    if not k.startswith('r'):\n",
    "        print(f\"File name {k} does not start with 'r', skipping\")\n",
    "        continue\n",
    "\n",
    "    ratname = k\n",
    "    \n",
    "    if ratname not in data_dict.keys():\n",
    "        data_dict[ratname] = {}\n",
    "    \n",
    "    d = get(data[k]['tmpS'])\n",
    "    d_keys = list(d.dtype.names)\n",
    "\n",
    "    dataset = d[d_keys.index('dataset')][0].split('_')[-1]\n",
    "    ratnames_old.append(f\"{ratname}_{dataset}\")\n",
    "\n",
    "    # sample rate is always 50 Hz\n",
    "    ages = d[d_keys.index('age')][0] # age 40 denotes adult\n",
    "    ages = [a if a<100 else 40 for a in ages]\n",
    "\n",
    "    sample_rates = d[d_keys.index('sampleRate')][0]\n",
    "    env_types = d[d_keys.index('envType')][0]\n",
    "    ppm = d[d_keys.index('ppm')][0]\n",
    "    spike_times = d[d_keys.index('spikeTimes')][0]\n",
    "    is_cs_neuron = d[d_keys.index('isCSNeuron')][0]\n",
    "    has_min_freq = d[d_keys.index('hasMinFreq')][0]\n",
    "    pos = d[d_keys.index('positions')][0]\n",
    "    hd = d[d_keys.index('directions')][0] # degrees\n",
    "    speed = d[d_keys.index('speed')][0] # cm/s\n",
    "\n",
    "    rate_maps = d[d_keys.index('rateMaps')][0]\n",
    "    pos_occ = d[d_keys.index('posOccMaps')][0]\n",
    "    si = d[d_keys.index('SI')][0]\n",
    "    rate_maps_corr = d[d_keys.index('corrRateMaps')][0]\n",
    "    si_corr = d[d_keys.index('SICorr')][0]\n",
    "    rate_maps_hd8 = d[d_keys.index('rateMapsHD8')][0]\n",
    "    rate_maps_hd4 = d[d_keys.index('rateMapsHD4')][0]\n",
    "    \n",
    "    polar_maps = d[d_keys.index('polarMaps')][0]\n",
    "    si_pm = d[d_keys.index('dirSI')][0]\n",
    "    rvl = d[d_keys.index('rvLength')][0]\n",
    "    hd_occ = d[d_keys.index('dirOccMaps')][0]\n",
    "    polar_maps_corr = d[d_keys.index('corrPolarMaps')][0]\n",
    "    si_pm_corr = d[d_keys.index('dirSICorr')][0]\n",
    "    rvl_corr = d[d_keys.index('rvLengthCorr')][0]\n",
    "    polar_maps_pred = d[d_keys.index('predPolarMaps')][0]\n",
    "    dis_ratios = d[d_keys.index('disRatios')][0]\n",
    "\n",
    "    # there are always max 3 trials per day\n",
    "    n_trials = 0\n",
    "    for trial_n in range(len(ages)):\n",
    "        if ONLY_2ND_TRIAL and (trial_n != 1) : continue # keep second trial\n",
    "\n",
    "        rms = rate_maps[trial_n]\n",
    "        sis = si[trial_n]\n",
    "        rms_corr = rate_maps_corr[trial_n]\n",
    "        sis_corr = si_corr[trial_n]\n",
    "        rms_hd8 = rate_maps_hd8[trial_n]\n",
    "        rms_hd4 = rate_maps_hd4[trial_n]\n",
    "\n",
    "        pms = polar_maps[trial_n]\n",
    "        sis_pm = si_pm[trial_n]\n",
    "        rvls = rvl[trial_n]\n",
    "        pms_corr = polar_maps_corr[trial_n]\n",
    "        sis_pm_corr = si_pm_corr[trial_n]\n",
    "        rvls_corr = rvl_corr[trial_n]\n",
    "        drs = dis_ratios[trial_n]\n",
    "        pms_pred = polar_maps_pred[trial_n]\n",
    "        if rms.shape[-1] == 0 or pms.shape[-1] == 0:\n",
    "            print(f\"Skipping trial {trial_n} because rate maps or polar maps are empty\")\n",
    "            continue\n",
    "\n",
    "        t = {}\n",
    "        age = ages[trial_n]\n",
    "        env = env_types[trial_n][0]\n",
    "        p = pos[trial_n]\n",
    "    \n",
    "        if np.isnan(age) and (len(env) == 0) and (p.shape[-1] == 0):\n",
    "            continue\n",
    "        n_trials += 1\n",
    "\n",
    "        age = int(age)\n",
    "        if age not in data_dict[ratname].keys():\n",
    "            data_dict[ratname][age] = {}\n",
    "            data_dict[ratname][age]['trials'] = []\n",
    "\n",
    "        t['name'] = n_trials\n",
    "        t['environment'] = env\n",
    "        t['ppm'] = ppm\n",
    "        t['sample_rate'] = sample_rates[trial_n] # Hz\n",
    "        t['positions'] = p\n",
    "        t['x'] = p[:,0]\n",
    "        t['y'] = p[:,1]\n",
    "        if ratname in ratnames_science:\n",
    "            t['spike_times'] = spike_times[trial_n][0]\n",
    "            t['is_cs_neuron'] = is_cs_neuron[trial_n][0]\n",
    "            t['has_min_freq'] = has_min_freq[trial_n][0]\n",
    "        else:\n",
    "            t['spike_times'] = spike_times[trial_n].squeeze() if len(spike_times[trial_n])>1 else spike_times[trial_n][0]\n",
    "            t['is_cs_neuron'] = is_cs_neuron[trial_n].squeeze() if len(is_cs_neuron[trial_n])>1 else is_cs_neuron[trial_n][0]\n",
    "            t['has_min_freq'] = has_min_freq[trial_n].squeeze() if len(has_min_freq[trial_n])>1 else has_min_freq[trial_n][0]\n",
    "\n",
    "        t['speed'] = speed[trial_n].squeeze()/100 # m/s\n",
    "        t['hd'] = hd[trial_n].squeeze()\n",
    "        t['duration'] = len(t['x'])/t['sample_rate']\n",
    "\n",
    "        # convert to (n_units, n_bins, n_bins)\n",
    "        t['rate_maps'] = np.array([rms[idx][0] for idx in range(len(rms))])\n",
    "        t['si'] = np.array([sis[idx][0] for idx in range(len(sis))])\n",
    "        t['rate_maps_corr'] = np.array([rms_corr[idx][0] for idx in range(len(rms_corr))])\n",
    "        t['si_corr'] = np.array([sis_corr[idx][0] for idx in range(len(sis_corr))])\n",
    "\n",
    "        rms_hd_np8 = np.zeros((t['rate_maps'].shape[0], 8, t['rate_maps'].shape[-1], t['rate_maps'].shape[-1]))\n",
    "        rms_hd_np4 = np.zeros((t['rate_maps'].shape[0], 4, t['rate_maps'].shape[-1], t['rate_maps'].shape[-1]))\n",
    "        for j in range(8):\n",
    "            if j < 4:\n",
    "                rms_hd_np4[:, j, ...] = np.array(\n",
    "                    [get(rms_hd4[idx])[j] for idx in range(len(rms_hd4))]\n",
    "                )\n",
    "            rms_hd_np8[:, j, ...] = np.array(\n",
    "                [get(rms_hd8[idx])[j] for idx in range(len(rms_hd8))]\n",
    "            )\n",
    "        # convert to (n_units, 8, n_bins, n_bins)\n",
    "        t['rate_maps_hd8'] = rms_hd_np8\n",
    "        t['rate_maps_hd4'] = rms_hd_np4\n",
    "\n",
    "        # convert to (n_units, n_bins)\n",
    "        t['polar_maps'] = np.array([pms[idx][0] for idx in range(len(pms))])[..., 0]\n",
    "        t['si_pm'] = np.array([sis_pm[idx][0] for idx in range(len(sis_pm))])\n",
    "        t['rvl'] = np.array([rvls[idx][0] for idx in range(len(rvls))])\n",
    "        t['polar_maps_corr'] = np.array([pms_corr[idx][0] for idx in range(len(pms_corr))])[..., 0]\n",
    "        t['si_pm_corr'] = np.array([sis_pm_corr[idx][0] for idx in range(len(sis_pm_corr))])\n",
    "        t['rvl_corr'] = np.array([rvls_corr[idx][0] for idx in range(len(rvls_corr))])\n",
    "\n",
    "        t['polar_maps_pred'] = np.array([pms_pred[idx][0] for idx in range(len(pms_pred))])[..., 0]\n",
    "        t['dis_ratios'] = np.array(drs[:,0])\n",
    "\n",
    "        if t['rate_maps'].shape[0] != t['polar_maps'].shape[0]:\n",
    "            raise ValueError(f\"Rate maps ({t['rate_maps'].shape}) and polar maps ({t['polar_maps'].shape}) have different number of units\")\n",
    "\n",
    "        t['pos_occ'] = get(pos_occ[trial_n])\n",
    "        t['hd_occ'] = np.array(hd_occ[trial_n][:,0])\n",
    "\n",
    "        data_dict[ratname][age]['trials'].append(t)\n",
    "    print(f\"\\t{n_trials} trial(s)\")\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "75cf24d7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------\n",
      "r1308_d4\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r1526_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r1343_d4\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r1526_p23\n",
      "ages [23]\n",
      "\n",
      "--------------------------------\n",
      "r1343_d1\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r1589_p22\n",
      "ages [22]\n",
      "\n",
      "--------------------------------\n",
      "r1589_p28\n",
      "ages [28]\n",
      "\n",
      "--------------------------------\n",
      "r1333_d1\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r1477_p29\n",
      "ages [29]\n",
      "\n",
      "--------------------------------\n",
      "r1552_p22\n",
      "ages [22]\n",
      "\n",
      "--------------------------------\n",
      "r1637_p23\n",
      "ages [23]\n",
      "\n",
      "--------------------------------\n",
      "r1588_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r1589_p23\n",
      "ages [23]\n",
      "\n",
      "--------------------------------\n",
      "r1589_p27\n",
      "ages [27]\n",
      "\n",
      "--------------------------------\n",
      "r1588_p22\n",
      "ages [22]\n",
      "\n",
      "--------------------------------\n",
      "r1515_p23\n",
      "ages [23]\n",
      "\n",
      "--------------------------------\n",
      "r1589_p17\n",
      "ages [17]\n",
      "\n",
      "--------------------------------\n",
      "r1589_p26\n",
      "ages [26]\n",
      "\n",
      "--------------------------------\n",
      "r1552_p17\n",
      "ages [17]\n",
      "\n",
      "--------------------------------\n",
      "r1308_d1\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r1526_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r1589_p21\n",
      "ages [21]\n",
      "\n",
      "--------------------------------\n",
      "r1552_p16_1\n",
      "ages [16]\n",
      "\n",
      "--------------------------------\n",
      "r1589_p25\n",
      "ages [25]\n",
      "\n",
      "--------------------------------\n",
      "r1333_d2\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r1628_p22_ca1\n",
      "ages [22]\n",
      "\n",
      "--------------------------------\n",
      "r1589_p19\n",
      "ages [19]\n",
      "\n",
      "--------------------------------\n",
      "r1515_p22\n",
      "ages [22]\n",
      "\n",
      "--------------------------------\n",
      "r1526_p21\n",
      "ages [21]\n",
      "\n",
      "--------------------------------\n",
      "r1588_p24\n",
      "ages [24]\n",
      "\n",
      "--------------------------------\n",
      "r1474_p25\n",
      "ages [25]\n",
      "\n",
      "--------------------------------\n",
      "r1552_p16_2\n",
      "ages [16]\n",
      "\n",
      "--------------------------------\n",
      "r1588_p16\n",
      "ages [16]\n",
      "\n",
      "--------------------------------\n",
      "r1588_p21\n",
      "ages [21]\n",
      "\n",
      "--------------------------------\n",
      "r1589_p24\n",
      "ages [24]\n",
      "\n",
      "--------------------------------\n",
      "r1546_p17\n",
      "ages [17]\n",
      "\n",
      "--------------------------------\n",
      "r1490_p22\n",
      "ages [22]\n",
      "\n",
      "--------------------------------\n",
      "r1590_p22\n",
      "ages [22]\n",
      "\n",
      "--------------------------------\n",
      "r1474_p27\n",
      "ages [27]\n",
      "\n",
      "--------------------------------\n",
      "r1546_p19\n",
      "ages [19]\n",
      "\n",
      "--------------------------------\n",
      "r1526_p19\n",
      "ages [19]\n",
      "\n",
      "--------------------------------\n",
      "r1526_p22\n",
      "ages [22]\n",
      "\n",
      "--------------------------------\n",
      "r1498_p21\n",
      "ages [21]\n",
      "\n",
      "--------------------------------\n",
      "r1552_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r1262_d1\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r1552_p19\n",
      "ages [19]\n",
      "\n",
      "--------------------------------\n",
      "r1552_p21\n",
      "ages [21]\n",
      "\n",
      "--------------------------------\n",
      "r1588_p19\n",
      "ages [19]\n",
      "\n",
      "--------------------------------\n",
      "r1589_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r1637_p24\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r1546_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r1628_p24_ca1\n",
      "ages [24]\n",
      "\n",
      "--------------------------------\n",
      "r1512_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r1515_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r1588_p23_2\n",
      "ages [23]\n",
      "\n",
      "--------------------------------\n",
      "r1589_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r1552_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r1588_p17\n",
      "ages [17]\n",
      "\n",
      "--------------------------------\n",
      "r1515_p21\n",
      "ages [21]\n",
      "\n",
      "--------------------------------\n",
      "r1262_d3\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r1588_p23_1\n",
      "ages [23]\n",
      "\n",
      "--------------------------------\n",
      "r1588_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r1589_p16\n",
      "ages [16]\n",
      "\n",
      "--------------------------------\n",
      "r1546_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r732_p26\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r572_p20\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r1783_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r1776_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r1770_p19\n",
      "ages [19]\n",
      "\n",
      "--------------------------------\n",
      "r14_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r1776_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r67_p17\n",
      "ages [17]\n",
      "\n",
      "--------------------------------\n",
      "r379_p28\n",
      "ages [28]\n",
      "\n",
      "--------------------------------\n",
      "r98_p19\n",
      "ages [19]\n",
      "\n",
      "--------------------------------\n",
      "r1770_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r44_p40\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r711_p22\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r118_p26\n",
      "ages [26]\n",
      "\n",
      "--------------------------------\n",
      "r716_p26\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r1783_p15\n",
      "ages [15]\n",
      "\n",
      "--------------------------------\n",
      "r732_p31\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r739_p30\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r566_p17\n",
      "ages [17]\n",
      "\n",
      "--------------------------------\n",
      "r378_p30\n",
      "ages [30]\n",
      "\n",
      "--------------------------------\n",
      "r72_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r66_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r663_p19\n",
      "ages [19]\n",
      "\n",
      "--------------------------------\n",
      "r66_p16\n",
      "ages [16]\n",
      "\n",
      "--------------------------------\n",
      "r1770_p15\n",
      "ages [15]\n",
      "\n",
      "--------------------------------\n",
      "r572_p22\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r1771_p14\n",
      "ages [14]\n",
      "\n",
      "--------------------------------\n",
      "r1917_p19\n",
      "ages [19]\n",
      "\n",
      "--------------------------------\n",
      "r631_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r573_p20\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r596_p22\n",
      "ages [22]\n",
      "\n",
      "--------------------------------\n",
      "r86_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r733_p28\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r1776_p16\n",
      "ages [16]\n",
      "\n",
      "--------------------------------\n",
      "r574_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r85_p17\n",
      "ages [17]\n",
      "\n",
      "--------------------------------\n",
      "r129_p40\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r586_p24\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r27_p40\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r710_p22\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r97_p17\n",
      "ages [17]\n",
      "\n",
      "--------------------------------\n",
      "r378_p29\n",
      "ages [29]\n",
      "\n",
      "--------------------------------\n",
      "r724_p25\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r739_p29\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r65_p19\n",
      "ages [19]\n",
      "\n",
      "--------------------------------\n",
      "r72_p21\n",
      "ages [21]\n",
      "\n",
      "--------------------------------\n",
      "r65_p17\n",
      "ages [17]\n",
      "\n",
      "--------------------------------\n",
      "r97_p21\n",
      "ages [21]\n",
      "\n",
      "--------------------------------\n",
      "r731_p30\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r28_p16\n",
      "ages [16]\n",
      "\n",
      "--------------------------------\n",
      "r66_p21\n",
      "ages [21]\n",
      "\n",
      "--------------------------------\n",
      "r573_p19\n",
      "ages [19]\n",
      "\n",
      "--------------------------------\n",
      "r86_p19\n",
      "ages [19]\n",
      "\n",
      "--------------------------------\n",
      "r678_p40\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r579_p22\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r86_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r659_p22\n",
      "ages [22]\n",
      "\n",
      "--------------------------------\n",
      "r98_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r661_p40\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r716_p30\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r28_p17\n",
      "ages [17]\n",
      "\n",
      "--------------------------------\n",
      "r98_p17\n",
      "ages [17]\n",
      "\n",
      "--------------------------------\n",
      "r631_p21\n",
      "ages [21]\n",
      "\n",
      "--------------------------------\n",
      "r76_p40\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r14_p17\n",
      "ages [17]\n",
      "\n",
      "--------------------------------\n",
      "r713_p29\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r710_p21\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r1776_p21\n",
      "ages [21]\n",
      "\n",
      "--------------------------------\n",
      "r67_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r1770_p17\n",
      "ages [17]\n",
      "\n",
      "--------------------------------\n",
      "r1770_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r45_p40\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r474_p40\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r738_p26\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r726_p24\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r14_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r658_p17\n",
      "ages [17]\n",
      "\n",
      "--------------------------------\n",
      "r1919_p19\n",
      "ages [19]\n",
      "\n",
      "--------------------------------\n",
      "r4_p15\n",
      "ages [15]\n",
      "\n",
      "--------------------------------\n",
      "r115_p24\n",
      "ages [24]\n",
      "\n",
      "--------------------------------\n",
      "r378_p27\n",
      "ages [27]\n",
      "\n",
      "--------------------------------\n",
      "r631_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r13_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r73_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r97_p19\n",
      "ages [19]\n",
      "\n",
      "--------------------------------\n",
      "r1776_p17\n",
      "ages [17]\n",
      "\n",
      "--------------------------------\n",
      "r379_p26\n",
      "ages [26]\n",
      "\n",
      "--------------------------------\n",
      "r87_p17\n",
      "ages [17]\n",
      "\n",
      "--------------------------------\n",
      "r73_p24\n",
      "ages [24]\n",
      "\n",
      "--------------------------------\n",
      "r572_p19\n",
      "ages [19]\n",
      "\n",
      "--------------------------------\n",
      "r97_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r1783_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r1770_p16\n",
      "ages [16]\n",
      "\n",
      "--------------------------------\n",
      "r71_p22\n",
      "ages [22]\n",
      "\n",
      "--------------------------------\n",
      "r87_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r85_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r1771_p17\n",
      "ages [17]\n",
      "\n",
      "--------------------------------\n",
      "r574_p21\n",
      "ages [21]\n",
      "\n",
      "--------------------------------\n",
      "r75_p16\n",
      "ages [16]\n",
      "\n",
      "--------------------------------\n",
      "r574_p20\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r15_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r87_p19\n",
      "ages [19]\n",
      "\n",
      "--------------------------------\n",
      "r715_p27\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r731_p32\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r1770_p14\n",
      "ages [14]\n",
      "\n",
      "--------------------------------\n",
      "r65_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r96_p16\n",
      "ages [16]\n",
      "\n",
      "--------------------------------\n",
      "r96_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r1843_p40\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r1783_p16\n",
      "ages [16]\n",
      "\n",
      "--------------------------------\n",
      "r1771_p16\n",
      "ages [16]\n",
      "\n",
      "--------------------------------\n",
      "r739_p31\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r710_p20\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r566_p25\n",
      "ages [25]\n",
      "\n",
      "--------------------------------\n",
      "r101_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r737_p26\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r596_p23\n",
      "ages [23]\n",
      "\n",
      "--------------------------------\n",
      "r1776_p19\n",
      "ages [19]\n",
      "\n",
      "--------------------------------\n",
      "r4_p14\n",
      "ages [14]\n",
      "\n",
      "--------------------------------\n",
      "r118_p25\n",
      "ages [25]\n",
      "\n",
      "--------------------------------\n",
      "r716_p25\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r737_p25\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r14_p19\n",
      "ages [19]\n",
      "\n",
      "--------------------------------\n",
      "r1919_p15\n",
      "ages [15]\n",
      "\n",
      "--------------------------------\n",
      "r86_p22\n",
      "ages [22]\n",
      "\n",
      "--------------------------------\n",
      "r733_p27\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r659_p21\n",
      "ages [21]\n",
      "\n",
      "--------------------------------\n",
      "r1919_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r673_p40\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r738_p27\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r96_p18\n",
      "ages [18]\n",
      "\n",
      "--------------------------------\n",
      "r118_p27\n",
      "ages [27]\n",
      "\n",
      "--------------------------------\n",
      "r711_p23\n",
      "ages []\n",
      "\n",
      "--------------------------------\n",
      "r1783_p19\n",
      "ages [19]\n",
      "\n",
      "--------------------------------\n",
      "r104_p26\n",
      "ages [26]\n",
      "\n",
      "--------------------------------\n",
      "r32_p40\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r2_p40\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r86_p21\n",
      "ages [21]\n",
      "\n",
      "--------------------------------\n",
      "r1776_p15\n",
      "ages [15]\n",
      "\n",
      "--------------------------------\n",
      "r98_p20\n",
      "ages [20]\n",
      "\n",
      "--------------------------------\n",
      "r112_p40\n",
      "ages [40]\n",
      "\n",
      "--------------------------------\n",
      "r118_p24\n",
      "ages [24]\n",
      "\n",
      "--------------------------------\n",
      "r574_p19\n",
      "ages []\n",
      "\n",
      "--------------------------------\n"
     ]
    }
   ],
   "source": [
    "ages = []\n",
    "print('--------------------------------')\n",
    "for k, v in data_dict.items():\n",
    "    print(k)\n",
    "    ages_tmp = list(v.keys())\n",
    "    print('ages', ages_tmp)\n",
    "    ages += ages_tmp\n",
    "    print()\n",
    "    print('--------------------------------')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "246ebe2f",
   "metadata": {},
   "source": [
    "### Shuffled Threshold Extraction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8c228e90",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "shuffled_metrics_science2010\n",
      "shuffled_metrics_adult_science2010\n",
      "shuffled_metrics_muessig\n"
     ]
    }
   ],
   "source": [
    "metrics_shuffle_th_science = {}\n",
    "metrics_shuffle_th_muessig = {}\n",
    "\n",
    "for filename in data.keys():\n",
    "    if not filename.startswith('shuffled'):\n",
    "        continue\n",
    "    print(filename)\n",
    "\n",
    "    shuffled_si = data[filename]['shuffledSIByAge']\n",
    "    shuffled_si_pm = data[filename]['shuffledDirSIByAge']\n",
    "    shuffled_rvl = data[filename]['shuffledRVLByAge']\n",
    "    for idx in range(len(shuffled_si)):\n",
    "        for m, k in zip([shuffled_si, shuffled_si_pm, shuffled_rvl], ['SI', 'dirSI', 'rvLength']):\n",
    "            m_curr = get(m[idx][0])\n",
    "            keys = list(m_curr.dtype.names)\n",
    "            age_group = get(m_curr[keys.index('ageGroup')])\n",
    "            th = get(m_curr[keys.index(k+'Threshold')])\n",
    "\n",
    "            if 'science' in filename.lower():\n",
    "                if k not in metrics_shuffle_th_science.keys():\n",
    "                    metrics_shuffle_th_science[k] = {}\n",
    "\n",
    "                age = 14+age_group*2\n",
    "                if age != 100:\n",
    "                    metrics_shuffle_th_science[k][age] = th\n",
    "                    metrics_shuffle_th_science[k][age+1] = th\n",
    "                else : metrics_shuffle_th_science[k][40] = th\n",
    "            elif 'muessig' in filename.lower():\n",
    "                if k not in metrics_shuffle_th_muessig.keys():\n",
    "                    metrics_shuffle_th_muessig[k] = {}\n",
    "\n",
    "                age = 14+age_group*2\n",
    "                if age != 100:\n",
    "                    metrics_shuffle_th_muessig[k][age] = th\n",
    "                    metrics_shuffle_th_muessig[k][age+1] = th\n",
    "                else : metrics_shuffle_th_muessig[k][40] = th\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d0ebd18",
   "metadata": {},
   "source": [
    "### Activity Extraction\n",
    "\n",
    "A neuron is considered actually tuned to direction if its corrected polar map still passes the criterion for inclusion (RVL or KLD)\n",
    "\n",
    "and its Pearson correlation with the uncorrected polar map is higher than 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f2e202b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def normalize_rate_maps(rate_maps):\n",
    "    # normalize rate maps\n",
    "    rate_maps_min = np.moveaxis(\n",
    "        np.tile(np.nanmin(rate_maps, axis=(1,2)), (N_SAMPLES_POS, N_SAMPLES_POS, 1)), -1, 0\n",
    "    )\n",
    "    rate_maps_max = np.moveaxis(\n",
    "        np.tile(np.nanmax(rate_maps, axis=(1,2)), (N_SAMPLES_POS, N_SAMPLES_POS, 1)), -1, 0\n",
    "    )\n",
    "    rate_maps = (\n",
    "        (rate_maps - rate_maps_min) / (rate_maps_max - rate_maps_min)\n",
    "    )\n",
    "    return rate_maps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8265f1fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rat r101_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (8, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (8, 60)\n",
      "\n",
      "Rat r104_p26\n",
      "\tAge 26 in ages to remove, skipping\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rat r112_p40\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (27, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (27, 60)\n",
      "\n",
      "Rat r115_p24\n",
      "\tAge 24\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (9, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (9, 60)\n",
      "\n",
      "Rat r118_p24\n",
      "\tAge 24\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (21, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (21, 60)\n",
      "\n",
      "Rat r118_p25\n",
      "\tAge 25\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (11, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (11, 60)\n",
      "\n",
      "Rat r118_p26\n",
      "\tAge 26 in ages to remove, skipping\n",
      "\n",
      "Rat r118_p27\n",
      "\tAge 27 in ages to remove, skipping\n",
      "\n",
      "Rat r1262_d1\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (13, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (13, 60)\n",
      "\n",
      "Rat r1262_d3\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (12, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (12, 60)\n",
      "\n",
      "Rat r129_p40\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (9, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (9, 60)\n",
      "\n",
      "Rat r1308_d1\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (14, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (14, 60)\n",
      "\n",
      "Rat r1308_d4\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (11, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (11, 60)\n",
      "\n",
      "Rat r1333_d1\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (14, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (14, 60)\n",
      "\n",
      "Rat r1333_d2\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (28, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (28, 60)\n",
      "\n",
      "Rat r1343_d1\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (18, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (18, 60)\n",
      "\n",
      "Rat r1343_d4\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (10, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (10, 60)\n",
      "\n",
      "Rat r13_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (10, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (10, 60)\n",
      "\n",
      "Rat r1474_p25\n",
      "\tAge 25\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (3, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (3, 60)\n",
      "\n",
      "Rat r1474_p27\n",
      "\tAge 27 in ages to remove, skipping\n",
      "\n",
      "Rat r1477_p29\n",
      "\tAge 29 in ages to remove, skipping\n",
      "\n",
      "Rat r1490_p22\n",
      "\tAge 22\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (6, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (6, 60)\n",
      "\n",
      "Rat r1498_p21\n",
      "\tAge 21\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (11, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (11, 60)\n",
      "\n",
      "Rat r14_p17\n",
      "\tAge 17\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (18, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (18, 60)\n",
      "\n",
      "Rat r14_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (24, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (24, 60)\n",
      "\n",
      "Rat r14_p19\n",
      "\tAge 19\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (26, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (26, 60)\n",
      "\n",
      "Rat r14_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (33, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (33, 60)\n",
      "\n",
      "Rat r1512_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (12, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (12, 60)\n",
      "\n",
      "Rat r1515_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (12, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (12, 60)\n",
      "\n",
      "Rat r1515_p21\n",
      "\tAge 21\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (9, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (9, 60)\n",
      "\n",
      "Rat r1515_p22\n",
      "\tAge 22\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (11, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (11, 60)\n",
      "\n",
      "Rat r1515_p23\n",
      "\tAge 23\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (5, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (5, 60)\n",
      "\n",
      "Rat r1526_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (7, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (7, 60)\n",
      "\n",
      "Rat r1526_p19\n",
      "\tAge 19\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (14, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (14, 60)\n",
      "\n",
      "Rat r1526_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (9, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (9, 60)\n",
      "\n",
      "Rat r1526_p21\n",
      "\tAge 21\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (8, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (8, 60)\n",
      "\n",
      "Rat r1526_p22\n",
      "\tAge 22\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (9, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (9, 60)\n",
      "\n",
      "Rat r1526_p23\n",
      "\tAge 23\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (7, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (7, 60)\n",
      "\n",
      "Rat r1546_p17\n",
      "\tAge 17\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (2, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (2, 60)\n",
      "\n",
      "Rat r1546_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (17, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (17, 60)\n",
      "\n",
      "Rat r1546_p19\n",
      "\tAge 19\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (30, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (30, 60)\n",
      "\n",
      "Rat r1546_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (14, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (14, 60)\n",
      "\n",
      "Rat r1552_p16_1\n",
      "\tAge 16\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (1, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (1, 60)\n",
      "\n",
      "Rat r1552_p16_2\n",
      "\tAge 16\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (9, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (9, 60)\n",
      "\n",
      "Rat r1552_p17\n",
      "\tAge 17\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (36, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (36, 60)\n",
      "\n",
      "Rat r1552_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (23, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (23, 60)\n",
      "\n",
      "Rat r1552_p19\n",
      "\tAge 19\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (27, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (27, 60)\n",
      "\n",
      "Rat r1552_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (18, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (18, 60)\n",
      "\n",
      "Rat r1552_p21\n",
      "\tAge 21\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (17, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (17, 60)\n",
      "\n",
      "Rat r1552_p22\n",
      "\tAge 22\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (16, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (16, 60)\n",
      "\n",
      "Rat r1588_p16\n",
      "\tAge 16\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (1, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (1, 60)\n",
      "\n",
      "Rat r1588_p17\n",
      "\tAge 17\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (27, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (27, 60)\n",
      "\n",
      "Rat r1588_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (11, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (11, 60)\n",
      "\n",
      "Rat r1588_p19\n",
      "\tAge 19\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (14, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (14, 60)\n",
      "\n",
      "Rat r1588_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (13, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (13, 60)\n",
      "\n",
      "Rat r1588_p21\n",
      "\tAge 21\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (16, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (16, 60)\n",
      "\n",
      "Rat r1588_p22\n",
      "\tAge 22\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (9, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (9, 60)\n",
      "\n",
      "Rat r1588_p23_1\n",
      "\tAge 23\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (6, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (6, 60)\n",
      "\n",
      "Rat r1588_p23_2\n",
      "\tAge 23\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (9, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (9, 60)\n",
      "\n",
      "Rat r1588_p24\n",
      "\tAge 24\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (5, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (5, 60)\n",
      "\n",
      "Rat r1589_p16\n",
      "\tAge 16\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (2, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (2, 60)\n",
      "\n",
      "Rat r1589_p17\n",
      "\tAge 17\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (7, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (7, 60)\n",
      "\n",
      "Rat r1589_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (10, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (10, 60)\n",
      "\n",
      "Rat r1589_p19\n",
      "\tAge 19\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (9, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (9, 60)\n",
      "\n",
      "Rat r1589_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (10, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (10, 60)\n",
      "\n",
      "Rat r1589_p21\n",
      "\tAge 21\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (13, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (13, 60)\n",
      "\n",
      "Rat r1589_p22\n",
      "\tAge 22\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (16, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (16, 60)\n",
      "\n",
      "Rat r1589_p23\n",
      "\tAge 23\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (15, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (15, 60)\n",
      "\n",
      "Rat r1589_p24\n",
      "\tAge 24\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (11, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (11, 60)\n",
      "\n",
      "Rat r1589_p25\n",
      "\tAge 25\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (6, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (6, 60)\n",
      "\n",
      "Rat r1589_p26\n",
      "\tAge 26 in ages to remove, skipping\n",
      "\n",
      "Rat r1589_p27\n",
      "\tAge 27 in ages to remove, skipping\n",
      "\n",
      "Rat r1589_p28\n",
      "\tAge 28 in ages to remove, skipping\n",
      "\n",
      "Rat r1590_p22\n",
      "\tAge 22\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (1, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (1, 60)\n",
      "\n",
      "Rat r15_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (22, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (22, 60)\n",
      "\n",
      "Rat r1628_p22_ca1\n",
      "\tAge 22\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (5, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (5, 60)\n",
      "\n",
      "Rat r1628_p24_ca1\n",
      "\tAge 24\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (3, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (3, 60)\n",
      "\n",
      "Rat r1637_p23\n",
      "\tAge 23\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (1, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (1, 60)\n",
      "\n",
      "Rat r1637_p24\n",
      "\n",
      "Rat r1770_p14\n",
      "\tAge 14\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (1, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (1, 60)\n",
      "\n",
      "Rat r1770_p15\n",
      "\tAge 15\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (8, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (8, 60)\n",
      "\n",
      "Rat r1770_p16\n",
      "\tAge 16\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (15, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (15, 60)\n",
      "\n",
      "Rat r1770_p17\n",
      "\tAge 17\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (14, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (14, 60)\n",
      "\n",
      "Rat r1770_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (13, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (13, 60)\n",
      "\n",
      "Rat r1770_p19\n",
      "\tAge 19\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (7, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (7, 60)\n",
      "\n",
      "Rat r1770_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (2, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (2, 60)\n",
      "\tSkipping trial because all rate maps are uniform\n",
      "\n",
      "Rat r1771_p14\n",
      "\tAge 14\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (4, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (4, 60)\n",
      "\tSkipping trial because all rate maps are uniform\n",
      "\n",
      "Rat r1771_p16\n",
      "\tAge 16\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (5, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (5, 60)\n",
      "\n",
      "Rat r1771_p17\n",
      "\tAge 17\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (10, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (10, 60)\n",
      "\n",
      "Rat r1776_p15\n",
      "\tAge 15\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (17, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (17, 60)\n",
      "\n",
      "Rat r1776_p16\n",
      "\tAge 16\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (9, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (9, 60)\n",
      "\n",
      "Rat r1776_p17\n",
      "\tAge 17\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (8, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (8, 60)\n",
      "\n",
      "Rat r1776_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (7, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (7, 60)\n",
      "\n",
      "Rat r1776_p19\n",
      "\tAge 19\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (7, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (7, 60)\n",
      "\n",
      "Rat r1776_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (7, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (7, 60)\n",
      "\n",
      "Rat r1776_p21\n",
      "\tAge 21\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (3, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (3, 60)\n",
      "\n",
      "Rat r1783_p15\n",
      "\tAge 15\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (11, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (11, 60)\n",
      "\n",
      "Rat r1783_p16\n",
      "\tAge 16\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (13, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (13, 60)\n",
      "\n",
      "Rat r1783_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (13, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (13, 60)\n",
      "\n",
      "Rat r1783_p19\n",
      "\tAge 19\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (8, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (8, 60)\n",
      "\n",
      "Rat r1783_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (3, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (3, 60)\n",
      "\n",
      "Rat r1843_p40\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (10, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (10, 60)\n",
      "\n",
      "Rat r1917_p19\n",
      "\tAge 19\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (8, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (8, 60)\n",
      "\n",
      "Rat r1919_p15\n",
      "\tAge 15\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (18, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (18, 60)\n",
      "\n",
      "Rat r1919_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (5, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (5, 60)\n",
      "\n",
      "Rat r1919_p19\n",
      "\tAge 19\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (8, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (8, 60)\n",
      "\n",
      "Rat r27_p40\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (18, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (18, 60)\n",
      "\n",
      "Rat r28_p16\n",
      "\tAge 16\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (12, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (12, 60)\n",
      "\n",
      "Rat r28_p17\n",
      "\tAge 17\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (20, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (20, 60)\n",
      "\n",
      "Rat r2_p40\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (9, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (9, 60)\n",
      "\n",
      "Rat r32_p40\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (22, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (22, 60)\n",
      "\n",
      "Rat r378_p27\n",
      "\tAge 27 in ages to remove, skipping\n",
      "\n",
      "Rat r378_p29\n",
      "\tAge 29 in ages to remove, skipping\n",
      "\n",
      "Rat r378_p30\n",
      "\tAge 30 in ages to remove, skipping\n",
      "\n",
      "Rat r379_p26\n",
      "\tAge 26 in ages to remove, skipping\n",
      "\n",
      "Rat r379_p28\n",
      "\tAge 28 in ages to remove, skipping\n",
      "\n",
      "Rat r44_p40\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (27, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (27, 60)\n",
      "\n",
      "Rat r45_p40\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (27, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (27, 60)\n",
      "\n",
      "Rat r474_p40\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (10, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (10, 60)\n",
      "\n",
      "Rat r4_p14\n",
      "\tAge 14\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (46, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (46, 60)\n",
      "\n",
      "Rat r4_p15\n",
      "\tAge 15\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (21, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (21, 60)\n",
      "\n",
      "Rat r566_p17\n",
      "\tAge 17\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (52, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (52, 60)\n",
      "\n",
      "Rat r566_p25\n",
      "\tAge 25\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (15, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (15, 60)\n",
      "\n",
      "Rat r572_p19\n",
      "\tAge 19\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (65, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (65, 60)\n",
      "\n",
      "Rat r572_p20\n",
      "\n",
      "Rat r572_p22\n",
      "\n",
      "Rat r573_p19\n",
      "\tAge 19\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (29, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (29, 60)\n",
      "\n",
      "Rat r573_p20\n",
      "\n",
      "Rat r574_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (26, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (26, 60)\n",
      "\n",
      "Rat r574_p19\n",
      "\n",
      "Rat r574_p20\n",
      "\n",
      "Rat r574_p21\n",
      "\tAge 21\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (29, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (29, 60)\n",
      "\n",
      "Rat r579_p22\n",
      "\n",
      "Rat r586_p24\n",
      "\n",
      "Rat r596_p22\n",
      "\tAge 22\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (27, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (27, 60)\n",
      "\n",
      "Rat r596_p23\n",
      "\tAge 23\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (13, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (13, 60)\n",
      "\n",
      "Rat r631_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (35, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (35, 60)\n",
      "\n",
      "Rat r631_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (47, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (47, 60)\n",
      "\n",
      "Rat r631_p21\n",
      "\tAge 21\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (40, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (40, 60)\n",
      "\n",
      "Rat r658_p17\n",
      "\tAge 17\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (44, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (44, 60)\n",
      "\n",
      "Rat r659_p21\n",
      "\tAge 21\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (28, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (28, 60)\n",
      "\n",
      "Rat r659_p22\n",
      "\tAge 22\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (39, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (39, 60)\n",
      "\n",
      "Rat r65_p17\n",
      "\tAge 17\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (12, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (12, 60)\n",
      "\n",
      "Rat r65_p19\n",
      "\tAge 19\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (5, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (5, 60)\n",
      "\n",
      "Rat r65_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (3, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (3, 60)\n",
      "\n",
      "Rat r661_p40\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (56, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (56, 60)\n",
      "\n",
      "Rat r663_p19\n",
      "\tAge 19\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (37, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (37, 60)\n",
      "\n",
      "Rat r66_p16\n",
      "\tAge 16\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (13, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (13, 60)\n",
      "\n",
      "Rat r66_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (21, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (21, 60)\n",
      "\n",
      "Rat r66_p21\n",
      "\tAge 21\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (10, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (10, 60)\n",
      "\n",
      "Rat r673_p40\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (82, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (82, 60)\n",
      "\n",
      "Rat r678_p40\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (109, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (109, 60)\n",
      "\n",
      "Rat r67_p17\n",
      "\tAge 17\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (31, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (31, 60)\n",
      "\n",
      "Rat r67_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (10, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (10, 60)\n",
      "\n",
      "Rat r710_p20\n",
      "\n",
      "Rat r710_p21\n",
      "\n",
      "Rat r710_p22\n",
      "\n",
      "Rat r711_p22\n",
      "\n",
      "Rat r711_p23\n",
      "\n",
      "Rat r713_p29\n",
      "\n",
      "Rat r715_p27\n",
      "\n",
      "Rat r716_p25\n",
      "\n",
      "Rat r716_p26\n",
      "\n",
      "Rat r716_p30\n",
      "\n",
      "Rat r71_p22\n",
      "\tAge 22\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (5, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (5, 60)\n",
      "\n",
      "Rat r724_p25\n",
      "\n",
      "Rat r726_p24\n",
      "\n",
      "Rat r72_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (8, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (8, 60)\n",
      "\n",
      "Rat r72_p21\n",
      "\tAge 21\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (4, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (4, 60)\n",
      "\n",
      "Rat r731_p30\n",
      "\n",
      "Rat r731_p32\n",
      "\n",
      "Rat r732_p26\n",
      "\n",
      "Rat r732_p31\n",
      "\n",
      "Rat r733_p27\n",
      "\n",
      "Rat r733_p28\n",
      "\n",
      "Rat r737_p25\n",
      "\n",
      "Rat r737_p26\n",
      "\n",
      "Rat r738_p26\n",
      "\n",
      "Rat r738_p27\n",
      "\n",
      "Rat r739_p29\n",
      "\n",
      "Rat r739_p30\n",
      "\n",
      "Rat r739_p31\n",
      "\n",
      "Rat r73_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (13, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (13, 60)\n",
      "\n",
      "Rat r73_p24\n",
      "\tAge 24\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (10, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (10, 60)\n",
      "\n",
      "Rat r75_p16\n",
      "\tAge 16\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (16, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (16, 60)\n",
      "\n",
      "Rat r76_p40\n",
      "\tAge 40\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (21, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (21, 60)\n",
      "\n",
      "Rat r85_p17\n",
      "\tAge 17\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (19, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (19, 60)\n",
      "\n",
      "Rat r85_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (16, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (16, 60)\n",
      "\n",
      "Rat r86_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (14, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (14, 60)\n",
      "\n",
      "Rat r86_p19\n",
      "\tAge 19\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (11, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (11, 60)\n",
      "\n",
      "Rat r86_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (14, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (14, 60)\n",
      "\n",
      "Rat r86_p21\n",
      "\tAge 21\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (18, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (18, 60)\n",
      "\n",
      "Rat r86_p22\n",
      "\tAge 22\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (8, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (8, 60)\n",
      "\n",
      "Rat r87_p17\n",
      "\tAge 17\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (16, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (16, 60)\n",
      "\n",
      "Rat r87_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (11, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (11, 60)\n",
      "\n",
      "Rat r87_p19\n",
      "\tAge 19\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (18, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (18, 60)\n",
      "\n",
      "Rat r96_p16\n",
      "\tAge 16\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (22, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (22, 60)\n",
      "\n",
      "Rat r96_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (16, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (16, 60)\n",
      "\n",
      "Rat r96_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (11, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (11, 60)\n",
      "\n",
      "Rat r97_p17\n",
      "\tAge 17\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (15, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (15, 60)\n",
      "\n",
      "Rat r97_p19\n",
      "\tAge 19\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (21, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (21, 60)\n",
      "\n",
      "Rat r97_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (16, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (16, 60)\n",
      "\n",
      "Rat r97_p21\n",
      "\tAge 21\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (12, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (12, 60)\n",
      "\n",
      "Rat r98_p17\n",
      "\tAge 17\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (11, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (11, 60)\n",
      "\n",
      "Rat r98_p18\n",
      "\tAge 18\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (25, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (25, 60)\n",
      "\n",
      "Rat r98_p19\n",
      "\tAge 19\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (21, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (21, 60)\n",
      "\n",
      "Rat r98_p20\n",
      "\tAge 20\n",
      "\t(n_cells, n_samples_pos, n_samples_pos): (27, 25, 25)\n",
      "\t(n_cells, N_SAMPLES_THET): (27, 60)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from utils.spatial_units import RateMaps, PolarMaps\n",
    "\n",
    "place_units = RateMaps(positions=None, env_dim=0)\n",
    "hd_units = PolarMaps(thetas=None)\n",
    "\n",
    "data_dict_age = {}\n",
    "perc_kept = []\n",
    "\n",
    "for k, v in sorted(data_dict.items()):\n",
    "    print(f\"Rat {k}\")\n",
    "    if k in ratnames_science:\n",
    "        metrics_shuffle_th = metrics_shuffle_th_science\n",
    "    elif k in ratnames_muessig:\n",
    "        metrics_shuffle_th = metrics_shuffle_th_muessig\n",
    "    else:\n",
    "        raise ValueError(f\"Rat {k} not found in science or muessig data\")\n",
    "    \n",
    "    for age in sorted(v.keys()):\n",
    "        if age in AGES_TO_REMOVE:\n",
    "            print(f\"\\tAge {age} in ages to remove, skipping\")\n",
    "            continue\n",
    "        \n",
    "        exp = v[age] # get the experiment for this rat's age\n",
    "        print(f\"\\tAge {age}\")\n",
    "        if (age in data_dict_age.keys()) and (k in data_dict_age[age].keys()):\n",
    "            print(f\"\\tSkipping because already processed\")\n",
    "            continue\n",
    "\n",
    "        if age not in data_dict_age.keys(): # initialize all dict if first exp for this age\n",
    "            data_dict_age[age] = {}\n",
    "        \n",
    "        data_dict_age[age][k] = {}\n",
    "        for k_tmp in [\n",
    "            'positions', 'hd', 'speed', 'spike_times', 'sample_rate',\n",
    "            'rate_maps', 'pos_occ', 'rate_maps_corr',\n",
    "            'si_matlab', 'si_corr_matlab', 'si_rm', 'si_rm_corr',\n",
    "            'selected_place_units', 'n_fields',\n",
    "            'single_field_dim', 'pu_flipped', 'pu_field_flipped',\n",
    "            'rate_maps_hd8', 'rate_maps_hd4',\n",
    "            'polar_maps', 'hd_occ', 'polar_maps_corr',\n",
    "            'si_pm_matlab', 'rvl_matlab', 'si_pm', 'rvl_pm', 'si_pm_corr_matlab', 'rvl_corr_matlab',\n",
    "            'rvl_pm_corr', 'rvangle_pm', 'rvangle_pm_corr',\n",
    "            'selected_hd_units', 'selected_place_hd_units',\n",
    "            'polar_maps_pred', 'dis_ratios',\n",
    "            'trial_start_idx'\n",
    "        ]:\n",
    "            data_dict_age[age][k][k_tmp] = []\n",
    "\n",
    "        rate_maps_all = []\n",
    "        polar_maps_all = []\n",
    "        indices_to_keep = None\n",
    "        trial_start_idx = 0\n",
    "        for trial in exp['trials']:\n",
    "            if trial['environment'] != 'hp' and trial['environment'] != 'fam':\n",
    "                raise ValueError(f\"\\tEnvironment is {trial['environment']} instead of hp or fam\")\n",
    "\n",
    "            rate_maps = trial['rate_maps']\n",
    "            if rate_maps.shape[1] != N_SAMPLES_POS or rate_maps.shape[2] != N_SAMPLES_POS:\n",
    "                raise ValueError(f\"\\t\\tRate maps shape is {rate_maps.shape} instead of (n_cells, {N_SAMPLES_POS}, {N_SAMPLES_POS})\")\n",
    "            rate_maps_all.append(rate_maps.copy())\n",
    "\n",
    "            si_matlab = trial['si']\n",
    "            rate_maps_corr = trial['rate_maps_corr']\n",
    "            si_corr_matlab = trial['si_corr']\n",
    "            rate_maps_hd8 = trial['rate_maps_hd8']\n",
    "            rate_maps_hd4 = trial['rate_maps_hd4']\n",
    "\n",
    "            pos_occ = trial['pos_occ']\n",
    "\n",
    "            polar_maps = trial['polar_maps']\n",
    "            if polar_maps.shape[1] != N_SAMPLES_THET:\n",
    "                raise ValueError(f\"\\t\\Polar maps shape is {polar_maps.shape} instead of (n_cells, {N_SAMPLES_THET})\")\n",
    "            polar_maps_all.append(polar_maps.copy())\n",
    "\n",
    "            si_pm_matlab = trial['si_pm']\n",
    "            rvl_matlab = trial['rvl']\n",
    "            polar_maps_corr = trial['polar_maps_corr']\n",
    "            si_pm_corr_matlab = trial['si_pm_corr']\n",
    "            rvl_corr_matlab = trial['rvl_corr']\n",
    "            polar_maps_pred = trial['polar_maps_pred']\n",
    "            dis_ratios = trial['dis_ratios']\n",
    "\n",
    "            hd_occ = trial['hd_occ']\n",
    "\n",
    "            print(f\"\\t(n_cells, n_samples_pos, n_samples_pos): {rate_maps.shape}\")\n",
    "            print(f\"\\t(n_cells, N_SAMPLES_THET): {polar_maps.shape}\")\n",
    "\n",
    "            # keep only Complex Spike neurons\n",
    "            idx_to_keep = np.logical_and(\n",
    "                trial['is_cs_neuron'] == 1, trial['has_min_freq'] == 1\n",
    "            )\n",
    "            if isinstance(idx_to_keep, np.bool) : idx_to_keep = np.array([idx_to_keep])\n",
    "            assert len(idx_to_keep) == rate_maps.shape[0]\n",
    "            idx_to_keep = np.where(idx_to_keep)[0] # convert mask to indices\n",
    "            if len(idx_to_keep) == 0:\n",
    "                print(f\"\\tSkipping trial because all rate maps are uniform\")\n",
    "                continue\n",
    "\n",
    "            positions = trial['positions']\n",
    "            hd = trial['hd'].astype(np.float64)\n",
    "            speed = trial['speed']\n",
    "            spike_times = trial['spike_times']\n",
    "            sample_rate = trial['sample_rate']\n",
    "            \n",
    "            perc_kept.append(len(idx_to_keep)/rate_maps.shape[0]*100)\n",
    "            rate_maps = rate_maps[idx_to_keep]\n",
    "            si_matlab = si_matlab[idx_to_keep]\n",
    "            rate_maps_corr = rate_maps_corr[idx_to_keep]\n",
    "            si_corr_matlab = si_corr_matlab[idx_to_keep]\n",
    "            rate_maps_hd8 = rate_maps_hd8[idx_to_keep]\n",
    "            rate_maps_hd4 = rate_maps_hd4[idx_to_keep]\n",
    "\n",
    "            polar_maps = polar_maps[idx_to_keep]\n",
    "            si_pm_matlab = si_pm_matlab[idx_to_keep]\n",
    "            rvl_matlab = rvl_matlab[idx_to_keep]\n",
    "            polar_maps_corr = polar_maps_corr[idx_to_keep]\n",
    "            si_pm_corr_matlab = si_pm_corr_matlab[idx_to_keep]\n",
    "            rvl_corr_matlab = rvl_corr_matlab[idx_to_keep]\n",
    "            polar_maps_pred = polar_maps_pred[idx_to_keep]\n",
    "            dis_ratios = dis_ratios[idx_to_keep]\n",
    "\n",
    "            if indices_to_keep is None: indices_to_keep = idx_to_keep\n",
    "            else: indices_to_keep = np.intersect1d(indices_to_keep, idx_to_keep)\n",
    "            \n",
    "            rate_maps_unnorm = rate_maps.copy()\n",
    "            rate_maps_corr_unnorm = rate_maps_corr.copy()\n",
    "            rate_maps = normalize_rate_maps(rate_maps)\n",
    "            rate_maps_corr = normalize_rate_maps(rate_maps_corr)\n",
    "            \n",
    "            si_rm = place_units.calculate_metrics(rate_maps, pos_occ)\n",
    "            si_rm_corr = place_units.calculate_metrics(rate_maps_corr, pos_occ)\n",
    "            \n",
    "            n_fields, rm_fields = place_units.rate_maps_field_detection(rate_maps, rate_maps, rate_maps)\n",
    "\n",
    "            selected_place_units = place_units.get_place_cells_indices(rate_maps, si_matlab)\n",
    "\n",
    "            if len(selected_place_units) > 0:\n",
    "                single_field_dim = np.array([\n",
    "                    np.sum(np.nansum(np.array(fields), axis=0)>0) for i, fields in enumerate(rm_fields)\n",
    "                    if fields and i in selected_place_units\n",
    "                ])\n",
    "                pu_flipped = place_units.rm_flipped(rate_maps, filter_indices=selected_place_units)\n",
    "\n",
    "                rm_fields_selected = [f for i, f in enumerate(rm_fields) if (i in selected_place_units) and n_fields[i] > 0]\n",
    "                if len(rm_fields_selected) == 0:\n",
    "                    print(f\"\\tSkipping avg rate map field because no selected fields\")\n",
    "                    continue\n",
    "                pu_field_flipped = place_units.rm_field_flipped(rm_fields_selected)\n",
    "\n",
    "                for k_tmp in ['single_field_dim', 'pu_flipped', 'pu_field_flipped']:\n",
    "                    data_dict_age[age][k][k_tmp].append(locals()[k_tmp])\n",
    "            \n",
    "            si_pm, rvl_pm, rvangle_pm = hd_units.calculate_metrics(polar_maps.copy(), hd_occ)\n",
    "            _, rvl_pm_corr, rvangle_pm_corr = hd_units.calculate_metrics(polar_maps_corr.copy(), hd_occ)\n",
    "\n",
    "            selected_hd_units = np.array([\n",
    "                idx for idx in range(polar_maps.shape[0]) if\n",
    "                (not np.isnan(si_pm_matlab[idx])) and (not np.isnan(rvl_matlab[idx])) and\n",
    "                (not np.isnan(si_pm_corr_matlab[idx])) and (not np.isnan(rvl_corr_matlab[idx])) and\n",
    "                ((si_pm_matlab[idx] > metrics_shuffle_th['dirSI'][age]) or (rvl_matlab[idx] > metrics_shuffle_th['rvLength'][age])) and\n",
    "                ((si_pm_corr_matlab[idx] > metrics_shuffle_th['dirSI'][age]) or (rvl_corr_matlab[idx] > metrics_shuffle_th['rvLength'][age])) and\n",
    "                (get_spatial_correlation(polar_maps[idx], polar_maps_corr[idx], return_pvalue=False) > 0.5)# and\n",
    "            ], dtype=np.int32)\n",
    "\n",
    "            selected_place_hd_units = np.intersect1d(selected_place_units, selected_hd_units, assume_unique=True)\n",
    "            selected_place_units = np.setdiff1d(selected_place_units, selected_place_hd_units, assume_unique=True)\n",
    "            selected_hd_units = np.setdiff1d(selected_hd_units, selected_place_hd_units, assume_unique=True)\n",
    "            \n",
    "            selected_place_units += trial_start_idx\n",
    "            selected_hd_units += trial_start_idx\n",
    "            selected_place_hd_units += trial_start_idx\n",
    "\n",
    "            for k_tmp in [\n",
    "                'positions', 'hd', 'speed', 'spike_times', 'sample_rate',\n",
    "                'rate_maps', 'rate_maps_corr', 'pos_occ',\n",
    "                'si_matlab', 'si_corr_matlab', 'si_rm', 'si_rm_corr',\n",
    "                'selected_place_units', 'n_fields',\n",
    "                'rate_maps_hd8', 'rate_maps_hd4',\n",
    "                'polar_maps', 'polar_maps_corr', 'hd_occ',\n",
    "                'si_pm_matlab', 'rvl_matlab', 'si_pm', 'rvl_pm', 'si_pm_corr_matlab', 'rvl_corr_matlab',\n",
    "                'rvl_pm_corr', 'rvangle_pm', 'rvangle_pm_corr',\n",
    "                'selected_hd_units', 'selected_place_hd_units',\n",
    "                'polar_maps_pred', 'dis_ratios',\n",
    "                'trial_start_idx'\n",
    "            ]:\n",
    "                data_dict_age[age][k][k_tmp].append(locals()[k_tmp])\n",
    "\n",
    "            trial_start_idx += len(idx_to_keep)\n",
    "\n",
    "        if trial_start_idx == 0:\n",
    "            data_dict_age[age].pop(k)\n",
    "        else:\n",
    "\n",
    "            for k_tmp in [\n",
    "                'rate_maps', 'rate_maps_corr', 'pos_occ',\n",
    "                'si_matlab', 'si_corr_matlab', 'si_rm', 'si_rm_corr',\n",
    "                'selected_place_units', 'n_fields',\n",
    "                'rate_maps_hd8', 'rate_maps_hd4',\n",
    "                'polar_maps', 'polar_maps_corr', 'hd_occ',\n",
    "                'si_pm_matlab', 'rvl_matlab', 'si_pm', 'rvl_pm', 'si_pm_corr_matlab', 'rvl_corr_matlab',\n",
    "                'rvl_pm_corr', 'rvangle_pm', 'rvangle_pm_corr',\n",
    "                'selected_hd_units', 'selected_place_hd_units',\n",
    "                'polar_maps_pred', 'dis_ratios',\n",
    "            ]:\n",
    "                try:\n",
    "                    data_dict_age[age][k][k_tmp] = np.concatenate(data_dict_age[age][k][k_tmp])\n",
    "                except ValueError as e:\n",
    "                    if \"zero-dimensional\" in str(e):\n",
    "                        data_dict_age[age][k][k_tmp] = np.array(data_dict_age[age][k][k_tmp])\n",
    "                    elif \"need at least one\" in str(e):\n",
    "                        continue\n",
    "                    else:\n",
    "                        raise e\n",
    "            \n",
    "    print(flush=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "62bef3ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "ages = sorted(data_dict_age.keys())\n",
    "n_ages = len(ages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "d24d07a3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Age 14:\n",
      "\tPlace units SI: 0.607 (min 0.393)\n",
      "\tHD units RVL: 0.317 (min 0.292)\n",
      "\tHD units SI: 0.624 (min 0.539)\n",
      "\n",
      "Age 15:\n",
      "\tPlace units SI: 0.619 (min 0.369)\n",
      "\tHD units RVL: 0.227 (min 0.174)\n",
      "\tHD units SI: 0.303 (min 0.237)\n",
      "\n",
      "Age 16:\n",
      "\tPlace units SI: 0.814 (min 0.372)\n",
      "\tHD units RVL: 0.355 (min 0.245)\n",
      "\tHD units SI: 0.372 (min 0.208)\n",
      "\n",
      "Age 17:\n",
      "\tPlace units SI: 0.662 (min 0.365)\n",
      "\tHD units RVL: 0.329 (min 0.066)\n",
      "\tHD units SI: 0.334 (min 0.113)\n",
      "\n",
      "Age 18:\n",
      "\tPlace units SI: 0.657 (min 0.353)\n",
      "\tHD units RVL: 0.344 (min 0.047)\n",
      "\tHD units SI: 0.321 (min 0.134)\n",
      "\n",
      "Age 19:\n",
      "\tPlace units SI: 0.745 (min 0.346)\n",
      "\tHD units RVL: 0.352 (min 0.060)\n",
      "\tHD units SI: 0.322 (min 0.146)\n",
      "\n",
      "Age 20:\n",
      "\tPlace units SI: 0.747 (min 0.364)\n",
      "\tHD units RVL: 0.352 (min 0.113)\n",
      "\tHD units SI: 0.360 (min 0.139)\n",
      "\n",
      "Age 21:\n",
      "\tPlace units SI: 0.753 (min 0.370)\n",
      "\tHD units RVL: 0.392 (min 0.271)\n",
      "\tHD units SI: 0.356 (min 0.133)\n",
      "\n",
      "Age 22:\n",
      "\tPlace units SI: 0.783 (min 0.363)\n",
      "\tHD units RVL: 0.340 (min 0.158)\n",
      "\tHD units SI: 0.298 (min 0.131)\n",
      "\n",
      "Age 23:\n",
      "\tPlace units SI: 0.779 (min 0.382)\n",
      "\tHD units RVL: 0.343 (min 0.255)\n",
      "\tHD units SI: 0.265 (min 0.144)\n",
      "\n",
      "Age 24:\n",
      "\tPlace units SI: 0.883 (min 0.379)\n",
      "\tHD units RVL: 0.405 (min 0.074)\n",
      "\tHD units SI: 0.461 (min 0.240)\n",
      "\n",
      "Age 25:\n",
      "\tPlace units SI: 0.947 (min 0.338)\n",
      "\tHD units RVL: 0.552 (min 0.408)\n",
      "\tHD units SI: 0.683 (min 0.358)\n",
      "\n",
      "Age 40:\n",
      "\tPlace units SI: 1.320 (min 0.362)\n",
      "\tHD units RVL: 0.490 (min 0.026)\n",
      "\tHD units SI: 0.571 (min 0.124)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "si_rm_selected = {}\n",
    "si_pm_selected = {}\n",
    "rvl_pm_selected = {}\n",
    "\n",
    "for age in ages:\n",
    "    for rat in data_dict_age[age].keys():\n",
    "        si_rm = data_dict_age[age][rat]['si_rm']\n",
    "        si_pm = data_dict_age[age][rat]['si_pm']\n",
    "        rvl_pm = data_dict_age[age][rat]['rvl_pm']\n",
    "\n",
    "        selected_place_units = data_dict_age[age][rat]['selected_place_units']\n",
    "        selected_hd_units = data_dict_age[age][rat]['selected_hd_units']\n",
    "        selected_place_hd_units = data_dict_age[age][rat]['selected_place_hd_units']\n",
    "\n",
    "        selected_units_rm = np.concatenate([selected_place_units, selected_place_hd_units])\n",
    "        selected_units_pm = np.concatenate([selected_hd_units, selected_place_hd_units])\n",
    "\n",
    "        if age not in si_rm_selected.keys():\n",
    "            si_rm_selected[age] = []\n",
    "            si_pm_selected[age] = []\n",
    "            rvl_pm_selected[age] = []\n",
    "        si_rm_selected[age].append(si_rm[selected_units_rm])\n",
    "        si_pm_selected[age].append(si_pm[selected_units_pm])\n",
    "        rvl_pm_selected[age].append(rvl_pm[selected_units_pm])\n",
    "\n",
    "for age in si_pm_selected.keys():\n",
    "    si_rm_selected[age] = np.concatenate(si_rm_selected[age])\n",
    "    si_pm_selected[age] = np.concatenate(si_pm_selected[age])\n",
    "    rvl_pm_selected[age] = np.concatenate(rvl_pm_selected[age])\n",
    "\n",
    "    print(f\"Age {age}:\")\n",
    "    if len(si_rm_selected[age]) > 0:\n",
    "        print(f\"\\tPlace units SI: {np.nanmean(si_rm_selected[age]):.3f} (min {np.min(si_rm_selected[age]):.3f})\")\n",
    "    if len(rvl_pm_selected[age]) > 0:\n",
    "        print(f\"\\tHD units RVL: {np.nanmean(rvl_pm_selected[age]):.3f} (min {np.min(rvl_pm_selected[age]):.3f})\")\n",
    "        print(f\"\\tHD units SI: {np.nanmean(si_pm_selected[age]):.3f} (min {np.min(si_pm_selected[age]):.3f})\")\n",
    "    print(flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14fdd440",
   "metadata": {},
   "source": [
    "# Load clustering data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8f1d9c5",
   "metadata": {},
   "source": [
    "#### Params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f741d6a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "BY = 'day'\n",
    "SEED = 7\n",
    "CLUSTERALGO = 'gm'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "139b6533",
   "metadata": {},
   "source": [
    "#### Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "42ea6abb",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_dir = os.path.join(DATA_DIR, 'cluster_locomotion', f'by_{BY}')\n",
    "df_data = pd.read_pickle(os.path.join(df_dir, f'data_{SEED}.pkl'))\n",
    "\n",
    "c_idx_col = f'cluster_idx_{CLUSTERALGO}'\n",
    "df_data = df_data[df_data[c_idx_col] != -1].reset_index()\n",
    "df_data.loc[df_data['age'] == 100, 'age'] = 40"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "40f01ac1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Excluding 0.0% (rat, age) pairs because they have multiple clusters\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>cluster_idx_gm</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>rat</th>\n",
       "      <th>age</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">muessig_data_struct</th>\n",
       "      <th>r101_p20</th>\n",
       "      <th>20</th>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>r115_p24</th>\n",
       "      <th>24</th>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>r118_p24</th>\n",
       "      <th>24</th>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>r118_p25</th>\n",
       "      <th>25</th>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>r129_p40</th>\n",
       "      <th>40</th>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                  cluster_idx_gm\n",
       "dataset             rat      age                \n",
       "muessig_data_struct r101_p20 20                2\n",
       "                    r115_p24 24                2\n",
       "                    r118_p24 24                2\n",
       "                    r118_p25 25                2\n",
       "                    r129_p40 40                3"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import re\n",
    "\n",
    "if 'exp' in BY or 'day' in BY:\n",
    "    df_data['dataset'] = df_data['rat'].map(lambda x: re.sub(\"['() ]\", \"\", x.split(',')[0]))\n",
    "    df_data = df_data[\n",
    "        (df_data['dataset'].str.contains('science2010')) | (df_data['dataset'].str.contains('muessig'))\n",
    "    ]\n",
    "    df_data['rat'] = df_data['rat'].map(lambda x: re.sub(\"['() ]\", \"\", x.split(',')[1]))\n",
    "\n",
    "    # this dataset give us a cluster for each (rat, age, trial)\n",
    "    # we want to get a cluster for each (rat, age)\n",
    "    # we exclude the (rat, age) where there are multiple clusters\n",
    "    df_data_exclude = df_data.groupby(['dataset', 'rat', 'age']).agg({c_idx_col: 'nunique'})\n",
    "    df_data_exclude = df_data_exclude[df_data_exclude[c_idx_col] > 1].reset_index()\n",
    "\n",
    "    print(\n",
    "        f\"Excluding {len(df_data_exclude)/len(df_data.reset_index().groupby(['rat', 'age']).count())*100:.1f}% \"+\n",
    "        \"(rat, age) pairs because they have multiple clusters\"\n",
    "    )\n",
    "\n",
    "    df_merge = df_data.merge(df_data_exclude[['dataset', 'rat', 'age']], on=['dataset', 'rat', 'age'], how='left', indicator=True)\n",
    "    df_data = df_merge[df_merge['_merge'] == 'left_only'].drop(columns=['_merge'])\n",
    "\n",
    "    # keep first cluster index because they are all the same after previous operation\n",
    "    df_data = df_data.groupby(['dataset', 'rat', 'age']).agg({c_idx_col: lambda x: list(x)[0]})\n",
    "elif 'age' in BY:\n",
    "    df_data = df_data[['age', c_idx_col]]\n",
    "    \n",
    "df_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "12599993",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_data = df_data.reset_index()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77013296",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>rat</th>\n",
       "      <th>age</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>cluster_idx_gm</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>31</td>\n",
       "      <td>16.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>102</td>\n",
       "      <td>20.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>9</td>\n",
       "      <td>40.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                rat   age\n",
       "cluster_idx_gm           \n",
       "1                31  16.0\n",
       "2               102  20.0\n",
       "3                 9  40.0"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_data_merge = df_data.groupby(c_idx_col).agg({'rat': 'count', 'age': 'median'}).reset_index()\n",
    "df_data_merge = df_data_merge[df_data_merge['rat'] > 2]\n",
    "df_data = pd.merge(\n",
    "    df_data, df_data_merge[c_idx_col], on=c_idx_col, how='inner'\n",
    ")\n",
    "df_data.groupby(c_idx_col).agg({'rat': 'count', 'age': 'median'})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "5024155d",
   "metadata": {},
   "outputs": [],
   "source": [
    "clusters = [c if c < 3 else 'Adult' for c in sorted(df_data[c_idx_col].unique())]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0728406f",
   "metadata": {},
   "source": [
    "# DEV model data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e456a763",
   "metadata": {},
   "source": [
    "### Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5134e063",
   "metadata": {},
   "outputs": [],
   "source": [
    "args_dev = { # DEVELOPMENT\n",
    "    'behaviour' : ['crawl', 'walk', 'run', 'adult', 'adult'],\n",
    "    'pretrained_behav' : ['crawl', 'crawl,walk', 'crawl,walk,run', 'crawl,walk,run,adult', 'adult'],\n",
    "    'env' : 'box_messy',\n",
    "    'env_dim': 0.635,\n",
    "    'name_prefix': None,\n",
    "    'pretrained_model_folder': False,\n",
    "    'moredata': None,\n",
    "    'n_gridcells': [0,0,0,0,25], # with GC\n",
    "    'gridcells_softmax': [False,False,False,False,True], # with GC\n",
    "    'gridcells_modules': [None,None,None,None,[0.2,0.4,0.6]], # with GC\n",
    "    'gridcells_orientations': [None,None,None,None,[0.1]], # with GC\n",
    "    'n_future_pred' : 1,\n",
    "    'frame_subsampling': 4,\n",
    "    'stride' : 10,\n",
    "    'reset_hidden_at': [None,None,None,None,10], # with GC\n",
    "    'bptt_steps' : 9,\n",
    "    'latent_dim' : 500,\n",
    "    'bias': False,\n",
    "    'dropouts': '[0,0,0]',\n",
    "    'nonlinearity' : 'sigmoid',\n",
    "    'hidden_reg' : 0.,\n",
    "    'weights_reg' : 0.,\n",
    "    'seed': 1,\n",
    "    'epoch' : None,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6bf1b84",
   "metadata": {},
   "source": [
    "### Define Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "2fb4badd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Comparison will be done on the following parameters:\n",
      "\tbehaviour: ['crawl', 'walk', 'run', 'adult', 'adult']\n",
      "\tpretrained_behav: ['crawl', 'crawl,walk', 'crawl,walk,run', 'crawl,walk,run,adult', 'adult']\n",
      "\tn_gridcells: [0, 0, 0, 0, 25]\n",
      "\tgridcells_softmax: [False, False, False, False, True]\n",
      "\tgridcells_modules: [None, None, None, None, [0.2, 0.4, 0.6]]\n",
      "\tgridcells_orientations: [None, None, None, None, [0.1]]\n",
      "\treset_hidden_at: [None, None, None, None, 10]\n"
     ]
    }
   ],
   "source": [
    "n_compare = None\n",
    "print(\"Comparison will be done on the following parameters:\")\n",
    "for k, a in args_dev.items():\n",
    "    if isinstance(a, list):\n",
    "        print(f\"\\t{k}: {a}\")\n",
    "        if n_compare is None:\n",
    "            n_compare = len(a)\n",
    "        elif n_compare != len(a):\n",
    "            raise ValueError(\"All lists must have the same length\")\n",
    "\n",
    "if n_compare is None:\n",
    "    raise ValueError(\"At least one argument must be a list to make a comparison\")\n",
    "\n",
    "for k, a in args_dev.items():\n",
    "    if not isinstance(a, list):\n",
    "        args_dev[k] = [a] * n_compare"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "0b3d174c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Comparing the activity from the following directories:\n",
      "/media/data/vrtopc/box/crawl/predictions/box_messy/vanilla/RNN_f1_w9_st10_fss4_do[0,0,0]_lat500_nlsigmoid_hreg0.0_wreg0.0_s01/act_crawl_epoch1500\n",
      "/media/data/vrtopc/box/walk/predictions/box_messy/crawl/RNN_f1_w9_st10_fss4_do[0,0,0]_lat500_nlsigmoid_hreg0.0_wreg0.0_s01/act_walk_epoch1500\n",
      "/media/data/vrtopc/box/run/predictions/box_messy/crawl_walk/RNN_f1_w9_st10_fss4_do[0,0,0]_lat500_nlsigmoid_hreg0.0_wreg0.0_s01/act_run_epoch1500\n",
      "/media/data/vrtopc/box/adult/predictions/box_messy/crawl_walk_run/RNN_f1_w9_st10_fss4_do[0,0,0]_lat500_nlsigmoid_hreg0.0_wreg0.0_s01/act_adult_epoch1500\n",
      "/media/data/vrtopc/box/adult/predictions/box_messy/vanilla/RNN_gridcellssm25_mod[0.2,0.4,0.6]_ori[0.1]_reset10_f1_w9_st10_fss4_do[0,0,0]_lat500_nlsigmoid_hreg0.0_wreg0.0_s01/act_adult_epoch1500\n"
     ]
    }
   ],
   "source": [
    "from utils.trainer import RNNTrainer\n",
    "\n",
    "activity_dirs_dev = []\n",
    "models_dev = []\n",
    "\n",
    "print(\"Comparing the activity from the following directories:\")\n",
    "for i in range(n_compare):\n",
    "    model_name = RNNTrainer.define_model_name({k: v[i] for k, v in args_dev.items()})\n",
    "    \n",
    "    env_shape = args_dev['env'][i].split('_')[0]\n",
    "    trained_behav_list = args_dev['pretrained_behav'][i].split(',')\n",
    "    behav = trained_behav_list.pop()\n",
    "    if len(trained_behav_list)>0:\n",
    "        folder_name = '_'.join(trained_behav_list)\n",
    "    else:\n",
    "        folder_name = \"vanilla\"\n",
    "    exp_dir = os.path.join(\n",
    "        DATA_DIR, env_shape, behav, \"predictions\", args_dev['env'][i],\n",
    "        folder_name, model_name\n",
    "    )\n",
    "\n",
    "    activity_dir = f\"act_{args_dev['behaviour'][i]}_epoch\"\n",
    "    if args_dev['epoch'][i] is not None:\n",
    "        epoch = args_dev['epoch'][i]\n",
    "    else:\n",
    "        dirs = [d for d in os.listdir(exp_dir) if re.match(rf\"{activity_dir}\\d+\", d)]\n",
    "        epoch = max([int(re.findall(r'\\d+', d)[-1]) for d in dirs])\n",
    "\n",
    "    models_dev.append(\n",
    "        torch.load(\n",
    "            os.path.join(exp_dir, f\"rnn_epoch{epoch}.pth\"),\n",
    "            weights_only=False,\n",
    "            map_location=torch.device(DEVICE)\n",
    "        ).to(DEVICE)\n",
    "    )\n",
    "    activity_dirs_dev.append(os.path.join(exp_dir, f\"{activity_dir}{epoch}\"))\n",
    "\n",
    "    print(activity_dirs_dev[-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15ab26de",
   "metadata": {},
   "source": [
    "# ROC model data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2d49524",
   "metadata": {},
   "source": [
    "### Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78a60af8",
   "metadata": {},
   "outputs": [],
   "source": [
    "args_compare = (\n",
    "    # ###################################################################################\n",
    "    # ########################## DIS-AGREEING MODELS ####################################\n",
    "    # ###################################################################################\n",
    "    { # RATE OF CHANGE\n",
    "        'behaviour' : ['crawl', 'crawl', 'crawl', 'crawl', 'crawl'],\n",
    "        'pretrained_behav' : ['crawl', 'crawl', 'crawl', 'crawl', 'crawl'],\n",
    "        'env' : 'box_messy',\n",
    "        'env_dim': 0.635,\n",
    "        'name_prefix': None,\n",
    "        'pretrained_model_folder': [False,True,True,True,False],\n",
    "        'moredata': None,\n",
    "        'n_gridcells': [0,0,0,0,25], # with GC\n",
    "        'gridcells_softmax': [False,False,False,False,True], # with GC\n",
    "        'gridcells_modules': [None,None,None,None,[0.2,0.4,0.6]], # with GC\n",
    "        'gridcells_orientations': [None,None,None,None,[0.1]], # with GC\n",
    "        'n_future_pred' : 1,\n",
    "        'frame_subsampling': 4,\n",
    "        'stride' : [10, 20, 25, 30, 30],\n",
    "        'reset_hidden_at': [None,None,None,None,10], # with GC\n",
    "        'bptt_steps' : 9,\n",
    "        'latent_dim' : 500,\n",
    "        'bias': False,\n",
    "        'dropouts': '[0,0,0]',\n",
    "        'nonlinearity' : 'sigmoid',\n",
    "        'hidden_reg' : 0.,\n",
    "        'weights_reg' : 0.,\n",
    "        'seed': 1,\n",
    "        'epoch' : None,\n",
    "    }\n",
    "    # { # CRAWL WITH MORE DATA\n",
    "    #     'behaviour' : ['crawl', 'crawl', 'crawl', 'crawl', 'crawl'],\n",
    "    #     'pretrained_behav' : ['crawl', 'crawl', 'crawl', 'crawl', 'crawl'],\n",
    "    #     'env' : 'box_messy',\n",
    "    #     'env_dim': 0.635,\n",
    "    #     'name_prefix': None,\n",
    "    #     'pretrained_model_folder': [False,True,True,True,False],\n",
    "    #     'moredata': [None, 1, 2, 3, 3],\n",
    "    #     'n_gridcells': [0,0,0,0,25], # with GC\n",
    "    #     'gridcells_softmax': [False,False,False,False,True], # with GC\n",
    "    #     'gridcells_modules': [None,None,None,None,[0.2,0.4,0.6]], # with GC\n",
    "    #     'gridcells_orientations': [None,None,None,None,[0.1]], # with GC\n",
    "    #     'n_future_pred' : 1,\n",
    "    #     'frame_subsampling': 4,\n",
    "    #     'stride' : 10,\n",
    "    #     'reset_hidden_at': [None,None,None,None,10], # with GC\n",
    "    #     'bptt_steps' : 9,\n",
    "    #     'latent_dim' : 500,\n",
    "    #     'bias': False,\n",
    "    #     'dropouts': '[0,0,0]',\n",
    "    #     'nonlinearity' : 'sigmoid',\n",
    "    #     'hidden_reg' : 0.,\n",
    "    #     'weights_reg' : 0.,\n",
    "    #     'seed': 1,\n",
    "    #     'epoch' : None,\n",
    "    # }\n",
    "    # { # REVERSE TRAINING\n",
    "    #     'behaviour' : ['adult', 'run', 'walk', 'crawl', 'crawl'],\n",
    "    #     'pretrained_behav' : ['adult', 'adult,run', 'adult,run,walk', 'adult,run,walk,crawl', 'crawl'],\n",
    "    #     'env' : 'box_messy',\n",
    "    #     'env_dim': 0.635,\n",
    "    #     'name_prefix': None,\n",
    "    #     'pretrained_model_folder': False,\n",
    "    #     'moredata': None,\n",
    "    #     'n_gridcells': [0,0,0,0,25], # with GC\n",
    "    #     'gridcells_softmax': [False,False,False,False,True], # with GC\n",
    "    #     'gridcells_modules': [None,None,None,None,[0.2,0.4,0.6]], # with GC\n",
    "    #     'gridcells_orientations': [None,None,None,None,[0.1]], # with GC\n",
    "    #     'n_future_pred' : 1,\n",
    "    #     'frame_subsampling': 4,\n",
    "    #     'stride' : 10,\n",
    "    #     'reset_hidden_at': [None,None,None,None,10], # with GC\n",
    "    #     'bptt_steps' : 9,\n",
    "    #     'latent_dim' : 500,\n",
    "    #     'bias': False,\n",
    "    #     'dropouts': '[0,0,0]',\n",
    "    #     'nonlinearity' : 'sigmoid',\n",
    "    #     'hidden_reg' : 0.,\n",
    "    #     'weights_reg' : 0.,\n",
    "    #     'seed': 1,\n",
    "    #     'epoch' : None,\n",
    "    # }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "465c686e",
   "metadata": {},
   "source": [
    "### Define Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "5fa6ec2a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Comparison will be done on the following parameters:\n",
      "\tbehaviour: ['crawl', 'crawl', 'crawl', 'crawl', 'crawl']\n",
      "\tpretrained_behav: ['crawl', 'crawl', 'crawl', 'crawl', 'crawl']\n",
      "\tpretrained_model_folder: [False, True, True, True, False]\n",
      "\tn_gridcells: [0, 0, 0, 0, 25]\n",
      "\tgridcells_softmax: [False, False, False, False, True]\n",
      "\tgridcells_modules: [None, None, None, None, [0.2, 0.4, 0.6]]\n",
      "\tgridcells_orientations: [None, None, None, None, [0.1]]\n",
      "\tstride: [10, 20, 25, 30, 30]\n",
      "\treset_hidden_at: [None, None, None, None, 10]\n"
     ]
    }
   ],
   "source": [
    "n_compare = None\n",
    "print(\"Comparison will be done on the following parameters:\")\n",
    "for k, a in args_compare.items():\n",
    "    if isinstance(a, list):\n",
    "        print(f\"\\t{k}: {a}\")\n",
    "        if n_compare is None:\n",
    "            n_compare = len(a)\n",
    "        elif n_compare != len(a):\n",
    "            raise ValueError(\"All lists must have the same length\")\n",
    "\n",
    "if n_compare is None:\n",
    "    raise ValueError(\"At least one argument must be a list to make a comparison\")\n",
    "\n",
    "for k, a in args_compare.items():\n",
    "    if not isinstance(a, list):\n",
    "        args_compare[k] = [a] * n_compare"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "38d1d8fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Comparing the activity from the following directories:\n",
      "/media/data/vrtopc/box/crawl/predictions/box_messy/vanilla/RNN_f1_w9_st10_fss4_do[0,0,0]_lat500_nlsigmoid_hreg0.0_wreg0.0_s01/act_crawl_epoch1500\n",
      "/media/data/vrtopc/box/crawl/predictions/box_messy/vanilla/RNN_ft_f1_w9_st20_fss4_do[0,0,0]_lat500_nlsigmoid_hreg0.0_wreg0.0_s01/act_crawl_epoch1500\n",
      "/media/data/vrtopc/box/crawl/predictions/box_messy/vanilla/RNN_ft_f1_w9_st25_fss4_do[0,0,0]_lat500_nlsigmoid_hreg0.0_wreg0.0_s01/act_crawl_epoch1500\n",
      "/media/data/vrtopc/box/crawl/predictions/box_messy/vanilla/RNN_ft_f1_w9_st30_fss4_do[0,0,0]_lat500_nlsigmoid_hreg0.0_wreg0.0_s01/act_crawl_epoch1500\n",
      "/media/data/vrtopc/box/crawl/predictions/box_messy/vanilla/RNN_gridcellssm25_mod[0.2,0.4,0.6]_ori[0.1]_reset10_f1_w9_st30_fss4_do[0,0,0]_lat500_nlsigmoid_hreg0.0_wreg0.0_s01/act_crawl_epoch1500\n"
     ]
    }
   ],
   "source": [
    "from utils.trainer import RNNTrainer\n",
    "\n",
    "activity_dirs_roc = []\n",
    "models_roc = []\n",
    "\n",
    "print(\"Comparing the activity from the following directories:\")\n",
    "for i in range(n_compare):\n",
    "    model_name = RNNTrainer.define_model_name({k: v[i] for k, v in args_compare.items()})\n",
    "    \n",
    "    env_shape = args_compare['env'][i].split('_')[0]\n",
    "    trained_behav_list = args_compare['pretrained_behav'][i].split(',')\n",
    "    behav = trained_behav_list.pop()\n",
    "    if len(trained_behav_list)>0:\n",
    "        folder_name = '_'.join(trained_behav_list)\n",
    "    else:\n",
    "        folder_name = \"vanilla\"\n",
    "    exp_dir = os.path.join(\n",
    "        DATA_DIR, env_shape, behav, \"predictions\", args_compare['env'][i],\n",
    "        folder_name, model_name\n",
    "    )\n",
    "\n",
    "    activity_dir = f\"act_{args_compare['behaviour'][i]}_epoch\"\n",
    "    if args_compare['epoch'][i] is not None:\n",
    "        epoch = args_compare['epoch'][i]\n",
    "    else:\n",
    "        dirs = [d for d in os.listdir(exp_dir) if re.match(rf\"{activity_dir}\\d+\", d)]\n",
    "        epoch = max([int(re.findall(r'\\d+', d)[-1]) for d in dirs])\n",
    "\n",
    "    models_roc.append(\n",
    "        torch.load(\n",
    "            os.path.join(exp_dir, f\"rnn_epoch{epoch}.pth\"),\n",
    "            weights_only=False,\n",
    "            map_location=torch.device(DEVICE)\n",
    "        ).to(DEVICE)\n",
    "    )\n",
    "    activity_dirs_roc.append(os.path.join(exp_dir, f\"{activity_dir}{epoch}\"))\n",
    "\n",
    "    print(activity_dirs_roc[-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "860f6ea9",
   "metadata": {},
   "source": [
    "# Extract values"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3519417",
   "metadata": {},
   "source": [
    "### Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "6970a158",
   "metadata": {},
   "outputs": [],
   "source": [
    "sir_dict_real = {}\n",
    "rvl_dict_real = {}\n",
    "sid_dict_real = {}\n",
    "\n",
    "for age in ages:\n",
    "    for rat in data_dict_age[age].keys():\n",
    "        # Determine dataset and rat key for matching in df_data\n",
    "        dataset = 'science2010_data_struct' if rat in ratnames_science else 'muessig_data_struct'\n",
    "        k = '_'.join(rat.split('_')[:2])\n",
    "        \n",
    "        # Get cluster index from df_data\n",
    "        c = df_data[\n",
    "            (df_data['dataset'] == dataset) &\n",
    "            (df_data['age'] == age) &\n",
    "            (df_data['rat'] == k)\n",
    "        ]['cluster_idx_gm'].values\n",
    "        \n",
    "        if len(c) == 0:\n",
    "            c = df_data[\n",
    "                (df_data['dataset'] == dataset) &\n",
    "                (df_data['age'] == age) &\n",
    "                (df_data['rat'].str.contains(k.split('_')[0]))\n",
    "            ]['cluster_idx_gm'].values\n",
    "            \n",
    "        if len(c) == 0 : continue\n",
    "\n",
    "        assert len(np.unique(c)) == 1\n",
    "        c = c[0]\n",
    "        \n",
    "        # Initialize cluster in data dictionary if not exists\n",
    "        if c not in sir_dict_real.keys():\n",
    "            sir_dict_real[c] = []\n",
    "            rvl_dict_real[c] = []\n",
    "            sid_dict_real[c] = []\n",
    "        \n",
    "        sir_dict_real[c].extend(list(data_dict_age[age][rat]['si_rm']))\n",
    "        rvl_dict_real[c].extend(list(data_dict_age[age][rat]['rvl_pm']))\n",
    "        sid_dict_real[c].extend(list(data_dict_age[age][rat]['si_pm']))\n",
    "\n",
    "sir_real = []\n",
    "rvl_real = []\n",
    "sid_real = []\n",
    "for c in sorted(sir_dict_real.keys()):\n",
    "    sir_real.append(sir_dict_real[c])\n",
    "    rvl_real.append(rvl_dict_real[c])\n",
    "    sid_real.append(sid_dict_real[c])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "682e4a77",
   "metadata": {},
   "outputs": [],
   "source": [
    "sir_model_dev = []\n",
    "rvl_model_dev = []\n",
    "sid_model_dev = []\n",
    "\n",
    "for ad in activity_dirs_dev[1:]:\n",
    "    sir = np.load(os.path.join(ad, 'place', \"si.npy\"))\n",
    "    rvl = np.load(os.path.join(ad, \"hd\", \"rvl.npy\"))\n",
    "    sid = np.load(os.path.join(ad, \"hd\", \"si.npy\"))\n",
    "\n",
    "    sir_model_dev.append(sir)\n",
    "    rvl_model_dev.append(rvl)\n",
    "    sid_model_dev.append(sid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "44101d12",
   "metadata": {},
   "outputs": [],
   "source": [
    "sir_model_roc = []\n",
    "rvl_model_roc = []\n",
    "sid_model_roc = []\n",
    "\n",
    "for ad in activity_dirs_roc[1:]:\n",
    "    sir_curr = np.load(os.path.join(ad, 'place', \"si.npy\"))\n",
    "    sir_model_roc.append(np.nan_to_num(sir_curr, nan=np.nanmean(sir_curr)))\n",
    "    rvl_curr = np.load(os.path.join(ad, \"hd\", \"rvl.npy\"))\n",
    "    rvl_model_roc.append(np.nan_to_num(rvl_curr, nan=np.nanmean(rvl_curr)))\n",
    "    sid_curr = np.load(os.path.join(ad, \"hd\", \"si.npy\"))\n",
    "    sid_model_roc.append(np.nan_to_num(sid_curr, nan=np.nanmean(sid_curr)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "138fbf33",
   "metadata": {},
   "source": [
    "### Percentage of place, HD, place+HD cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "46796363",
   "metadata": {},
   "outputs": [],
   "source": [
    "LATENT_DIM = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "78ae3bd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "pc_perc_real_dict = {}\n",
    "pc_n_real_dict = {}\n",
    "\n",
    "phdc_perc_real_dict = {}\n",
    "phdc_n_real_dict = {}\n",
    "\n",
    "hdc_perc_real_dict = {}\n",
    "hdc_n_real_dict = {}\n",
    "\n",
    "n_real_dict = {}\n",
    "\n",
    "for age in ages:\n",
    "    for rat in data_dict_age[age].keys():\n",
    "        dataset = 'science2010_data_struct' if rat in ratnames_science else 'muessig_data_struct'\n",
    "        k = '_'.join(rat.split('_')[:2])\n",
    "        c = df_data[\n",
    "            (df_data['dataset'] == dataset) &\n",
    "            (df_data['age'] == age) &\n",
    "            (df_data['rat'] == k)\n",
    "        ]['cluster_idx_gm'].values\n",
    "\n",
    "        if len(c) == 0:\n",
    "            c = df_data[\n",
    "                (df_data['dataset'] == dataset) &\n",
    "                (df_data['age'] == age) &\n",
    "                (df_data['rat'].str.contains(k.split('_')[0]))\n",
    "            ]['cluster_idx_gm'].values\n",
    "            \n",
    "        if len(c) == 0 : continue\n",
    "\n",
    "        assert len(np.unique(c)) == 1\n",
    "        c = c[0]\n",
    "\n",
    "        if c not in pc_perc_real_dict.keys():\n",
    "            pc_perc_real_dict[c] = []\n",
    "            pc_n_real_dict[c] = []\n",
    "            phdc_perc_real_dict[c] = []\n",
    "            phdc_n_real_dict[c] = []\n",
    "            hdc_perc_real_dict[c] = []\n",
    "            hdc_n_real_dict[c] = []\n",
    "            n_real_dict[c] = []\n",
    "\n",
    "        selected_place_units = data_dict_age[age][rat]['selected_place_units']\n",
    "        selected_hd_units = data_dict_age[age][rat]['selected_hd_units']\n",
    "        selected_place_hd_units = data_dict_age[age][rat]['selected_place_hd_units']\n",
    "        n_cells = len(data_dict_age[age][rat]['rate_maps'])\n",
    "        n_real_dict[c].append(n_cells)\n",
    "\n",
    "        n_pc = len(selected_place_units) + len(selected_place_hd_units)\n",
    "        pc_n_real_dict[c].append(n_pc)\n",
    "        pc_perc_real_dict[c].append(n_pc / n_cells)\n",
    "\n",
    "        n_phdc = len(selected_place_hd_units)\n",
    "        phdc_n_real_dict[c].append(n_phdc)\n",
    "        phdc_perc_real_dict[c].append(n_phdc / n_cells)\n",
    "\n",
    "        n_hdc = (len(selected_hd_units) + len(selected_place_hd_units))\n",
    "        hdc_n_real_dict[c].append(n_hdc)\n",
    "        hdc_perc_real_dict[c].append(n_hdc / n_cells)\n",
    "\n",
    "pc_perc_real = []\n",
    "pc_n_real = []\n",
    "\n",
    "phdc_perc_real = []\n",
    "phdc_n_real = []\n",
    "\n",
    "hdc_perc_real = []\n",
    "hdc_n_real = []\n",
    "\n",
    "n_real = []\n",
    "for c in sorted(pc_perc_real_dict.keys()):\n",
    "    pc_perc_real.append(pc_perc_real_dict[c])\n",
    "    pc_n_real.append(pc_n_real_dict[c])\n",
    "    hdc_perc_real.append(hdc_perc_real_dict[c])\n",
    "    hdc_n_real.append(hdc_n_real_dict[c])\n",
    "    phdc_perc_real.append(phdc_perc_real_dict[c])\n",
    "    phdc_n_real.append(phdc_n_real_dict[c])\n",
    "\n",
    "    n_real.append(n_real_dict[c])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "cbcac2de",
   "metadata": {},
   "outputs": [],
   "source": [
    "pc_perc_model_dev = []\n",
    "phdc_perc_model_dev = []\n",
    "hdc_perc_model_dev = []\n",
    "\n",
    "for ad in activity_dirs_dev[1:]:\n",
    "    selected_place_units = np.load(os.path.join(ad, \"indices_place_cells.npy\"))\n",
    "    selected_place_hd_units = np.load(os.path.join(ad, \"indices_conjunctive_cells.npy\"))\n",
    "    selected_hd_units = np.load(os.path.join(ad, \"indices_hd_cells.npy\"))\n",
    "\n",
    "    pc_perc_model_dev.append(\n",
    "        (len(selected_place_units)+len(selected_place_hd_units)) /\n",
    "        LATENT_DIM\n",
    "    )\n",
    "    phdc_perc_model_dev.append(\n",
    "        len(selected_place_hd_units) /\n",
    "        LATENT_DIM\n",
    "    )\n",
    "    hdc_perc_model_dev.append(\n",
    "        (len(selected_hd_units)+len(selected_place_hd_units)) /\n",
    "        LATENT_DIM\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "6b350227",
   "metadata": {},
   "outputs": [],
   "source": [
    "pc_perc_model_roc = []\n",
    "phdc_perc_model_roc = []\n",
    "hdc_perc_model_roc = []\n",
    "\n",
    "for ad in activity_dirs_roc[1:]:\n",
    "    selected_place_units = np.load(os.path.join(ad, \"indices_place_cells.npy\"))\n",
    "    selected_place_hd_units = np.load(os.path.join(ad, \"indices_conjunctive_cells.npy\"))\n",
    "    selected_hd_units = np.load(os.path.join(ad, \"indices_hd_cells.npy\"))\n",
    "\n",
    "    pc_perc_model_roc.append(\n",
    "        (len(selected_place_units)+len(selected_place_hd_units)) /\n",
    "        LATENT_DIM\n",
    "    )\n",
    "    phdc_perc_model_roc.append(\n",
    "        len(selected_place_hd_units) /\n",
    "        LATENT_DIM\n",
    "    )\n",
    "    hdc_perc_model_roc.append(\n",
    "        (len(selected_hd_units)+len(selected_place_hd_units)) /\n",
    "        LATENT_DIM\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4b6cd1f",
   "metadata": {},
   "source": [
    "# Likelihood test"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8586d5bc",
   "metadata": {},
   "source": [
    "### Generate probability distributions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "16fa7fe4",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_BINS = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "eddfde30",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_prob_distribs(metric_real, metric_model_dev, metric_model_roc, n_bins, c):\n",
    "    metric_real_curr = metric_real[c]\n",
    "    metric_model_dev_curr = metric_model_dev[c]\n",
    "    metric_model_roc_curr = metric_model_roc[c]\n",
    "    metric_min = min(np.min(metric_real_curr), np.min(metric_model_dev_curr), np.min(metric_model_roc_curr))\n",
    "    metric_max = max(np.max(metric_real_curr), np.max(metric_model_dev_curr), np.max(metric_model_roc_curr))\n",
    "    hist_dev, metric_edges_dev = np.histogram(metric_model_dev_curr, bins=n_bins, range=(metric_min, metric_max), density=False)\n",
    "    hist_roc, metric_edges_roc = np.histogram(metric_model_roc_curr, bins=n_bins, range=(metric_min, metric_max), density=False)\n",
    "\n",
    "    # convert histogram to probability distribution using softmax\n",
    "    pd_dev = np.exp(hist_dev)/np.sum(np.exp(hist_dev))\n",
    "    pd_roc = np.exp(hist_roc)/np.sum(np.exp(hist_roc))\n",
    "\n",
    "    if not np.allclose(metric_edges_dev, metric_edges_roc):\n",
    "        raise ValueError(\"Metric edges are not the same for dev and roc\")\n",
    "    if ~np.isclose(np.sum(pd_dev),1) or ~np.isclose(np.sum(pd_roc), 1):\n",
    "        raise ValueError(\"Probability distribution does not sum to 1\")\n",
    "    \n",
    "    return pd_dev, pd_roc, metric_edges_dev\n",
    "\n",
    "def calculate_pdfs(metric_real, metric_model_dev, metric_model_roc, n_bins, c):\n",
    "    metric_real_curr = metric_real[c]\n",
    "    metric_model_dev_curr = metric_model_dev[c]\n",
    "    metric_model_roc_curr = metric_model_roc[c]\n",
    "    metric_min = min(np.min(metric_real_curr), np.min(metric_model_dev_curr), np.min(metric_model_roc_curr))\n",
    "    metric_max = max(np.max(metric_real_curr), np.max(metric_model_dev_curr), np.max(metric_model_roc_curr))\n",
    "    pdf_dev, metric_edges_dev = np.histogram(metric_model_dev_curr, bins=n_bins, range=(metric_min, metric_max), density=True)\n",
    "    pdf_roc, metric_edges_roc = np.histogram(metric_model_roc_curr, bins=n_bins, range=(metric_min, metric_max), density=True)\n",
    "\n",
    "    if not np.allclose(metric_edges_dev, metric_edges_roc):\n",
    "        raise ValueError(\"Metric edges are not the same for dev and roc\")\n",
    "    \n",
    "    return pdf_dev, pdf_roc, metric_edges_dev\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "10802abc",
   "metadata": {},
   "outputs": [],
   "source": [
    "if WITH_GRID_CELLS:\n",
    "    sir_model_dev = sir_model_dev[:-2] + [sir_model_dev[-1]]\n",
    "    rvl_model_dev = rvl_model_dev[:-2] + [rvl_model_dev[-1]]\n",
    "    sid_model_dev = sid_model_dev[:-2] + [sid_model_dev[-1]]\n",
    "    pc_perc_model_dev = pc_perc_model_dev[:-2] + [pc_perc_model_dev[-1]]\n",
    "    phdc_perc_model_dev = phdc_perc_model_dev[:-2] + [phdc_perc_model_dev[-1]]\n",
    "    hdc_perc_model_dev = hdc_perc_model_dev[:-2] + [hdc_perc_model_dev[-1]]\n",
    "\n",
    "    sir_model_roc = sir_model_roc[:-2] + [sir_model_roc[-1]]\n",
    "    rvl_model_roc = rvl_model_roc[:-2] + [rvl_model_roc[-1]]\n",
    "    sid_model_roc = sid_model_roc[:-2] + [sid_model_roc[-1]]\n",
    "    pc_perc_model_roc = pc_perc_model_roc[:-2] + [pc_perc_model_roc[-1]]\n",
    "    phdc_perc_model_roc = phdc_perc_model_roc[:-2] + [phdc_perc_model_roc[-1]]\n",
    "    hdc_perc_model_roc = hdc_perc_model_roc[:-2] + [hdc_perc_model_roc[-1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "c0eef834",
   "metadata": {},
   "outputs": [],
   "source": [
    "sir_pd_model_dev, rvl_pd_model_dev, sid_pd_model_dev = [], [], []\n",
    "sir_pd_model_roc, rvl_pd_model_roc, sid_pd_model_roc = [], [], []\n",
    "sir_edges, rvl_edges, sid_edges = [], [], []\n",
    "\n",
    "for c in range(len(sir_real)):\n",
    "    pdf_dev, pdf_roc, metric_edges = calculate_prob_distribs(\n",
    "        sir_real, sir_model_dev, sir_model_roc, N_BINS, c\n",
    "    )\n",
    "    sir_pd_model_dev.append(pdf_dev)\n",
    "    sir_pd_model_roc.append(pdf_roc)\n",
    "    sir_edges.append(metric_edges)\n",
    "\n",
    "    pdf_dev, pdf_roc, metric_edges = calculate_prob_distribs(\n",
    "        rvl_real, rvl_model_dev, rvl_model_roc, N_BINS, c\n",
    "    )\n",
    "    rvl_pd_model_dev.append(pdf_dev)\n",
    "    rvl_pd_model_roc.append(pdf_roc)\n",
    "    rvl_edges.append(metric_edges)\n",
    "\n",
    "    pdf_dev, pdf_roc, metric_edges = calculate_prob_distribs(\n",
    "        sid_real, sid_model_dev, sid_model_roc, N_BINS, c\n",
    "    )\n",
    "    sid_pd_model_dev.append(pdf_dev)\n",
    "    sid_pd_model_roc.append(pdf_roc)\n",
    "    sid_edges.append(metric_edges)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21182a59",
   "metadata": {},
   "source": [
    "### Calculate log-likelihoods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "6280cf0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_loglikelihood(metric_real, pd_model_dev, pd_model_roc, metric_edges):\n",
    "    bin_idx_real = np.digitize(metric_real, metric_edges)-1\n",
    "    if -1 in bin_idx_real:\n",
    "        raise ValueError(\"-1 should not be in bin_idx_real\")\n",
    "    bin_idx_real = np.clip(bin_idx_real, a_min=None, a_max=len(pd_model_dev)-1)\n",
    "\n",
    "    ll_dev = np.log(pd_model_dev[bin_idx_real])\n",
    "    ll_roc = np.log(pd_model_roc[bin_idx_real])\n",
    "\n",
    "    perc_inf_dev = np.sum(ll_dev == -np.inf) / len(ll_dev) *100\n",
    "    perc_inf_roc = np.sum(ll_roc == -np.inf) / len(ll_dev) *100\n",
    "\n",
    "    return np.sum(ll_dev[ll_dev != -np.inf]), np.sum(ll_roc[ll_roc != -np.inf]), perc_inf_dev, perc_inf_roc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "360a3cfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "sir_ll_dev, sir_ll_roc = [], []\n",
    "sir_pinf_dev, sir_pinf_roc = [], []\n",
    "rvl_ll_dev, rvl_ll_roc = [], []\n",
    "rvl_pinf_dev, rvl_pinf_roc = [], []\n",
    "sid_ll_dev, sid_ll_roc = [], []\n",
    "sid_pinf_dev, sid_pinf_roc = [], []\n",
    "\n",
    "\n",
    "for c in range(len(sir_real)):\n",
    "    ll_dev, ll_roc, pinf_dev, pinf_roc = calculate_loglikelihood(\n",
    "        metric_real=sir_real[c],\n",
    "        pd_model_dev=sir_pd_model_dev[c],\n",
    "        pd_model_roc=sir_pd_model_roc[c],\n",
    "        metric_edges=sir_edges[c]\n",
    "    )\n",
    "    sir_ll_dev.append(ll_dev)\n",
    "    sir_ll_roc.append(ll_roc)\n",
    "    sir_pinf_dev.append(pinf_dev)\n",
    "    sir_pinf_roc.append(pinf_roc)\n",
    "\n",
    "    ll_dev, ll_roc, pinf_dev, pinf_roc = calculate_loglikelihood(\n",
    "        metric_real=rvl_real[c],\n",
    "        pd_model_dev=rvl_pd_model_dev[c],\n",
    "        pd_model_roc=rvl_pd_model_roc[c],\n",
    "        metric_edges=rvl_edges[c]\n",
    "    )\n",
    "    rvl_ll_dev.append(ll_dev)\n",
    "    rvl_ll_roc.append(ll_roc)\n",
    "    rvl_pinf_dev.append(pinf_dev)\n",
    "    rvl_pinf_roc.append(pinf_roc)\n",
    "\n",
    "    ll_dev, ll_roc, pinf_dev, pinf_roc = calculate_loglikelihood(\n",
    "        metric_real=sid_real[c],\n",
    "        pd_model_dev=sid_pd_model_dev[c],\n",
    "        pd_model_roc=sid_pd_model_roc[c],\n",
    "        metric_edges=sid_edges[c]\n",
    "    )\n",
    "    sid_ll_dev.append(ll_dev)\n",
    "    sid_ll_roc.append(ll_roc)\n",
    "    sid_pinf_dev.append(pinf_dev)\n",
    "    sid_pinf_roc.append(pinf_roc)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "506b055b",
   "metadata": {},
   "source": [
    "### Compare log-likelihoods"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6cf4150",
   "metadata": {},
   "source": [
    "$SI_r$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "12ce9848",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(sir_ll_roc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "cc439c97",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C=0, LL dev=-11,731, LL roc=-12,490 !!! DEV WINS\n",
      "C=1, LL dev=-30,930, LL roc=-48,815 !!! DEV WINS\n",
      "C=2, LL dev=-9,028, LL roc=-16,235 !!! DEV WINS\n",
      "TOT, LL dev=-51,689, LL roc=-77,540 \t!!! DEV WINS\n",
      "\n"
     ]
    }
   ],
   "source": [
    "ll_dev_tot = 0\n",
    "ll_tot_roc = 0\n",
    "for c in range(len(sir_real)):\n",
    "    print(f'C={c}, LL dev={sir_ll_dev[c]:,.0f}, LL roc={sir_ll_roc[c]:,.0f}', end=' ')\n",
    "    ll_dev_tot += sir_ll_dev[c]\n",
    "    ll_tot_roc += sir_ll_roc[c]\n",
    "    if sir_ll_dev[c] > sir_ll_roc[c]:\n",
    "        print('!!! DEV WINS')\n",
    "    else:\n",
    "        print('??? ROC WINS')\n",
    "print(f'TOT, LL dev={ll_dev_tot:,.0f}, LL roc={ll_tot_roc:,.0f}', end=' ')\n",
    "if ll_dev_tot > ll_tot_roc:\n",
    "    print('\\t!!! DEV WINS')\n",
    "else:\n",
    "    print('\\t??? ROC WINS')\n",
    "print(flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d407c9f2",
   "metadata": {},
   "source": [
    "$RVL$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "aaeaf1a5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C=0, LL dev=-4,900, LL roc=-5,696 !!! DEV WINS\n",
      "C=1, LL dev=-12,354, LL roc=-14,075 !!! DEV WINS\n",
      "C=2, LL dev=-2,805, LL roc=-4,336 !!! DEV WINS\n",
      "TOT, LL dev=-20,059, LL roc=-24,107 \t!!! DEV WINS\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "ll_dev_tot = 0\n",
    "ll_tot_roc = 0\n",
    "for c in range(len(sir_real)):\n",
    "    print(f'C={c}, LL dev={rvl_ll_dev[c]:,.0f}, LL roc={rvl_ll_roc[c]:,.0f}', end=' ')\n",
    "    ll_dev_tot += rvl_ll_dev[c]\n",
    "    ll_tot_roc += rvl_ll_roc[c]\n",
    "    if rvl_ll_dev[c] > rvl_ll_roc[c]:\n",
    "        print('!!! DEV WINS')\n",
    "    else:\n",
    "        print('??? ROC WINS')\n",
    "print(f'TOT, LL dev={ll_dev_tot:,.0f}, LL roc={ll_tot_roc:,.0f}', end=' ')\n",
    "if ll_dev_tot > ll_tot_roc:\n",
    "    print('\\t!!! DEV WINS')\n",
    "else:\n",
    "    print('\\t??? ROC WINS')\n",
    "print(flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b646218",
   "metadata": {},
   "source": [
    "$SI_d$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "1d626db1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C=0, LL dev=-69,064, LL roc=-83,141 !!! DEV WINS\n",
      "C=1, LL dev=-157,763, LL roc=-144,059 ??? ROC WINS\n",
      "C=2, LL dev=-31,564, LL roc=-41,968 !!! DEV WINS\n",
      "TOT, LL dev=-258,391, LL roc=-269,168 \t!!! DEV WINS\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "ll_dev_tot = 0\n",
    "ll_tot_roc = 0\n",
    "for c in range(len(sir_real)):\n",
    "    print(f'C={c}, LL dev={sid_ll_dev[c]:,.0f}, LL roc={sid_ll_roc[c]:,.0f}', end=' ')\n",
    "    ll_dev_tot += sid_ll_dev[c]\n",
    "    ll_tot_roc += sid_ll_roc[c]\n",
    "    if sid_ll_dev[c] > sid_ll_roc[c]:\n",
    "        print('!!! DEV WINS')\n",
    "    else:\n",
    "        print('??? ROC WINS')\n",
    "print(f'TOT, LL dev={ll_dev_tot:,.0f}, LL roc={ll_tot_roc:,.0f}', end=' ')\n",
    "if ll_dev_tot > ll_tot_roc:\n",
    "    print('\\t!!! DEV WINS')\n",
    "else:\n",
    "    print('\\t??? ROC WINS')\n",
    "print(flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ed085a8",
   "metadata": {},
   "source": [
    "### Calculate BIC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "0833a60e",
   "metadata": {},
   "outputs": [],
   "source": [
    "dof_dev = 4\n",
    "dof_roc = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "593d7ff6",
   "metadata": {},
   "source": [
    "$SI_r$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "769c1c5e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C=0, L dev=-11,731, L roc=-12,490,\t\tΔBIC = 1,518\n",
      "C=1, L dev=-30,930, L roc=-48,815,\t\tΔBIC = 35,770\n",
      "C=2, L dev=-9,028, L roc=-16,235,\t\tΔBIC = 14,414\n",
      "TOT, LL dev=-51,689, LL roc=-77,540,\t\tΔBIC = 51,702\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "ll_dev_tot = 0\n",
    "ll_roc_tot = 0\n",
    "n_tot = 0\n",
    "for c in range(len(sir_real)):\n",
    "    ll_dev = sir_ll_dev[c]\n",
    "    ll_roc = sir_ll_roc[c]\n",
    "    n = len(sir_model_dev[c])\n",
    "    print(f'C={c}, L dev={ll_dev:,.0f}, L roc={ll_roc:,.0f}', end=',\\t\\t')\n",
    "    ll_dev_tot += ll_dev\n",
    "    ll_roc_tot += ll_roc\n",
    "    n_tot += n\n",
    "    \n",
    "    bic_dev = -2*ll_dev + dof_dev*np.log(n)\n",
    "    bic_roc = -2*ll_roc + dof_roc*np.log(n)\n",
    "    delta_bic = bic_roc - bic_dev\n",
    "    print(f'ΔBIC = {delta_bic:,.0f}')\n",
    "\n",
    "print(f'TOT, LL dev={ll_dev_tot:,.0f}, LL roc={ll_roc_tot:,.0f}', end=',\\t\\t')\n",
    "\n",
    "bic_dev = -2*ll_dev_tot + dof_dev*np.log(n_tot)\n",
    "bic_roc = -2*ll_roc_tot + dof_roc*np.log(n_tot)\n",
    "delta_bic = bic_roc - bic_dev\n",
    "print(f'ΔBIC = {delta_bic:,.0f}')\n",
    "print(flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c9083b1",
   "metadata": {},
   "source": [
    "$RVL$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "8da3aa60",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C=0, L dev=-4,900, L roc=-5,696,\t\tΔBIC = 1,592\n",
      "C=1, L dev=-12,354, L roc=-14,075,\t\tΔBIC = 3,442\n",
      "C=2, L dev=-2,805, L roc=-4,336,\t\tΔBIC = 3,062\n",
      "TOT, LL dev=-20,059, LL roc=-24,107,\t\tΔBIC = 8,096\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "ll_dev_tot = 0\n",
    "ll_roc_tot = 0\n",
    "n_tot = 0\n",
    "for c in range(len(rvl_real)):\n",
    "    ll_dev = rvl_ll_dev[c]\n",
    "    ll_roc = rvl_ll_roc[c]\n",
    "    n = len(rvl_model_dev[c])\n",
    "    print(f'C={c}, L dev={ll_dev:,.0f}, L roc={ll_roc:,.0f}', end=',\\t\\t')\n",
    "    ll_dev_tot += ll_dev\n",
    "    ll_roc_tot += ll_roc\n",
    "    n_tot += n\n",
    "    \n",
    "    bic_dev = -2*ll_dev + dof_dev*np.log(n)\n",
    "    bic_roc = -2*ll_roc + dof_roc*np.log(n)\n",
    "    delta_bic = bic_roc - bic_dev\n",
    "    print(f'ΔBIC = {delta_bic:,.0f}')\n",
    "\n",
    "print(f'TOT, LL dev={ll_dev_tot:,.0f}, LL roc={ll_roc_tot:,.0f}', end=',\\t\\t')\n",
    "\n",
    "bic_dev = -2*ll_dev_tot + dof_dev*np.log(n_tot)\n",
    "bic_roc = -2*ll_roc_tot + dof_roc*np.log(n_tot)\n",
    "delta_bic = bic_roc - bic_dev\n",
    "print(f'ΔBIC = {delta_bic:,.0f}')\n",
    "print(flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6bb7b5f",
   "metadata": {},
   "source": [
    "$SI_d$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "bc6756c1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C=0, L dev=-69,064, L roc=-83,141,\t\tΔBIC = 28,154\n",
      "C=1, L dev=-157,763, L roc=-144,059,\t\tΔBIC = -27,408\n",
      "C=2, L dev=-31,564, L roc=-41,968,\t\tΔBIC = 20,808\n",
      "TOT, LL dev=-258,391, LL roc=-269,168,\t\tΔBIC = 21,554\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "ll_dev_tot = 0\n",
    "ll_roc_tot = 0\n",
    "n_tot = 0\n",
    "for c in range(len(sid_real)):\n",
    "    ll_dev = sid_ll_dev[c]\n",
    "    ll_roc = sid_ll_roc[c]\n",
    "    n = len(sid_model_dev[c])\n",
    "    print(f'C={c}, L dev={ll_dev:,.0f}, L roc={ll_roc:,.0f}', end=',\\t\\t')\n",
    "    ll_dev_tot += ll_dev\n",
    "    ll_roc_tot += ll_roc\n",
    "    n_tot += n\n",
    "    \n",
    "    bic_dev = -2*ll_dev + dof_dev*np.log(n)\n",
    "    bic_roc = -2*ll_roc + dof_roc*np.log(n)\n",
    "    delta_bic = bic_roc - bic_dev\n",
    "    print(f'ΔBIC = {delta_bic:,.0f}')\n",
    "\n",
    "print(f'TOT, LL dev={ll_dev_tot:,.0f}, LL roc={ll_roc_tot:,.0f}', end=',\\t\\t')\n",
    "\n",
    "bic_dev = -2*ll_dev_tot + dof_dev*np.log(n_tot)\n",
    "bic_roc = -2*ll_roc_tot + dof_roc*np.log(n_tot)\n",
    "delta_bic = bic_roc - bic_dev\n",
    "print(f'ΔBIC = {delta_bic:,.0f}')\n",
    "print(flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc82a1b9",
   "metadata": {},
   "source": [
    "# Binomial log-likelihood"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "6e4a36e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.special import comb\n",
    "\n",
    "def calculate_binomial_loglikelihood(cell_perc_model, cell_n_real, tot_n_cells_real):\n",
    "    return np.sum(\n",
    "        np.log(comb(tot_n_cells_real, cell_n_real)) +\n",
    "        cell_n_real*np.log(cell_perc_model) +\n",
    "        (tot_n_cells_real - cell_n_real)*np.log(1-cell_perc_model)\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "df60657f",
   "metadata": {},
   "outputs": [],
   "source": [
    "pc_bll_dev, pc_bll_roc = [], []\n",
    "hdc_bll_dev, hdc_bll_roc = [], []\n",
    "phdc_bll_dev, phdc_bll_roc = [], []\n",
    "\n",
    "for c in range(len(pc_perc_real)):\n",
    "    pc_bll_dev.append(calculate_binomial_loglikelihood(\n",
    "        pc_perc_model_dev[c], np.array(pc_n_real[c]), np.array(n_real[c])\n",
    "    ))\n",
    "    pc_bll_roc.append(calculate_binomial_loglikelihood(\n",
    "        pc_perc_model_roc[c], np.array(pc_n_real[c]), np.array(n_real[c])\n",
    "    ))\n",
    "\n",
    "    hdc_bll_dev.append(calculate_binomial_loglikelihood(\n",
    "        hdc_perc_model_dev[c], np.array(hdc_n_real[c]), np.array(n_real[c])\n",
    "    ))\n",
    "    hdc_bll_roc.append(calculate_binomial_loglikelihood(\n",
    "        hdc_perc_model_roc[c], np.array(hdc_n_real[c]), np.array(n_real[c])\n",
    "    ))\n",
    "\n",
    "    phdc_bll_dev.append(calculate_binomial_loglikelihood(\n",
    "        phdc_perc_model_dev[c], np.array(phdc_n_real[c]), np.array(n_real[c])\n",
    "    ))\n",
    "    phdc_bll_roc.append(calculate_binomial_loglikelihood(\n",
    "        phdc_perc_model_roc[c], np.array(phdc_n_real[c]), np.array(n_real[c])\n",
    "    ))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3f72a9c",
   "metadata": {},
   "source": [
    "### Compare bin log-likelihoods"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fe4edbb",
   "metadata": {},
   "source": [
    "Perc. Place cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "4767ef48",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C=0, BLL dev=-83, BLL roc=-130 !!! DEV WINS\n",
      "C=1, BLL dev=-294, BLL roc=-471 !!! DEV WINS\n",
      "C=2, BLL dev=-101, BLL roc=-186 !!! DEV WINS\n",
      "TOT, LL dev=-478, LL roc=-787 \t!!! DEV WINS\n",
      "\n"
     ]
    }
   ],
   "source": [
    "bll_dev_tot = 0\n",
    "bll_tot_roc = 0\n",
    "for c in range(len(pc_perc_real)):\n",
    "    bll_dev = pc_bll_dev[c]\n",
    "    bll_roc = pc_bll_roc[c]\n",
    "    print(f'C={c}, BLL dev={bll_dev:,.0f}, BLL roc={bll_roc:,.0f}', end=' ')\n",
    "    bll_dev_tot += bll_dev\n",
    "    bll_tot_roc += bll_roc\n",
    "    if bll_dev > bll_roc:\n",
    "        print('!!! DEV WINS')\n",
    "    else:\n",
    "        print('??? ROC WINS')\n",
    "print(f'TOT, LL dev={bll_dev_tot:,.0f}, LL roc={bll_tot_roc:,.0f}', end=' ')\n",
    "if bll_dev_tot > bll_tot_roc:\n",
    "    print('\\t!!! DEV WINS')\n",
    "else:\n",
    "    print('\\t??? ROC WINS')\n",
    "print(flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e33d620",
   "metadata": {},
   "source": [
    "Perc. HD cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "5ce075a0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C=0, BLL dev=-51, BLL roc=-78 !!! DEV WINS\n",
      "C=1, BLL dev=-173, BLL roc=-254 !!! DEV WINS\n",
      "C=2, BLL dev=-21, BLL roc=-30 !!! DEV WINS\n",
      "TOT, LL dev=-244, LL roc=-362 \t!!! DEV WINS\n",
      "\n"
     ]
    }
   ],
   "source": [
    "bll_dev_tot = 0\n",
    "bll_tot_roc = 0\n",
    "for c in range(len(pc_perc_real)):\n",
    "    bll_dev = hdc_bll_dev[c]\n",
    "    bll_roc = hdc_bll_roc[c]\n",
    "    print(f'C={c}, BLL dev={bll_dev:,.0f}, BLL roc={bll_roc:,.0f}', end=' ')\n",
    "    bll_dev_tot += bll_dev\n",
    "    bll_tot_roc += bll_roc\n",
    "    if bll_dev > bll_roc:\n",
    "        print('!!! DEV WINS')\n",
    "    else:\n",
    "        print('??? ROC WINS')\n",
    "print(f'TOT, LL dev={bll_dev_tot:,.0f}, LL roc={bll_tot_roc:,.0f}', end=' ')\n",
    "if bll_dev_tot > bll_tot_roc:\n",
    "    print('\\t!!! DEV WINS')\n",
    "else:\n",
    "    print('\\t??? ROC WINS')\n",
    "print(flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d41cbef",
   "metadata": {},
   "source": [
    "Perc. Place+HD cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "8eb2f1a8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C=0, BLL dev=-52, BLL roc=-81 !!! DEV WINS\n",
      "C=1, BLL dev=-185, BLL roc=-334 !!! DEV WINS\n",
      "C=2, BLL dev=-19, BLL roc=-46 !!! DEV WINS\n",
      "TOT, LL dev=-257, LL roc=-462 \t!!! DEV WINS\n",
      "\n"
     ]
    }
   ],
   "source": [
    "bll_dev_tot = 0\n",
    "bll_tot_roc = 0\n",
    "for c in range(len(pc_perc_real)):\n",
    "    bll_dev = phdc_bll_dev[c]\n",
    "    bll_roc = phdc_bll_roc[c]\n",
    "    print(f'C={c}, BLL dev={bll_dev:,.0f}, BLL roc={bll_roc:,.0f}', end=' ')\n",
    "    bll_dev_tot += bll_dev\n",
    "    bll_tot_roc += bll_roc\n",
    "    if bll_dev > bll_roc:\n",
    "        print('!!! DEV WINS')\n",
    "    else:\n",
    "        print('??? ROC WINS')\n",
    "print(f'TOT, LL dev={bll_dev_tot:,.0f}, LL roc={bll_tot_roc:,.0f}', end=' ')\n",
    "if bll_dev_tot > bll_tot_roc:\n",
    "    print('\\t!!! DEV WINS')\n",
    "else:\n",
    "    print('\\t??? ROC WINS')\n",
    "print(flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa72aa70",
   "metadata": {},
   "source": [
    "### Calculate BIC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "03ad4ca8",
   "metadata": {},
   "outputs": [],
   "source": [
    "dof_dev = 4\n",
    "dof_roc = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84899453",
   "metadata": {},
   "source": [
    "Perc. Place cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "b123cac9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C=0, L dev=-83, L roc=-130,\t\tΔBIC = 94\n",
      "C=1, L dev=-294, L roc=-471,\t\tΔBIC = 354\n",
      "C=2, L dev=-101, L roc=-186,\t\tΔBIC = 170\n",
      "TOT, LL dev=-478, LL roc=-787,\t\tΔBIC = 618\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "bll_dev_tot = 0\n",
    "bll_roc_tot = 0\n",
    "n_tot = 0\n",
    "for c in range(len(pc_perc_real)):\n",
    "    bll_dev = pc_bll_dev[c]\n",
    "    bll_roc = pc_bll_roc[c]\n",
    "    n = len(sir_real[c])\n",
    "    print(f'C={c}, L dev={bll_dev:,.0f}, L roc={bll_roc:,.0f}', end=',\\t\\t')\n",
    "    bll_dev_tot += bll_dev\n",
    "    bll_roc_tot += bll_roc\n",
    "    n_tot += n\n",
    "    \n",
    "    bic_dev = -2*bll_dev + dof_dev*np.log(n)\n",
    "    bic_roc = -2*bll_roc + dof_roc*np.log(n)\n",
    "    delta_bic = bic_roc - bic_dev\n",
    "    print(f'ΔBIC = {delta_bic:,.0f}')\n",
    "\n",
    "print(f'TOT, LL dev={bll_dev_tot:,.0f}, LL roc={bll_roc_tot:,.0f}', end=',\\t\\t')\n",
    "\n",
    "bic_dev = -2*bll_dev_tot + dof_dev*np.log(n_tot)\n",
    "bic_roc = -2*bll_roc_tot + dof_roc*np.log(n_tot)\n",
    "delta_bic = bic_roc - bic_dev\n",
    "print(f'ΔBIC = {delta_bic:,.0f}')\n",
    "print(flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2d61f98",
   "metadata": {},
   "source": [
    "Perc. Place cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "85e3bae7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C=0, L dev=-51, L roc=-78,\t\tΔBIC = 54\n",
      "C=1, L dev=-173, L roc=-254,\t\tΔBIC = 162\n",
      "C=2, L dev=-21, L roc=-30,\t\tΔBIC = 19\n",
      "TOT, LL dev=-244, LL roc=-362,\t\tΔBIC = 235\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "bll_dev_tot = 0\n",
    "bll_roc_tot = 0\n",
    "n_tot = 0\n",
    "for c in range(len(hdc_perc_real)):\n",
    "    bll_dev = hdc_bll_dev[c]\n",
    "    bll_roc = hdc_bll_roc[c]\n",
    "    n = len(sir_real[c])\n",
    "    print(f'C={c}, L dev={bll_dev:,.0f}, L roc={bll_roc:,.0f}', end=',\\t\\t')\n",
    "    bll_dev_tot += bll_dev\n",
    "    bll_roc_tot += bll_roc\n",
    "    n_tot += n\n",
    "    \n",
    "    bic_dev = -2*bll_dev + dof_dev*np.log(n)\n",
    "    bic_roc = -2*bll_roc + dof_roc*np.log(n)\n",
    "    delta_bic = bic_roc - bic_dev\n",
    "    print(f'ΔBIC = {delta_bic:,.0f}')\n",
    "\n",
    "print(f'TOT, LL dev={bll_dev_tot:,.0f}, LL roc={bll_roc_tot:,.0f}', end=',\\t\\t')\n",
    "\n",
    "bic_dev = -2*bll_dev_tot + dof_dev*np.log(n_tot)\n",
    "bic_roc = -2*bll_roc_tot + dof_roc*np.log(n_tot)\n",
    "delta_bic = bic_roc - bic_dev\n",
    "print(f'ΔBIC = {delta_bic:,.0f}')\n",
    "print(flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af1c9238",
   "metadata": {},
   "source": [
    "Perc. Place+HD cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "4b334d2f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C=0, L dev=-52, L roc=-81,\t\tΔBIC = 59\n",
      "C=1, L dev=-185, L roc=-334,\t\tΔBIC = 298\n",
      "C=2, L dev=-19, L roc=-46,\t\tΔBIC = 53\n",
      "TOT, LL dev=-257, LL roc=-462,\t\tΔBIC = 410\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "bll_dev_tot = 0\n",
    "bll_roc_tot = 0\n",
    "n_tot = 0\n",
    "for c in range(len(phdc_perc_real)):\n",
    "    bll_dev = phdc_bll_dev[c]\n",
    "    bll_roc = phdc_bll_roc[c]\n",
    "    n = len(sir_real[c])\n",
    "    print(f'C={c}, L dev={bll_dev:,.0f}, L roc={bll_roc:,.0f}', end=',\\t\\t')\n",
    "    bll_dev_tot += bll_dev\n",
    "    bll_roc_tot += bll_roc\n",
    "    n_tot += n\n",
    "    \n",
    "    bic_dev = -2*bll_dev + dof_dev*np.log(n)\n",
    "    bic_roc = -2*bll_roc + dof_roc*np.log(n)\n",
    "    delta_bic = bic_roc - bic_dev\n",
    "    print(f'ΔBIC = {delta_bic:,.0f}')\n",
    "\n",
    "print(f'TOT, LL dev={bll_dev_tot:,.0f}, LL roc={bll_roc_tot:,.0f}', end=',\\t\\t')\n",
    "\n",
    "bic_dev = -2*bll_dev_tot + dof_dev*np.log(n_tot)\n",
    "bic_roc = -2*bll_roc_tot + dof_roc*np.log(n_tot)\n",
    "delta_bic = bic_roc - bic_dev\n",
    "print(f'ΔBIC = {delta_bic:,.0f}')\n",
    "print(flush=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29995c6a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab257a91",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "vrtopc",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
