{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import random\n",
    "import json\n",
    "import gc\n",
    "\n",
    "import PIL\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "import h5py\n",
    "import matplotlib.pyplot as plt\n",
    "from pathlib import Path\n",
    "from tqdm.notebook import tqdm\n",
    "import nibabel as nib\n",
    "from einops import rearrange\n",
    "from scipy import ndimage, stats\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "from torchvision import transforms\n",
    "from PIL import Image\n",
    "from PIL import ImageFont\n",
    "from PIL import ImageDraw \n",
    "from scipy.stats import sem\n",
    "\n",
    "dir2 = os.path.abspath('../..')\n",
    "dir1 = os.path.dirname(dir2)\n",
    "if not dir1 in sys.path: \n",
    "    sys.path.append(dir1)\n",
    "    \n",
    "from research.data.natural_scenes import NaturalScenesDataset\n",
    "from research.experiments.nsd.nsd_utils import tsne_image_plot\n",
    "from research.metrics.metrics import cosine_distance, top_knn_test, r2_score, pearsonr\n",
    "from pipeline.utils import get_data_iterator, DisablePrints, read_patch\n",
    "from pipeline.compact_json_encoder import CompactJSONEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "nsd_path = Path('D:\\\\Datasets\\\\NSD\\\\')\n",
    "nsd = NaturalScenesDataset(nsd_path, coco_path='X:\\\\Datasets\\\\COCO')\n",
    "stimuli_path = nsd_path / 'nsddata_stimuli' / 'stimuli' / 'nsd' / 'nsd_stimuli.hdf5'\n",
    "stimulus_images = h5py.File(stimuli_path, 'r')['imgBrick']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'ViT-B=32' #'clip-vit-large-patch14'\n",
    "group_name = 'group-22'\n",
    "\n",
    "subjects = [f'subj0{i}' for i in range(1, 9)]\n",
    "embedding_name = 'embedding'\n",
    "fold_name = 'val'\n",
    "\n",
    "embeddings = h5py.File(nsd_path / f'derivatives/decoded_features/{model_name}/{group_name}.hdf5', 'r')\n",
    "\n",
    "results_path = nsd_path / f'derivatives/figures/decoding/{model_name}/{group_name}/{fold_name}/{embedding_name}/'\n",
    "results_path.mkdir(exist_ok=True, parents=True)\n",
    "\n",
    "Y_full = h5py.File(nsd_path / f'derivatives/stimulus_embeddings/{model_name}.hdf5', 'r')[embedding_name][:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "subj01\n",
      "subj02\n",
      "subj03\n",
      "subj04\n",
      "subj05\n",
      "subj06\n",
      "subj07\n",
      "subj08\n"
     ]
    }
   ],
   "source": [
    "# Load decoder data\n",
    "\n",
    "from research.models.fmri_decoders import Decoder\n",
    "\n",
    "num_voxels = None\n",
    "\n",
    "results = {}\n",
    "folds = {}\n",
    "for fold in ('val', 'test'):\n",
    "    fold_data = {\n",
    "        'X_all': [],\n",
    "        'Y_all': [],\n",
    "        'Y_pred_all': [],\n",
    "        'stimulus_ids_all': [],\n",
    "    }\n",
    "    folds[fold] = fold_data\n",
    "models_all = []\n",
    "state_dicts_all = []\n",
    "indices_all = []\n",
    "\n",
    "load_X = True\n",
    "\n",
    "for i, subject in enumerate(subjects):\n",
    "    print(subject)\n",
    "    \n",
    "    subject_embeddings = embeddings[f'{subject}/{embedding_name}']\n",
    "    train_mask, val_mask, test_mask = nsd.get_split(subject, 'split-01')\n",
    "\n",
    "    if load_X:\n",
    "        config = dict(subject_embeddings.attrs)\n",
    "\n",
    "        model_params = {k: config[k] for k in ('layer_sizes', 'dropout_p') if k in config}\n",
    "        model = Decoder(**model_params)\n",
    "        model = model.eval()\n",
    "        state_dict = {k: torch.from_numpy(v[:]) for k, v in subject_embeddings['model'].items()}\n",
    "        state_dicts_all.append(state_dict)\n",
    "        model.load_state_dict({k: v.clone() for k, v in state_dict.items()})\n",
    "        models_all.append(model)\n",
    "        \n",
    "        betas_params = {\n",
    "            k: config[k] \n",
    "            for k in (\n",
    "                'subject_name', 'voxel_selection_path', \n",
    "                'voxel_selection_key', 'num_voxels', 'return_volume_indices', 'threshold'\n",
    "            )\n",
    "        }\n",
    "        if betas_params['threshold'] is not None:\n",
    "            betas_params['num_voxels'] = None\n",
    "            betas_params['return_tensor_dataset'] = False\n",
    "        betas, betas_indices = nsd.load_betas(**betas_params)\n",
    "        folds['val']['X_all'].append(betas[val_mask])\n",
    "        folds['test']['X_all'].append(betas[test_mask])\n",
    "        indices_all.append(betas_indices)\n",
    "\n",
    "    stimulus_params = dict(\n",
    "        subject_name=subject,\n",
    "        stimulus_path=f'derivatives/stimulus_embeddings/{model_name}.hdf5',\n",
    "        stimulus_key=embedding_name,\n",
    "        delay_loading=False,\n",
    "        return_tensor_dataset=False,\n",
    "        return_stimulus_ids=True,\n",
    "    )\n",
    "    stimulus, stimulus_ids = nsd.load_stimulus(**stimulus_params)\n",
    "    for fold, mask in [('val', val_mask), ('test', test_mask)]:\n",
    "        \n",
    "        folds[fold]['stimulus_ids_all'].append(stimulus_ids[mask])\n",
    "    \n",
    "        Y = stimulus[mask].astype(np.float32)\n",
    "        Y = Y.reshape(Y.shape[0], -1)\n",
    "        folds[fold]['Y_all'].append(Y)\n",
    "    \n",
    "        Y_pred = subject_embeddings[f'{fold}/Y_pred'][:]\n",
    "        Y_pred = Y_pred / np.linalg.norm(Y_pred, axis=1)[:, None]\n",
    "        folds[fold][\"Y_pred_all\"].append(Y_pred)\n",
    "        \n",
    "locals().update(folds[fold_name])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DBSCAN Clustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load a group of decoders\n",
    "\n",
    "decoding_models = h5py.File(nsd_path / f'derivatives/decoded_features/ViT-B=32/{group_name}.hdf5')\n",
    "W_decoding_all = []\n",
    "volume_indices_decoding_all = []\n",
    "for subject_id, subject_name in enumerate(subjects):\n",
    "    subject = decoding_models[f'{subject_name}/embedding/']\n",
    "    W_decoding_all.append(subject[f'model/layers.0.weight'][:])\n",
    "    volume_indices_decoding_all.append(subject[f'volume_indices'][:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Alteratively load a group of decoders where there were multiple runs per participant, average the weights across all runs\n",
    "\n",
    "group_name = 'group-22_reruns'\n",
    "\n",
    "decoding_models = h5py.File(nsd_path / f'derivatives/decoded_features/ViT-B=32/{group_name}.hdf5')\n",
    "num_runs = 50\n",
    "\n",
    "W_decoding_all = []\n",
    "W_decoding_reruns_all = []\n",
    "volume_indices_decoding_all = []\n",
    "for subject_id, subject_name in enumerate(subjects):\n",
    "    print(subject_name)\n",
    "    W_decoding_subject = []\n",
    "    for run_id in range(num_runs):\n",
    "        print(run_id)\n",
    "        W_decoding_subject.append(decoding_models[f'{subject_name}/embedding/run_{run_id}/model/layers.0.weight'][:])\n",
    "    W_decoding_reruns_all.append(np.stack(W_decoding_subject))\n",
    "    W_decoding_all.append(np.stack(W_decoding_subject).mean(axis=0))\n",
    "    volume_indices_decoding_all.append(decoding_models[f'{subject_name}/embedding/run_{run_id}/volume_indices'][:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute and cache nearest neighbor queries for the voxel-wise parameter vectors\n",
    "\n",
    "tag, W = (f'linear_decoding__{group_name}', [W_subj.T for W_subj in W_decoding_all])\n",
    "num_models = 1\n",
    "\n",
    "W = [W_subj / np.linalg.norm(W_subj, axis=1, keepdims=True) for W_subj in W]\n",
    "for W_subj in W:\n",
    "    W_subj[np.isnan(W_subj)] = 0.\n",
    "nn_all = [NearestNeighbors(radius=1.0, metric='cosine').fit(W_subj) for W_subj in W]\n",
    "cluster_name = 'density'\n",
    "print(tag)\n",
    "\n",
    "out_path = nsd_path / f'derivatives/figures/concept_maps_voxel_v5/{tag}'\n",
    "out_path.mkdir(exist_ok=True, parents=True)\n",
    "\n",
    "eps = 0.7 # max neighborhood size for queries\n",
    "print('computing neighbors')\n",
    "nn_results = np.zeros((8, 8), dtype=object)\n",
    "for subject_i in range(8):\n",
    "    for subject_j in range(8):\n",
    "        print(subject_i, subject_j)\n",
    "        nn_results[subject_i, subject_j] = nn_all[subject_i].radius_neighbors(W[subject_j], radius=eps, sort_results=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optionally save these queries to the disk\n",
    "\n",
    "np.save(out_path / f'nn_results__num_models-{num_models}_v1.npy', nn_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load saved queries from the disk\n",
    "\n",
    "nn_results = np.load(out_path / f'nn_results__num_models-{num_models}_v1.npy', allow_pickle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# DBSCAN code\n",
    "\n",
    "from collections import deque\n",
    "\n",
    "UNEXPLORED = -2\n",
    "OUTLIER = -1\n",
    "\n",
    "def within_participant_clustering(cluster_ids, cluster_id, subject_i, voxel_ids, eps, min_neighbors):\n",
    "    core_samples = []\n",
    "    fringe = deque(voxel_ids)\n",
    "    fringe_hash = set(voxel_ids)\n",
    "\n",
    "    while len(fringe) > 0:\n",
    "        check_voxel = fringe.popleft()\n",
    "        fringe_hash.remove(check_voxel)\n",
    "        \n",
    "        nn_dists = nn_results[subject_i, subject_i][0][check_voxel] \n",
    "        nn_ids = nn_results[subject_i, subject_i][1][check_voxel]\n",
    "\n",
    "        mask = nn_dists < eps\n",
    "        nn_dists = nn_dists[mask]\n",
    "        nn_ids = nn_ids[mask]\n",
    "        min_pts = nn_dists.shape[0]\n",
    "\n",
    "        # its a core sample\n",
    "        if min_pts >= min_neighbors:\n",
    "\n",
    "            # add to core samples list and cluster array\n",
    "            core_samples.append(check_voxel)\n",
    "            cluster_ids[check_voxel] = cluster_id\n",
    "\n",
    "            # add neighbors to the fringe\n",
    "            for neighbor_id in nn_ids:\n",
    "\n",
    "                # dont add the point we are checking\n",
    "                if neighbor_id == check_voxel:\n",
    "                    continue\n",
    "                # dont add points already in fringe\n",
    "                if neighbor_id in fringe_hash:\n",
    "                    continue\n",
    "                # only add unexplored points to the fringe\n",
    "                if cluster_ids[neighbor_id] == UNEXPLORED or cluster_ids[neighbor_id] == OUTLIER:\n",
    "                    fringe.append(neighbor_id)\n",
    "                    fringe_hash.add(neighbor_id)\n",
    "\n",
    "        # not a core sample\n",
    "        else:\n",
    "            # mark it as an outlier if its not in the neighborhood of a core sample\n",
    "            if len(core_samples) == 0:\n",
    "                cluster_ids[check_voxel] = OUTLIER\n",
    "            # otherwise its part of the cluster\n",
    "            else:\n",
    "                cluster_ids[check_voxel] = cluster_id\n",
    "                \n",
    "    return core_samples\n",
    "\n",
    "                \n",
    "def cross_participant_clustering(cluster_ids, cluster_id, subject_i, voxel_id, eps, min_neighbors):\n",
    "    core_samples = []\n",
    "    fringe = deque([(subject_i, voxel_id)])\n",
    "    fringe_hash = set([(subject_i, voxel_id)])\n",
    "\n",
    "    while len(fringe) > 0:\n",
    "        check_subject, check_voxel = fringe.popleft()\n",
    "        fringe_hash.remove((check_subject, check_voxel))\n",
    "        \n",
    "        nn_dists = [nn_results[subject_j, check_subject][0][check_voxel] for subject_j in range(8)]\n",
    "        nn_ids = [nn_results[subject_j, check_subject][1][check_voxel] for subject_j in range(8)]\n",
    "\n",
    "        nn_masks = [dists < eps for dists in nn_dists]\n",
    "        nn_dists = [dists[mask] for dists, mask in zip(nn_dists, nn_masks)]\n",
    "        nn_ids = [ids[mask] for ids, mask in zip(nn_ids, nn_masks)]\n",
    "\n",
    "        # number of subjects that have at least one\n",
    "        neighbor_counts = [d.shape[0] for d in nn_dists]\n",
    "        \n",
    "        neighbor_counts[check_subject] = 0\n",
    "        min_pts = np.sum(np.array(neighbor_counts) != 0)\n",
    "\n",
    "        # its a core sample\n",
    "        if min_pts >= min_neighbors:\n",
    "\n",
    "            # add to core samples list and cluster array\n",
    "            core_samples.append((check_subject, check_voxel))\n",
    "            cluster_ids[check_subject][check_voxel] = cluster_id\n",
    "\n",
    "            # add neighbors to the fringe\n",
    "            for subject_j, neighbor_ids in enumerate(nn_ids):\n",
    "                for neighbor_id in neighbor_ids:\n",
    "\n",
    "                    # dont add the point we are checking\n",
    "                    if subject_j == check_subject and neighbor_id == check_voxel:\n",
    "                        continue\n",
    "                    # dont add points already in fringe\n",
    "                    if (subject_j, neighbor_id) in fringe_hash:\n",
    "                        continue\n",
    "                    # dont add points already in this cluster\n",
    "                    if cluster_ids[subject_j][neighbor_id] == cluster_id:\n",
    "                        continue\n",
    "                        \n",
    "                    fringe.append((subject_j, neighbor_id))\n",
    "                    fringe_hash.add((subject_j, neighbor_id))\n",
    "\n",
    "        # not a core sample\n",
    "        else:\n",
    "\n",
    "            # mark it as an outlier if its not in the neighborhood of a core sample\n",
    "            if len(core_samples) == 0:\n",
    "                cluster_ids[check_subject][check_voxel] = OUTLIER\n",
    "            # otherwise its part of the cluster\n",
    "            else:\n",
    "                cluster_ids[check_subject][check_voxel] = cluster_id\n",
    "    return core_samples\n",
    "\n",
    "def dbscan_expansion(core_samples_all, cluster_ids, min_neighbors, eps):\n",
    "    expanded_clusters_all = []\n",
    "    for subject_id in range(8):\n",
    "        expanded_clusters = np.full((len(core_samples_all), W[subject_id].shape[0]), UNEXPLORED)\n",
    "        for c_id in range(len(core_samples_all)):\n",
    "            cluster_mask = cluster_ids[subject_id] == c_id\n",
    "            if not np.any(cluster_mask):\n",
    "                expanded_clusters[c_id] = 0\n",
    "                continue\n",
    "            voxel_ids = list(np.where(cluster_mask)[0])\n",
    "            within_participant_clustering(expanded_clusters[c_id], 1, subject_id, voxel_ids, expansion_eps, expansion_min_neighbors)\n",
    "            expanded_clusters[c_id][expanded_clusters[c_id] != 1] = 0\n",
    "            expanded_clusters[c_id][cluster_mask] = 1\n",
    "        expanded_clusters_all.append(expanded_clusters)\n",
    "    return expanded_clusters_all\n",
    "\n",
    "def distance_expansion(core_samples_all, cluster_ids, eps):\n",
    "    expanded_clusters_all = []\n",
    "    for subject_id in range(8):\n",
    "        expanded_clusters = np.full((len(core_samples_all), W[subject_id].shape[0]), UNEXPLORED)\n",
    "        for c_id in range(len(core_samples_all)):\n",
    "            cluster_mask = cluster_ids[subject_id] == c_id\n",
    "            if not np.any(cluster_mask):\n",
    "                expanded_clusters[c_id] = 0\n",
    "                continue\n",
    "            \n",
    "            voxel_ids = np.where(cluster_mask)[0]\n",
    "            nn_dists = np.concatenate([nn_results[subject_id, subject_id][0][voxel_id] for voxel_id in voxel_ids])\n",
    "            nn_ids = np.concatenate([nn_results[subject_id, subject_id][1][voxel_id] for voxel_id in voxel_ids])\n",
    "            \n",
    "            dist_mask = nn_dists < eps\n",
    "            nn_ids = nn_ids[dist_mask]\n",
    "            expanded_clusters[c_id][nn_ids] = 1\n",
    "                        \n",
    "            expanded_clusters[c_id][expanded_clusters[c_id] != 1] = 0\n",
    "            expanded_clusters[c_id][cluster_mask] = 1\n",
    "            \n",
    "        expanded_clusters_all.append(expanded_clusters)\n",
    "    return expanded_clusters_all\n",
    "\n",
    "def modified_dbscan(min_neighbors, eps, expansion_eps):\n",
    "    # initialize cluster_id maps as -2 (unexplored)\n",
    "    cluster_ids = [np.full(W[subject_id].shape[0], UNEXPLORED) for subject_id in range(8)]\n",
    "    cluster_id = 0\n",
    "    core_samples_all = []\n",
    "    \n",
    "    for subject_i in range(8):\n",
    "        for voxel_id in range(W[subject_i].shape[0]):\n",
    "            if cluster_ids[subject_i][voxel_id] != UNEXPLORED:\n",
    "                continue\n",
    "                \n",
    "            core_samples = cross_participant_clustering(cluster_ids, cluster_id, subject_i, voxel_id, eps, min_neighbors)\n",
    "\n",
    "            if len(core_samples) > 0:\n",
    "                core_samples_all.append(np.array(core_samples))\n",
    "                #print(core_samples)\n",
    "                cluster_id += 1\n",
    "    \n",
    "    expanded_clusters_all = distance_expansion(core_samples_all, cluster_ids, expansion_eps)\n",
    "            \n",
    "    return cluster_ids, core_samples_all, expanded_clusters_all\n",
    "\n",
    "def save_results(cluster_name, min_neighbors, eps, reruns_mode, cluster_ids, core_samples_all, expanded_clusters, expansion_suffix):\n",
    "    root_path =  nsd_path / f'derivatives/figures/concept_maps_voxel_v5/{tag}/'\n",
    "    root_path.mkdir(exist_ok=True, parents=True)\n",
    "    out_path = nsd_path / f'derivatives/figures/concept_maps_voxel_v5/{tag}/{cluster_name}'\n",
    "    out_path.mkdir(exist_ok=True, parents=True)\n",
    "\n",
    "    params = {'min_neighbors': min_neighbors, 'eps': eps, 'reruns_mode': reruns_mode}\n",
    "    with open(out_path / 'params.json', 'w') as f:\n",
    "        f.write(json.dumps(params))\n",
    "\n",
    "    if reruns_mode == 'multiple':\n",
    "        for subj_id, subject_name in enumerate(subjects):\n",
    "            c_ids = np.stack(np.split(cluster_ids[subj_id] + 1, num_models))\n",
    "            c_mask = F.one_hot(torch.from_numpy(c_ids).long(), num_classes=len(core_samples_all) + 1).float().numpy()[:, :, 1:].mean(axis=0)\n",
    "            \n",
    "            c_mask_expanded = np.stack(np.split(expanded_clusters[subj_id], num_models, axis=1)).mean(axis=0)\n",
    "            \n",
    "            subject_path = out_path / subject_name\n",
    "            subject_path.mkdir(exist_ok=True, parents=True)\n",
    "            np.save(subject_path / 'mask.npy', c_mask.T)\n",
    "            np.save(subject_path / f'mask{expansion_suffix}.npy', c_mask_expanded)\n",
    "    else:\n",
    "        #W_concat = np.concatenate([w for w in W])\n",
    "        for subj_id, subject_name in enumerate(subjects):\n",
    "            c_ids = np.stack(np.split(cluster_ids[subj_id] + 1, num_models))\n",
    "            c_mask = F.one_hot(torch.from_numpy(c_ids).long(), num_classes=len(core_samples_all) + 1).float().numpy()[:, :, 1:].mean(axis=0)\n",
    "            subject_path = out_path / subject_name\n",
    "            subject_path.mkdir(exist_ok=True, parents=True)\n",
    "            print(c_mask.T.shape)\n",
    "            np.save(subject_path / 'mask.npy', c_mask.T)\n",
    "            np.save(subject_path / f'mask{expansion_suffix}.npy', expanded_clusters[subj_id])\n",
    "            \n",
    "        subject_id = np.concatenate([np.full((w.shape[0],), i, int) for i, w in enumerate(W)])\n",
    "        W_t = np.concatenate(cluster_ids)\n",
    "\n",
    "        #np.save(root_path / f'{tag}__W.npy', W_concat)\n",
    "        #np.save(root_path / f'{tag}__subject_id.npy', subject_id)\n",
    "        np.save(out_path / f'{tag}__clusters.npy', W_t)\n",
    "\n",
    "        for subj_id, subject_name in enumerate(subjects):\n",
    "            subject_mask = subject_id == subj_id\n",
    "            #np.save(root_path / f'{tag}__W__{subject_name}.npy', W_concat[subject_mask])\n",
    "            np.save(out_path / f'{tag}__clusters__{subject_name}.npy', W_t[subject_mask])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run modified DBSCAN clustering algorithm for desired min_neighbors and eps combinations\n",
    "\n",
    "for min_neighbors in (1, 2, 3, 4):\n",
    "    for i, eps in enumerate((0.55, 0.6, 0.65, 0.5, 0.7, 0.45)):\n",
    "        if min_neighbors != 3:\n",
    "            continue\n",
    "        print(f'{min_neighbors=}, {eps=}')\n",
    "        cluster_ids, core_samples_all, expanded_clusters_all = modified_dbscan(min_neighbors, eps, expansion_eps=min(eps + 0.05, 0.65))\n",
    "        save_results(\n",
    "            f'num_models-{num_models}/min_neighbors-{min_neighbors}/run-{i}', min_neighbors, eps, reruns_mode, cluster_ids, core_samples_all, expanded_clusters_all, \n",
    "            expansion_suffix='_expanded'\n",
    "        )\n",
    "        \n",
    "        for i, core_samples in enumerate(core_samples_all):\n",
    "            print('cluster_id', i)\n",
    "            print([((i == c_ids).sum()) for subject_id, c_ids in enumerate(cluster_ids)])\n",
    "            print([expanded_clusters_all[subject_id][i].sum() for subject_id in range(8)])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
