{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe5a643b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import nibabel.freesurfer.io as fsio\n",
    "import pandas as pd\n",
    "\n",
    "from src.utils import load_pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c1391bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.style.use('../style/plots.mplstyle')\n",
    "palette = ['#a4c8ffff', '#ffa5acff', '#afffa6ff', '#d8a6ffff', '#ffd0a5ff', '#d7d7d7ff']\n",
    "\n",
    "seaborn_props = {\n",
    "    'boxprops': {'edgecolor':'black', 'linewidth': 1.5},\n",
    "    'medianprops': {'color':'black', 'linewidth': 1.5},\n",
    "    'whiskerprops': {'color':'black', 'linewidth': 1.5},\n",
    "    'capprops': {'color':'black', 'linewidth': 1.5}\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51757aa8",
   "metadata": {},
   "outputs": [],
   "source": [
    "path_datasets = os.path.join(..., 'datasets')\n",
    "assert os.path.exists(path_datasets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f19a720",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'THINGS-fMRI'\n",
    "subject_id = 'S2'\n",
    "num_neighbors = 50\n",
    "split = 'train'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "541847a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "path_dataset = os.path.join(path_datasets, dataset)\n",
    "assert os.path.exists(path_dataset)\n",
    "\n",
    "path_results = os.path.join(path_dataset, 'results', subject_id)\n",
    "assert os.path.exists(path_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b743728f",
   "metadata": {},
   "outputs": [],
   "source": [
    "path_annot_left = os.path.join(path_dataset, 'freesurfer', subject_id, 'label', 'lh.visual.annot')\n",
    "path_annot_right = os.path.join(path_dataset, 'freesurfer', subject_id, 'label', 'rh.visual.annot')\n",
    "assert os.path.exists(path_annot_left)\n",
    "assert os.path.exists(path_annot_right)\n",
    "\n",
    "labels_vertices_left, _, names_vertices_left = fsio.read_annot(path_annot_left)\n",
    "labels_vertices_right, _, names_vertices_right = fsio.read_annot(path_annot_right)\n",
    "\n",
    "names_vertices_left = np.array([name.decode('utf-8').split('_')[1] for name in names_vertices_left])\n",
    "names_vertices_right = np.array([name.decode('utf-8').split('_')[1] for name in names_vertices_right])\n",
    "\n",
    "assert np.all(names_vertices_left == names_vertices_right)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a01d967",
   "metadata": {},
   "outputs": [],
   "source": [
    "unique_labels_left = np.unique(labels_vertices_left)[1:]\n",
    "unique_labels_right = np.unique(labels_vertices_left)[1:]\n",
    "\n",
    "assert np.all(unique_labels_left == unique_labels_right)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21764c19",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_subsample = 100\n",
    "total_subsample = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d663f58",
   "metadata": {},
   "outputs": [],
   "source": [
    "path_geodesic = os.path.join(path_results, f'{subject_id}_geodesic_matrix.pickle')\n",
    "assert os.path.exists(path_geodesic), path_geodesic\n",
    "\n",
    "dict_geodesic = load_pickle(path_geodesic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b4e39fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "names_areas = np.array(dict_geodesic['areas'])\n",
    "matric_geodesic = dict_geodesic['matrix']\n",
    "\n",
    "index_V1 = np.where(names_areas == 'V1')[0]\n",
    "\n",
    "dict_distances = {}\n",
    "for area in names_areas:\n",
    "    index_area = np.where(names_areas == area)[0]\n",
    "    dict_distances[area] = matric_geodesic[index_V1, index_area].item()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff37ddb7",
   "metadata": {},
   "source": [
    "## Linear Dimensionality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6f1b7ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "name_estimator = 'lpca'\n",
    "path_dict_dims_ED = os.path.join(path_results, 'dimensionality', f'{subject_id}_dimensionality_{split}_{num_neighbors}_{name_estimator}.pickle')\n",
    "assert os.path.exists(path_dict_dims_ED)\n",
    "dict_dims_ED = load_pickle(path_dict_dims_ED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87148803",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(9876)\n",
    "\n",
    "dims_areas_linear = []\n",
    "\n",
    "for i_area, (label, name) in enumerate(zip(unique_labels_left, names_vertices_left)):\n",
    "    indices_area_left = np.where(labels_vertices_left == label)[0]\n",
    "    vertices_area_left = dict_dims_ED['lh'][indices_area_left]\n",
    "\n",
    "    indices_area_right = np.where(labels_vertices_right == label)[0]\n",
    "    vertices_area_right = dict_dims_ED['rh'][indices_area_right]\n",
    "\n",
    "    vertices_area = np.concatenate((vertices_area_left, vertices_area_right), axis=0)\n",
    "    num_vertices_area = vertices_area.shape[0]\n",
    "\n",
    "    for i_sample in range(num_subsample):\n",
    "        sampled_indices = np.random.choice(num_vertices_area, size=num_subsample, replace=False)\n",
    "        \n",
    "        dims_areas_linear.append({'Area': name,\n",
    "                                  'Linear Dimensionality': vertices_area[sampled_indices].mean(),\n",
    "                                  'Distance (mm)': dict_distances[name]})\n",
    "\n",
    "df_dims_areas_linear = pd.DataFrame(dims_areas_linear)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca170f67",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_dims_areas_linear"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fdb70b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_dim = []\n",
    "std_dim = []\n",
    "distance = []\n",
    "\n",
    "for area in names_vertices_left:\n",
    "    mean_dim.append(df_dims_areas_linear[df_dims_areas_linear['Area'] == area]['Linear Dimensionality'].mean())\n",
    "    std_dim.append(df_dims_areas_linear[df_dims_areas_linear['Area'] == area]['Linear Dimensionality'].std())\n",
    "    distance.append(df_dims_areas_linear[df_dims_areas_linear['Area'] == area]['Distance (mm)'].mean())\n",
    "\n",
    "mean_dim = np.array(mean_dim)\n",
    "std_dim = np.array(std_dim)\n",
    "distance = np.array(distance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52aacf48",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert np.all(names_vertices_left == names_vertices_right)\n",
    "assert np.all(unique_labels_left == unique_labels_right)\n",
    "\n",
    "primary_rois = ['V1']\n",
    "indices_primary = np.array([i for i, area in enumerate(names_vertices_left) if area in primary_rois])\n",
    "\n",
    "early_rois = ['V2', 'V3', 'V4']\n",
    "indices_early = np.array([i for i, area in enumerate(names_vertices_left) if area in early_rois])\n",
    "\n",
    "ventral_rois = ['V8', 'FFC', 'PIT', 'VMV1', 'VMV3', 'VMV2', 'VVC']\n",
    "indices_ventral = np.array([i for i, area in enumerate(names_vertices_left) if area in ventral_rois])\n",
    "\n",
    "dorsal_rois = ['V6', 'V3A', 'V7', 'IPS1', 'V3B', 'V6A']\n",
    "indices_dorsal = np.array([i for i, area in enumerate(names_vertices_left) if area in dorsal_rois])\n",
    "\n",
    "mt_rois = ['MST', 'LO1', 'LO2', 'MT', 'PH', 'V4t', 'FST', 'V3CD', 'LO3']\n",
    "indices_mt = np.array([i for i, area in enumerate(names_vertices_left) if area in mt_rois])\n",
    "\n",
    "assert len(primary_rois) + len(early_rois) + len(ventral_rois) + len(dorsal_rois) + len(mt_rois) == len(unique_labels_left)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4965fda2",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(4.5, 3.5))\n",
    "ax = plt.gca()\n",
    "\n",
    "ax.errorbar(distance[indices_primary], mean_dim[indices_primary], yerr=std_dim[indices_primary], fmt='o', capsize=5, capthick=2, color=palette[0])\n",
    "ax.errorbar(distance[indices_early], mean_dim[indices_early], yerr=std_dim[indices_early], fmt='o', capsize=5, capthick=2, color=palette[1])\n",
    "ax.errorbar(distance[indices_ventral], mean_dim[indices_ventral], yerr=std_dim[indices_ventral], fmt='o', capsize=5, capthick=2, color=palette[2])\n",
    "ax.errorbar(distance[indices_dorsal], mean_dim[indices_dorsal], yerr=std_dim[indices_dorsal], fmt='o', capsize=5, capthick=2, color=palette[3])\n",
    "ax.errorbar(distance[indices_mt], mean_dim[indices_mt], yerr=std_dim[indices_mt], fmt='o', capsize=5, capthick=2, color=palette[4])\n",
    "\n",
    "ax.set_xlabel('Cortical Distance (mm)')\n",
    "ax.set_ylabel('Linear Dimensionality')\n",
    "ax.legend(['Primary', 'Early', 'Ventral', 'Dorsal', 'MT'], loc='upper left', bbox_to_anchor=(1, 1), fontsize=10, frameon=False)\n",
    "\n",
    "fig.suptitle(subject_id, fontsize=16)\n",
    "fig.tight_layout()\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d127c2e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_linear = np.concatenate((distance[indices_primary], distance[indices_early], distance[indices_ventral]))\n",
    "y_linear = np.concatenate((mean_dim[indices_primary], mean_dim[indices_early], mean_dim[indices_ventral]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da45e902",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.optimize import curve_fit\n",
    "\n",
    "def model(x, a, b, c):\n",
    "    return a * x**b + c\n",
    "\n",
    "# Fit model\n",
    "p0 = [1.0, 2.0, 1.0]  # initial guesses\n",
    "popt, pcov = curve_fit(model, X_linear, y_linear, p0=p0)\n",
    "\n",
    "x_plot = np.linspace(0, X_linear.max(), 200)\n",
    "y_plot = model(x_plot, *popt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "919c87e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_plot = np.linspace(0, X_linear.max(), 200)\n",
    "y_plot = model(x_plot, *popt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a6a5dd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs('svg', exist_ok=True)\n",
    "path_figure = os.path.join('svg', f'figure02_fMRI_effective_{num_neighbors}.svg')\n",
    "\n",
    "fig = plt.figure(figsize=(3, 3))\n",
    "ax = plt.gca()\n",
    "\n",
    "ax.errorbar(distance[indices_primary], mean_dim[indices_primary], yerr=std_dim[indices_primary], fmt='o', capsize=5, capthick=2, color=palette[1])\n",
    "ax.errorbar(distance[indices_early], mean_dim[indices_early], yerr=std_dim[indices_early], fmt='o', capsize=5, capthick=2, color=palette[1])\n",
    "ax.errorbar(distance[indices_ventral], mean_dim[indices_ventral], yerr=std_dim[indices_ventral], fmt='o', capsize=5, capthick=2, color=palette[1])\n",
    "\n",
    "ax.plot(x_plot, y_plot, color=palette[1], linewidth=2)\n",
    "\n",
    "ax.set_xlabel('Cortical Distance (mm)')\n",
    "ax.set_ylabel('Effective Dimensionality')\n",
    "ax.set_xlim(-4, 90)  \n",
    "ax.set_ylim(2.5, 6)\n",
    "ax.set_xticks([0, 45, 90], [0, 45, 90])\n",
    "ax.set_title(subject_id)\n",
    "fig.tight_layout()\n",
    "fig.savefig(path_figure, transparent=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11c49865",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_linear = X_linear / np.max(X_linear)\n",
    "popt, pcov = curve_fit(model, X_linear, y_linear, p0=p0)\n",
    "a, b, c = popt\n",
    "print(f'{a:.2f}x^{b:.2f} + {c:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9a52b65",
   "metadata": {},
   "source": [
    "## Non-Linear Dimensionality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97ed37db",
   "metadata": {},
   "outputs": [],
   "source": [
    "name_estimator = 'mle'\n",
    "path_dict_dims_ID = os.path.join(path_results, 'dimensionality', f'{subject_id}_dimensionality_{split}_{num_neighbors}_{name_estimator}.pickle')\n",
    "assert os.path.exists(path_dict_dims_ID)\n",
    "dict_dims_ID = load_pickle(path_dict_dims_ID)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f728672a",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(5678)\n",
    "name_estimator = 'mle'\n",
    "\n",
    "dims_areas_nonlinear = []\n",
    "\n",
    "for i_area, (label, name) in enumerate(zip(unique_labels_left, names_vertices_left)):\n",
    "    indices_area_left = np.where(labels_vertices_left == label)[0]\n",
    "    vertices_area_left = dict_dims_ID['lh'][indices_area_left]\n",
    "\n",
    "    indices_area_right = np.where(labels_vertices_right == label)[0]\n",
    "    vertices_area_right = dict_dims_ID['rh'][indices_area_right]\n",
    "\n",
    "    vertices_area = np.concatenate((vertices_area_left, vertices_area_right), axis=0)\n",
    "    num_vertices_area = vertices_area.shape[0]\n",
    "\n",
    "    for i_sample in range(num_subsample):\n",
    "        sampled_indices = np.random.choice(num_vertices_area, size=num_subsample, replace=False)\n",
    "        \n",
    "        dims_areas_nonlinear.append({'Area': name,\n",
    "                                  'Non-Linear Dimensionality': vertices_area[sampled_indices].mean(),\n",
    "                                  'Distance (mm)': dict_distances[name]})\n",
    "\n",
    "df_dims_areas_nonlinear = pd.DataFrame(dims_areas_nonlinear)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33a8b869",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_dim = []\n",
    "std_dim = []\n",
    "distance = []\n",
    "\n",
    "for area in names_vertices_left:\n",
    "    mean_dim.append(df_dims_areas_nonlinear[df_dims_areas_nonlinear['Area'] == area]['Non-Linear Dimensionality'].mean())\n",
    "    std_dim.append(df_dims_areas_nonlinear[df_dims_areas_nonlinear['Area'] == area]['Non-Linear Dimensionality'].std())\n",
    "    distance.append(df_dims_areas_nonlinear[df_dims_areas_nonlinear['Area'] == area]['Distance (mm)'].mean())\n",
    "\n",
    "mean_dim = np.array(mean_dim)\n",
    "std_dim = np.array(std_dim)\n",
    "distance = np.array(distance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0115c951",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(4.5, 3.5))\n",
    "ax = plt.gca()\n",
    "\n",
    "ax.errorbar(distance[indices_primary], mean_dim[indices_primary], yerr=std_dim[indices_primary], fmt='o', capsize=5, capthick=2, color=palette[0])\n",
    "ax.errorbar(distance[indices_early], mean_dim[indices_early], yerr=std_dim[indices_early], fmt='o', capsize=5, capthick=2, color=palette[1])\n",
    "ax.errorbar(distance[indices_ventral], mean_dim[indices_ventral], yerr=std_dim[indices_ventral], fmt='o', capsize=5, capthick=2, color=palette[2])\n",
    "ax.errorbar(distance[indices_dorsal], mean_dim[indices_dorsal], yerr=std_dim[indices_dorsal], fmt='o', capsize=5, capthick=2, color=palette[3])\n",
    "ax.errorbar(distance[indices_mt], mean_dim[indices_mt], yerr=std_dim[indices_mt], fmt='o', capsize=5, capthick=2, color=palette[4])\n",
    "\n",
    "ax.set_xlabel('Cortical Distance (mm)')\n",
    "ax.set_ylabel('Linear Dimensionality')\n",
    "ax.legend(['Primary', 'Early', 'Ventral', 'Dorsal', 'MT'], loc='upper left', bbox_to_anchor=(1, 1), fontsize=10, frameon=False)\n",
    "\n",
    "fig.suptitle(subject_id, fontsize=16)\n",
    "fig.tight_layout()\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b3f5183",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_nonlinear = np.concatenate((distance[indices_primary], distance[indices_early], distance[indices_ventral]))\n",
    "y_nonlinear = np.concatenate((mean_dim[indices_primary], mean_dim[indices_early], mean_dim[indices_ventral]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec1c4246",
   "metadata": {},
   "outputs": [],
   "source": [
    "p0 = [1.0, 2.0, 1.0]  # initial guesses\n",
    "popt, pcov = curve_fit(model, X_nonlinear, y_nonlinear, p0=p0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "745c653e",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_plot = np.linspace(0, X_nonlinear.max(), 200)\n",
    "y_plot = model(x_plot, *popt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06c23bd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "path_figure = os.path.join('svg', f'figure02_fMRI_intrinsic_{num_neighbors}.svg')\n",
    "\n",
    "fig = plt.figure(figsize=(3, 3))\n",
    "ax = plt.gca()\n",
    "\n",
    "ax.errorbar(distance[indices_primary], mean_dim[indices_primary], yerr=std_dim[indices_primary], fmt='o', capsize=5, capthick=2, color=palette[0])\n",
    "ax.errorbar(distance[indices_early], mean_dim[indices_early], yerr=std_dim[indices_early], fmt='o', capsize=5, capthick=2, color=palette[0])\n",
    "ax.errorbar(distance[indices_ventral], mean_dim[indices_ventral], yerr=std_dim[indices_ventral], fmt='o', capsize=5, capthick=2, color=palette[0])\n",
    "\n",
    "ax.plot(x_plot, y_plot, color=palette[0], linewidth=2)\n",
    "\n",
    "ax.set_xlabel('Cortical Distance (mm)')\n",
    "ax.set_ylabel('Intrinsic Dimensionality')\n",
    "ax.set_xlim(-4, 90)  \n",
    "ax.set_ylim(10, 13)\n",
    "ax.set_xticks([0, 45, 90], [0, 45, 90])\n",
    "ax.set_title(subject_id)\n",
    "fig.tight_layout()\n",
    "fig.savefig(path_figure, transparent=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bc0bce2",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_nonlinear = X_nonlinear / np.max(X_nonlinear)\n",
    "popt, pcov = curve_fit(model, X_nonlinear, y_nonlinear, p0=p0)\n",
    "a, b, c = popt\n",
    "print(f'{a:.2f}x^{b:.2f} + {c:.2f}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cognitive_maps",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
