{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c6fe08d-6a9b-49e1-9119-cea36d24d5ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from jax import random\n",
    "\n",
    "import jax\n",
    "jax.config.update(\"jax_enable_x64\", True)\n",
    "\n",
    "import jax.numpy as jnp\n",
    "\n",
    "from src import data_utils\n",
    "from src.moment_utils import get_dimensionality_avg\n",
    "\n",
    "#DATASET_ROOT = './'\n",
    "\n",
    "#DATASET_ROOT = '/Users/acanatar/datasets/'\n",
    "DATASET_ROOT = '/mnt/ceph/users/acanatar/'\n",
    "\n",
    "#ACTIVATIONS_ROOT = os.path.join(DATASET_ROOT, 'TVSD', 'TVSD_activations')\n",
    "#RESULTS_ROOT = os.path.join(DATASET_ROOT, 'TVSD', 'results')\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import jax.numpy as jnp\n",
    "import jax\n",
    "from functools import partial\n",
    "from jax import random\n",
    "plt.rcParams.update({'font.size': 12})\n",
    "import matplotlib\n",
    "matplotlib.rcParams['pdf.fonttype']=42\n",
    "\n",
    "\n",
    "def gett_all(ijlm: str, A, B):\n",
    "    i, j, l, m = list(ijlm)\n",
    "    pexp = i+'a,'+j+'a,'+l+'b,'+m+'b->'\n",
    "    qexp = i+'a,'+j+'a,'+l+'a,'+m+'a->'\n",
    "    pval = jnp.einsum(pexp, A, B, A, B)\n",
    "    pqval = pval - jnp.einsum(qexp, A, B, A, B)\n",
    "    return pval, pqval\n",
    "\n",
    "@jax.jit\n",
    "def estimate_dimensionality_no_centering(A, B):\n",
    "    \"\"\"\n",
    "    Estimate the dimensionality of the population A.\n",
    "    B is either the same as A or a different trial of the same population.\n",
    "    \"\"\"\n",
    "    P, Q = jnp.shape(A)\n",
    "\n",
    "    nf = (P * Q) ** 0.5\n",
    "\n",
    "    t1, t1d = gett_all('iijj', A/nf, B/nf)\n",
    "    t3, t3d = gett_all('ijij', A/nf, B/nf)\n",
    "    t4, t4d = gett_all('iiii', A/nf, B/nf)\n",
    "\n",
    "    numer_n = t1\n",
    "    numer_s = P/(P-1) * (t1 - t4)\n",
    "    numer_s_col = Q/(Q-1) * t1d\n",
    "    numer_d = P/(P-1) * Q/(Q-1) * (t1d - t4d)\n",
    "\n",
    "    denom_n = t3\n",
    "    denom_s = P/(P-1) * (t3 - t4)\n",
    "    denom_s_col = Q/(Q-1) * t3d\n",
    "    denom_d = P/(P-1) * Q/(Q-1) * (t3d - t4d)\n",
    "\n",
    "    naive = [numer_n, denom_n]\n",
    "    row_exp = [numer_s,  denom_s]\n",
    "    col_exp = [numer_s_col, denom_s_col]\n",
    "    double_exp = [numer_d, denom_d]\n",
    "\n",
    "    return [naive, row_exp, col_exp, double_exp]\n",
    "\n",
    "def twoNN_intrinsic_dimension(X, subsample=None, key=None):\n",
    "    \"\"\"\n",
    "    Estimate the intrinsic dimension of a dataset using the Two Nearest Neighbors (TwoNN) algorithm.\n",
    "    \n",
    "    Args:\n",
    "        X (jnp.ndarray): Input data of shape (N, D) where N is the number of samples\n",
    "                         and D is the number of features.\n",
    "        subsample (int, optional): If provided and N is large, randomly choose 'subsample' points\n",
    "                                   from X for the estimation.\n",
    "        key (jax.random.PRNGKey, optional): PRNG key for random sampling (required if subsample is not None).\n",
    "    \n",
    "    Returns:\n",
    "        float: Estimated intrinsic dimension using the TwoNN method.\n",
    "    \"\"\"\n",
    "    N, D = X.shape\n",
    "\n",
    "    # Possibly subsample the dataset if requested\n",
    "    if subsample is not None and subsample < N:\n",
    "        if key is None:\n",
    "            raise ValueError(\"A jax.random.PRNGKey must be provided when subsample is used.\")\n",
    "        indices = jax.random.choice(key, N, shape=(subsample,), replace=False)\n",
    "        X = X[indices]\n",
    "        N = subsample\n",
    "\n",
    "    # Compute the pairwise distance matrix (squared distances)\n",
    "    # shape: (N, N)\n",
    "    # We add a small epsilon to avoid exact zeros due to floating precision (if any).\n",
    "    dist_sq = jnp.sum((X[:, None, :] - X[None, :, :])**2, axis=-1)\n",
    "\n",
    "    # Get actual distances\n",
    "    distances = jnp.sqrt(dist_sq + 1e-12)\n",
    "\n",
    "    # Sort the distances for each point\n",
    "    # sorted_distances[i, :] will have distances in ascending order for point i\n",
    "    sorted_distances = jnp.sort(distances, axis=1)\n",
    "\n",
    "    # The 0th item in each row is the distance to itself (0), so:\n",
    "    # nearest neighbor distance = sorted_distances[:, 1]\n",
    "    # second nearest neighbor distance = sorted_distances[:, 2]\n",
    "    r1 = sorted_distances[:, 1]\n",
    "    r2 = sorted_distances[:, 2]\n",
    "\n",
    "    # Compute ratio mu = r2 / r1\n",
    "    # (Add small epsilon to r1 to avoid division by zero if there are duplicates)\n",
    "    mu = r2 / (r1 + 1e-12)\n",
    "\n",
    "    # Intrinsic dimension estimate\n",
    "    # d = 1 / mean(log(mu))\n",
    "    log_mu = jnp.log(mu)\n",
    "    d_est = 1.0 / jnp.mean(log_mu)\n",
    "\n",
    "    return d_est\n",
    "\n",
    "\n",
    "###\n",
    "\n",
    "@jax.jit\n",
    "def find_neighbors_naive_jax_jitted(X, R, chosen_indices):\n",
    "    \"\"\"\n",
    "    Computes, for each index in `chosen_indices`, which rows in X lie\n",
    "    within radius R of X[index]. Returns a boolean mask of shape (S, P),\n",
    "    the array of neighbor counts, and the max neighbor count.\n",
    "    \n",
    "    X: jnp.array of shape (P, Q)\n",
    "    R: float (radius)\n",
    "    chosen_indices: jnp.array of shape (S,) - the row indices chosen without replacement\n",
    "    \n",
    "    Returns:\n",
    "       neighbors_mask: boolean array of shape (S, P) \n",
    "                       neighbors_mask[i, j] = True if sample j is within R of chosen_indices[i]\n",
    "       neighbors_count: int32 array of shape (S,) \n",
    "                       neighbors_count[i] = number of True entries in neighbors_mask[i]\n",
    "       max_neighbors: a scalar int32, the maximum neighbor count across i\n",
    "    \"\"\"\n",
    "    R2 = R * R\n",
    "\n",
    "    def compute_neighbors_mask(idx):\n",
    "        # For a single index, compute distance^2 to all points\n",
    "        dist_sq = jnp.mean((X - X[idx])**2, axis=1)   # shape (P,)\n",
    "        # Boolean mask: whether each point is within radius R\n",
    "        within_mask = dist_sq <= R2\n",
    "        # Count how many True in this mask\n",
    "        count = jnp.sum(within_mask)\n",
    "        return within_mask, count\n",
    "\n",
    "    # Vectorize (vmap) over each chosen index\n",
    "    neighbors_mask, neighbors_count = jax.vmap(compute_neighbors_mask)(chosen_indices)\n",
    "    # neighbors_mask: shape (S, P)\n",
    "    # neighbors_count: shape (S,)\n",
    "\n",
    "    max_neighbors = jnp.max(neighbors_count)\n",
    "\n",
    "    return neighbors_mask, neighbors_count, max_neighbors\n",
    "\n",
    "\n",
    "def find_neighbors_naive_jax(X, R, S, key):\n",
    "    \"\"\"\n",
    "    Wrapper function that:\n",
    "    1) Samples S distinct indices from X\n",
    "    2) Calls the JIT-compiled routine to compute the neighbor masks\n",
    "    3) Converts the neighbor masks to a Python list of jnp-arrays (neighbors_list)\n",
    "    4) Prints 'no data' if max_neighbors == 0\n",
    "    5) Returns chosen_indices, neighbors_list, max_neighbors\n",
    "\n",
    "    X: jax.numpy array of shape (P, Q)\n",
    "    R: float, radius\n",
    "    S: int, number of points to subsample\n",
    "    key: jax.random.PRNGKey\n",
    "    \"\"\"\n",
    "\n",
    "    P, _ = X.shape\n",
    "\n",
    "    # 1) Sample S distinct indices\n",
    "    #chosen_indices = jax.random.choice(key, P, shape=(S,), replace=False)\n",
    "    chosen_indices = jnp.arange(S)\n",
    "    \n",
    "    # 2) Call the JIT-compiled function\n",
    "    neighbors_mask, neighbors_count, max_neighbors = find_neighbors_naive_jax_jitted(\n",
    "        X, R, chosen_indices\n",
    "    )\n",
    "\n",
    "    # 3) Convert the boolean masks to a Python list of neighbor index arrays\n",
    "    #    Because each row i in neighbors_mask is shape (P,)\n",
    "    #    we find the where() indices for all True values.\n",
    "    #    This is done in Python space, so it won't be JIT-compiled (which is fine).\n",
    "    neighbors_list = []\n",
    "    for i in range(S):\n",
    "        # Grab all j where neighbors_mask[i, j] == True\n",
    "        these_neighbors = jnp.where(neighbors_mask[i])[0]\n",
    "        # If you want to skip single-point or zero neighbors, you can do:\n",
    "        if neighbors_count[i] < 4:\n",
    "            continue\n",
    "        neighbors_list.append(these_neighbors)\n",
    "\n",
    "    return chosen_indices, neighbors_list, max_neighbors\n",
    "\n",
    "\n",
    "\n",
    "def the_machine(A,B,nf):\n",
    "    r1, r1d = gett_all('ijji', A/nf, B/nf)\n",
    "    r2, r2d = gett_all('iiii', A/nf, B/nf)\n",
    "    r3, r3d = gett_all('ijjj', A/nf, B/nf)\n",
    "    # t4,t4d = gett_all('iiij', A/nf)#<\n",
    "    r5, r5d = gett_all('ijjl', A/nf, B/nf)\n",
    "    r6, r6d = gett_all('iijj', A/nf, B/nf)\n",
    "    r7, r7d = gett_all('iijl', A/nf, B/nf)\n",
    "    # t8,t8d = gett_all('ijll', A/nf)#<\n",
    "    r9, r9d = gett_all('ijlm', A/nf, B/nf)\n",
    "\n",
    "    t1_both = r6d - r2d\n",
    "    t2_both = r7d - 2*r3d - r6d + 2*r2d\n",
    "    t3_both = r1d - r2d\n",
    "    t4_both = r5d - 2*r3d - r1d + 2*r2d\n",
    "    t5_both = r9d - 2*(r7d + 2*r5d) + r6d + 8*r3d+ 2*r1d - 6*r2d\n",
    "    \n",
    "    t1_row = r6 - r2\n",
    "    t2_row = r7 - 2*r3 - r6 + 2*r2\n",
    "    t3_row = r1 - r2\n",
    "    t4_row = r5 - 2*r3 - r1 + 2*r2\n",
    "    t5_row = r9 - 2*(r7 + 2*r5) + r6 + 8*r3 + 2*r1 - 6*r2\n",
    "\n",
    "    t1_col = r6d\n",
    "    t2_col = r7d\n",
    "    t3_col = r1d\n",
    "    t4_col = r5d\n",
    "    t5_col = r9d\n",
    "\n",
    "    t1_naive = r6\n",
    "    t2_naive = r7\n",
    "    t3_naive = r1\n",
    "    t4_naive = r5\n",
    "    t5_naive = r9\n",
    "    \n",
    "    return  [t1_both,t2_both,t3_both,t4_both,t5_both,t1_row,t2_row,t3_row,t4_row,t5_row,t1_col,t2_col,t3_col,t4_col,t5_col,t1_naive,t2_naive,t3_naive,t4_naive,t5_naive]\n",
    "\n",
    "\n",
    "def estimate_weighted_dimensionality(A, B, w):\n",
    "    \"\"\"\n",
    "    Estimate the dimensionality of the population A.\n",
    "    B is either the same as A or a different trial of the same population.\n",
    "    \"\"\"\n",
    "    P, Q = A.shape\n",
    "\n",
    "    Pr=jnp.square(jnp.sum(w))/jnp.sum(jnp.square(w))\n",
    "    nf = (Pr * Q) ** 0.5\n",
    "\n",
    "    W=jnp.tile(w[:,None],(Q,))\n",
    "    [t1_both,t2_both,t3_both,t4_both,t5_both,t1_row,t2_row,t3_row,t4_row,t5_row,t1_col,t2_col,t3_col,t4_col,t5_col,t1_naive,t2_naive,t3_naive,t4_naive,t5_naive] = [est/norm for est,norm in zip(the_machine(A*w[:,None],B*w[:,None],nf), the_machine(W,W,nf))]\n",
    "\n",
    "    numer_d = t1_both - 2*t2_both + t5_both\n",
    "    denom_d = t3_both - 2*t4_both + t5_both\n",
    "    \n",
    "    numer_r = t1_row - 2*t2_row + t5_row\n",
    "    denom_r = t3_row - 2*t4_row + t5_row\n",
    "    \n",
    "    numer_c = t1_col - 2*t2_col + t5_col\n",
    "    denom_c = t3_col - 2*t4_col + t5_col\n",
    "    \n",
    "    numer_n = t1_naive - 2*t2_naive + t5_naive\n",
    "    denom_n = t3_naive - 2*t4_naive + t5_naive\n",
    "    \n",
    "    naive = [numer_n, denom_n]\n",
    "    row_exp = [numer_r,  denom_r]\n",
    "    col_exp = [numer_c, denom_c]\n",
    "    double_exp = [numer_d, denom_d]\n",
    "    \n",
    "    return [naive, row_exp, col_exp, double_exp]\n",
    "    \n",
    "\n",
    "def local_mahalanobis_distances(X0: jnp.ndarray,\n",
    "                                X: jnp.ndarray,\n",
    "                                k: int | None = None,\n",
    "                                ridge: float = 1e-6) -> jnp.ndarray:\n",
    "    \"\"\"\n",
    "    Compute Mahalanobis distances from each x0 in X0 to all points in X,\n",
    "    where the metric for each x0 is the inverse (pseudoinverse) of the\n",
    "    local covariance estimated from the k nearest neighbors of x0 in X.\n",
    "\n",
    "    Args:\n",
    "        X0: (P0, Q) array of query points (centers).\n",
    "        X:  (P,  Q) array of dataset points.\n",
    "        k:  number of neighbors to estimate local covariance. If None,\n",
    "            defaults to min(P, max(50, 3*Q)) to make the covariance well-conditioned.\n",
    "        ridge: small nonnegative value added to the diagonal of the covariance\n",
    "               before pseudoinverse for numerical stability.\n",
    "\n",
    "    Returns:\n",
    "        D: (P0, P) array where D[a, i] = Mahalanobis distance between X[i] and X0[a]\n",
    "           using the local metric estimated at X0[a].\n",
    "    \"\"\"\n",
    "    P, Q = X.shape\n",
    "    P0 = X0.shape[0]\n",
    "    if X0.shape[1] != Q:\n",
    "        raise ValueError(f\"X0 has Q={X0.shape[1]} dims but X has Q={Q} dims.\")\n",
    "\n",
    "    if k is None:\n",
    "        k = int(jnp.minimum(P, jnp.maximum(50, 3 * Q)))\n",
    "    if k < 2:\n",
    "        raise ValueError(\"k must be at least 2 to form a covariance.\")\n",
    "    if k > P:\n",
    "        raise ValueError(f\"k={k} exceeds number of points P={P} in X.\")\n",
    "\n",
    "    # Pairwise squared Euclidean distances between X0 and X for neighbor selection\n",
    "    def pairwise_sqdist(A, B):\n",
    "        # A: (n, Q), B: (m, Q)\n",
    "        A2 = jnp.sum(A * A, axis=1, keepdims=True)       # (n, 1)\n",
    "        B2 = jnp.sum(B * B, axis=1, keepdims=True).T     # (1, m)\n",
    "        return A2 + B2 - 2.0 * (A @ B.T)                 # (n, m)\n",
    "\n",
    "    sqd = pairwise_sqdist(X0, X)                         # (P0, P)\n",
    "    nn_idx = jnp.argsort(sqd, axis=1)[:, :k]             # (P0, k), indices of kNN in X\n",
    "\n",
    "    # Helper: for one x0 row and its neighbor indices, build precision M and distances to all X\n",
    "    def per_center(x0, idx_row):\n",
    "        # Gather neighbors (k, Q)\n",
    "        nbrs = X[idx_row, :]                              # (k, Q)\n",
    "        # Weighted/unweighted mean (unweighted here)\n",
    "        mu = jnp.mean(nbrs, axis=0, keepdims=True)        # (1, Q)\n",
    "        Z = nbrs - mu                                     # (k, Q)\n",
    "\n",
    "        # Sample covariance (Q, Q); divide by (k-1) to be unbiased\n",
    "        cov = (Z.T @ Z) / (k - 1)                         # (Q, Q)\n",
    "        cov = cov + ridge * jnp.eye(Q, dtype=X.dtype)\n",
    "        #cov = jnp.eye(Q, dtype=X.dtype)\n",
    "\n",
    "        # Precision via pseudoinverse for stability (handles rank-deficient local covariance)\n",
    "        prec = jnp.linalg.pinv(cov)                       # (Q, Q)\n",
    "\n",
    "        # Distances from x0 to all X under this local metric\n",
    "        Δ = X - x0                                        # (P, Q)\n",
    "        # d^2 = diag( Δ @ prec @ Δ^T )\n",
    "        d2 = jnp.einsum('iq,qr,ir->i', Δ, prec, Δ)        # (P,)\n",
    "        # Ensure numerical nonnegativity before sqrt\n",
    "        d2 = jnp.clip(d2, a_min=0.0)\n",
    "        return jnp.sqrt(d2)                               # (P,)\n",
    "\n",
    "    # Vectorize over centers\n",
    "    D = jax.vmap(per_center, in_axes=(0, 0), out_axes=0)(X0, nn_idx)  # (P0, P)\n",
    "    return D\n",
    "\n",
    "\n",
    "def getLocalPR3(Theta1,Theta2,key,rad=0.07,SP=43,batch_size=10,iter_cap=100):\n",
    "    \n",
    "    P=jnp.shape(Theta1)[0]\n",
    "    Q=jnp.shape(Theta1)[1]\n",
    "\n",
    "    Tmean=(Theta1+Theta2)/2\n",
    "    #Tmean=Tmean-jnp.mean(Tmean,axis=0)[None,:]\n",
    "    #Tmean=Tmean/jnp.sqrt(jnp.sum(jnp.square(Tmean),axis=1))[:,None]\n",
    "    Theta1_n=Theta1/jnp.sqrt(jnp.sum(jnp.square(Theta1),axis=1))[:,None]\n",
    "    Theta2_n=Theta2/jnp.sqrt(jnp.sum(jnp.square(Theta2),axis=1))[:,None]\n",
    "\n",
    "    idx=jnp.arange(P)\n",
    "    sidx=random.choice(key, idx,(SP,),replace=False)\n",
    "    \n",
    "    nit=int(jnp.ceil(SP/batch_size))\n",
    "    binned=[]\n",
    "    for i in range(nit):\n",
    "        binned.append(sidx[i*batch_size:(i+1)*batch_size])\n",
    "    \n",
    "    estvals=[]\n",
    "    c=0\n",
    "    donc=False\n",
    "    for binv in binned:\n",
    "        #sqdist=jnp.sqrt(jnp.mean(jnp.square(Tmean[binv,None,:]-Tmean[None,:,:]),axis=-1) )\n",
    "        #sqdist = jnp.sqrt(2-2*jnp.matmul(Theta1_n[binv,:],Theta2_n.T))\n",
    "\n",
    "        sqdist=local_mahalanobis_distances(Tmean[binv,:],Tmean)\n",
    "        #sqdist=local_mahalanobis_distances(Theta1[binv,:],Theta2)\n",
    "        \n",
    "        #masks = jnp.exp(-1/((rad**2) * Q) * sqdist) #* 1/(rad*jnp.sqrt(2*jnp.pi*Q))\n",
    "        masks=(sqdist<(jnp.sqrt(Q)*rad))*1.0\n",
    "        masks=masks[jnp.sum(masks,axis=1)>=10,:]\n",
    "        for mask in masks:\n",
    "            #thres=0.5\n",
    "            #count = jnp.sum(mask>thres)\n",
    "            count=jnp.sum(mask)\n",
    "            #if count<4:\n",
    "            #    continue\n",
    "            #mask=jnp.where(mask>thres,thres,0.0) #0.00000001\n",
    "\n",
    "\n",
    "            #meanv = (jnp.mean(Theta1[mask==1,:],axis=0)+jnp.mean(Theta2[mask==1,:],axis=0))/2\n",
    "            #Theta1_n = Theta1 - meanv[None,:]\n",
    "            #Theta2_n = Theta2 - meanv[None,:]\n",
    "            #Theta1_n = Theta1_n/(jnp.std(Theta1_n[mask==1,:],axis=0)[None,:])\n",
    "            #Theta2_n = Theta2_n/(jnp.std(Theta2_n[mask==1,:],axis=0)[None,:])\n",
    "            #Theta1_n = Theta1_n/(jnp.std(Theta1_n,axis=1)[:,None])\n",
    "            #Theta2_n = Theta2_n/(jnp.std(Theta2_n,axis=1)[:,None])\n",
    "            #estvals.append( estimate_weighted_dimensionality(Theta1_n,Theta2_n,mask) )\n",
    "\n",
    "            estvals.append( estimate_weighted_dimensionality(Theta1,Theta2,mask) )\n",
    "\n",
    "            #estvals.append(estimate_dimensionality_no_centering(Theta1[mask==1,:]-jnp.mean(Theta1[mask==1,:],axis=0), Theta2[mask==1,:]-jnp.mean(Theta2[mask==1,:],axis=0)))\n",
    "\n",
    "            \n",
    "            print('\\n{}'.format(count),end=' ')\n",
    "\n",
    "            if c==iter_cap:\n",
    "                donc=True\n",
    "                break\n",
    "            c=c+1\n",
    "        if donc:\n",
    "            break\n",
    "    \n",
    "    estvals=jnp.nanmean(jnp.array(estvals),axis=0)\n",
    "    if estvals.ndim==2:\n",
    "        ndim=estvals[0,0]/estvals[0,1]\n",
    "        ddim=estvals[-1,0]/estvals[-1,1]\n",
    "    else:\n",
    "        ndim=jnp.nan\n",
    "        ddim=jnp.nan\n",
    "    return ndim,ddim\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ab57148-2143-431a-b12b-c34c86586d65",
   "metadata": {},
   "outputs": [],
   "source": [
    "df, df_img = data_utils.get_stringer(DATASET_ROOT, avg_trials=False)\n",
    "\n",
    "data_stringer = []\n",
    "for rep_idx, rep_df in df.groupby('istim'):\n",
    "    data_stringer.append(rep_df.values)\n",
    "data_stringer = np.array(data_stringer).swapaxes(0, 1)\n",
    "print(data_stringer.shape)\n",
    "df.head()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a6b0309",
   "metadata": {},
   "outputs": [],
   "source": [
    "Phi0 = data_stringer.astype(np.float64)[0]\n",
    "Phi1 = data_stringer.astype(np.float64)[1]\n",
    "\n",
    "key = random.PRNGKey(50)\n",
    "\n",
    "(Pmax,Qmax)=Phi0.shape\n",
    "\n",
    "P=100\n",
    "Q=500\n",
    "\n",
    "key, key0 = random.split(key, 2)\n",
    "pidx = random.choice(key0,jnp.arange(Pmax),shape=(P,),replace=False)\n",
    "key, key0 = random.split(key, 2)\n",
    "qidx = random.choice(key0,jnp.arange(Qmax),shape=(Q,),replace=False)\n",
    "\n",
    "Phi0=Phi0-jnp.mean(Phi0,axis=0)[None,:]\n",
    "Phi0=Phi0/(jnp.std(Phi0,axis=0)[None,:])\n",
    "Phi1=Phi1-jnp.mean(Phi1,axis=0)[None,:]\n",
    "Phi1=Phi1/(jnp.std(Phi1,axis=0)[None,:])\n",
    "\n",
    "Theta=Phi0[pidx,:][:,qidx]\n",
    "\n",
    "\n",
    "fig,ax=plt.subplots(1,1,figsize=(10,5))\n",
    "ax.imshow(Phi0[pidx,:][:,qidx])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94d69bfe-340f-4cfc-bc7b-0809d6a2c1c3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "key = random.PRNGKey(2)\n",
    "\n",
    "(Pmax,Qmax)=Phi0.shape\n",
    "\n",
    "P=2000\n",
    "Q=200\n",
    "\n",
    "numits=40#\n",
    "\n",
    "#rads=np.linspace(1.0,1.4,10)\n",
    "rads=np.linspace(0.08,0.2,5)\n",
    "\n",
    "\n",
    "\n",
    "SP=P\n",
    "batch_size=200\n",
    "iter_cap=50\n",
    "\n",
    "destss=[]\n",
    "for j in range(numits): \n",
    "    key, key0 = random.split(key, 2)\n",
    "    pidx = random.choice(key0,jnp.arange(Pmax),shape=(P,),replace=False)\n",
    "    key, key0 = random.split(key, 2)\n",
    "    qidx = random.choice(key0,jnp.arange(Qmax),shape=(Q,),replace=False)\n",
    "\n",
    "    dests=[]\n",
    "    for ir,rad in enumerate(rads):\n",
    "        #results=estimate_dimensionality_no_centering(Theta1, Theta2, P,Q)\n",
    "        #nia = results[0][0]/results[0][1]\n",
    "        #dia = results[-1][0]/results[-1][1]\n",
    "\n",
    "\n",
    "        Theta1=jnp.array(Phi0[pidx,:][:,qidx])\n",
    "        Theta2=jnp.array(Phi1[pidx,:][:,qidx])\n",
    "\n",
    "        tv = twoNN_intrinsic_dimension(Theta1)\n",
    "        key_new, key = random.split(key, 2)\n",
    "        ln,lv = getLocalPR3(Theta1,Theta2,key,rad=rad,SP=SP,batch_size=batch_size,iter_cap=iter_cap)\n",
    "    \n",
    "        dests.append([tv,lv,ln]) #ln is naive\n",
    "\n",
    "        print((ir+1)/len(rads),j)\n",
    "    destss.append(dests)\n",
    "destss=np.array(destss)\n",
    "\n",
    "destss=destss.transpose((1,0,2))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3acf34fc-ca82-43a1-8e6c-8ef6ca8170e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting the noise correcting version:\n",
    "def re(x): \n",
    "    x=np.array(x)\n",
    "    return np.where(x>0,x,0)\n",
    "\n",
    "xoi=rads\n",
    "#tru=np.array([GT,GT,d,d])\n",
    "#=destss-tru[:,None]\n",
    "\n",
    "fig,ax=plt.subplots(1,1,figsize=(4,3.2))\n",
    "\n",
    "qoi=destss[:,:,0].flatten()\n",
    "\n",
    "yval=xoi*0+np.nanmean(qoi)\n",
    "lower_error=xoi*0+np.nanquantile(qoi,q=0.25)\n",
    "upper_error=xoi*0+np.nanquantile(qoi,q=0.75)\n",
    "asymmetric_error = re([(yval-lower_error), (upper_error-yval)])\n",
    "\n",
    "ax.plot(xoi, yval ,alpha=1, c='y',label='TwoNN')\n",
    "ax.fill_between(xoi, lower_error, upper_error ,alpha=0.5, facecolor='y')\n",
    "\n",
    "\n",
    "qoi=destss[:,:,1]\n",
    "\n",
    "yval=np.nanmean(qoi,axis=1)\n",
    "lower_error=np.nanquantile(qoi,axis=1,q=0.25)\n",
    "upper_error=np.nanquantile(qoi,axis=1,q=0.75)\n",
    "asymmetric_error = re([(yval-lower_error), (upper_error-yval)])\n",
    "ax.errorbar(xoi, yval, yerr=asymmetric_error,c='r',marker='o',ls='-',alpha=0.5,ms=3,label=r'$\\gamma^\\text{local}_\\text{both}(r)$')#, fmt='o')\n",
    "\n",
    "\n",
    "qoi=destss[:,:,2]\n",
    "\n",
    "yval=np.nanmean(qoi,axis=1)\n",
    "lower_error=np.nanquantile(qoi,axis=1,q=0.25)\n",
    "upper_error=np.nanquantile(qoi,axis=1,q=0.75)\n",
    "asymmetric_error = re([(yval-lower_error), (upper_error-yval)])\n",
    "ax.errorbar(xoi, yval, yerr=asymmetric_error,c='k',marker='o',ls='-',alpha=0.5,ms=3,label=r'$\\gamma^\\text{local}_\\text{naive}(r)$')#, fmt='o')\n",
    "\n",
    "\n",
    "ax.set_xlabel('r: Local ball radius')\n",
    "ax.set_ylabel('Local dimensionality')\n",
    "#ax.set_title('')\n",
    "#ax.set_ylim([-5,5])\n",
    "ax.legend()\n",
    "fig.tight_layout()\n",
    "#plt.savefig(\"stringer_local_dim.pdf\", format=\"pdf\", bbox_inches = 'tight')  \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0efab4b3-eff5-4b5b-abf4-3660f1817c16",
   "metadata": {},
   "outputs": [],
   "source": [
    "Tmean=(Theta1+Theta2)/2\n",
    "\n",
    "key = random.PRNGKey(50)\n",
    "\n",
    "rad=1.0\n",
    "\n",
    "B=200\n",
    "\n",
    "key, key0 = random.split(key, 2)\n",
    "binv = random.choice(key0,jnp.arange(jnp.shape(Theta1)[0]),shape=(B,),replace=False)\n",
    "\n",
    "\n",
    "sqdist=local_mahalanobis_distances(Tmean[binv,:],Tmean)\n",
    "#masks = jnp.exp(-1/((rad**2) * Q) * sqdist) #* 1/(rad*jnp.sqrt(2*jnp.pi*Q))\n",
    "masks=(sqdist<(jnp.sqrt(Q)*rad))*1.0\n",
    "masks=masks[jnp.sum(masks,axis=1)>=4,:]\n",
    "\n",
    "print(jnp.sum(masks,axis=1))\n",
    "\n",
    "idx=3\n",
    "\n",
    "mask=masks[idx,:]==1\n",
    "\n",
    "#print(np.shape(Theta1))\n",
    "\n",
    "Theta_a=Theta1[mask,:]\n",
    "\n",
    "meanv = (jnp.mean(Theta1[mask==1,:],axis=0)+jnp.mean(Theta2[mask==1,:],axis=0))/2\n",
    "Theta1_n = Theta1[mask,:] - meanv[None,:]\n",
    "Theta2_n = Theta2[mask,:] - meanv[None,:]\n",
    "Theta1_n = Theta1_n/(jnp.std(Theta1_n,axis=0)[None,:])\n",
    "Theta2_n = Theta2_n/(jnp.std(Theta2_n,axis=0)[None,:])\n",
    "\n",
    "\n",
    "fig,ax=plt.subplots(1,1,figsize=(10,5))\n",
    "ax.imshow((Theta1_n+Theta2_n)/2)\n",
    "\n",
    "\n",
    "\n",
    "K=1/jnp.shape(Theta1_n)[0] * jnp.matmul(Theta1_n,Theta2_n.T)\n",
    "              \n",
    "PR=jnp.square(jnp.trace(K))/jnp.trace(jnp.linalg.matrix_power(K,2))\n",
    "print(PR)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6385ca74-9992-4ba4-8301-9f71b4e015dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TVSD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17b2ba23",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import data_utils\n",
    "from src.moment_utils import get_dimensionality_avg\n",
    "\n",
    "\n",
    "regions = ['V1', 'V4', 'IT']\n",
    "subject_names = ['F', 'N']\n",
    "estimators = ['naive', 'row_exp', 'col_exp', 'double_exp']\n",
    "\n",
    "monkey_name = 'N'\n",
    "region = 'V1'\n",
    "\n",
    "df, df_img = data_utils.get_tvsd(DATASET_ROOT, monkey_name, region, get_stimuli=True)\n",
    "print(df.shape, df_img.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0d17096",
   "metadata": {},
   "outputs": [],
   "source": [
    "Phi = df.values.astype(np.float64)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bb9d2db",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "key = random.PRNGKey(0)\n",
    "\n",
    "(Pmax,Qmax)=Phi.shape\n",
    "\n",
    "P=2000\n",
    "Q=100\n",
    "\n",
    "numits=10#10#\n",
    "\n",
    "#rads=np.linspace(1.15,1.4,10)\n",
    "rads=np.linspace(1.08,1.4,10)\n",
    "\n",
    "#rads=np.linspace(0.5,0.8,5)\n",
    "\n",
    "\n",
    "\n",
    "SP=P\n",
    "batch_size=20\n",
    "iter_cap=50\n",
    "\n",
    "destss=[]\n",
    "for j in range(numits): \n",
    "    key, key0 = random.split(key, 2)\n",
    "    pidx = random.choice(key0,jnp.arange(Pmax),shape=(P,),replace=False)\n",
    "    key, key0 = random.split(key, 2)\n",
    "    qidx = random.choice(key0,jnp.arange(Qmax),shape=(Q,),replace=False)\n",
    "\n",
    "    dests=[]\n",
    "    for ir,rad in enumerate(rads):\n",
    "        #results=estimate_dimensionality_no_centering(Theta1, Theta2, P,Q)\n",
    "        #nia = results[0][0]/results[0][1]\n",
    "        #dia = results[-1][0]/results[-1][1]\n",
    "\n",
    "\n",
    "        Theta=jnp.array(Phi[pidx,:][:,qidx])\n",
    "\n",
    "        tv = twoNN_intrinsic_dimension(Theta)\n",
    "        key_new, key = random.split(key, 2)\n",
    "        ln,lv = getLocalPR3(Theta,Theta,key,rad=rad,SP=SP,batch_size=batch_size,iter_cap=iter_cap)\n",
    "    \n",
    "        dests.append([tv,lv,ln]) #ln is naive\n",
    "\n",
    "        print((ir+1)/len(rads),j)\n",
    "    destss.append(dests)\n",
    "destss=np.array(destss)\n",
    "\n",
    "destss=destss.transpose((1,0,2))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6eba627f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting the noise correcting version:\n",
    "def re(x): \n",
    "    x=np.array(x)\n",
    "    return np.where(x>0,x,0)\n",
    "\n",
    "xoi=rads\n",
    "#tru=np.array([GT,GT,d,d])\n",
    "#=destss-tru[:,None]\n",
    "\n",
    "fig,ax=plt.subplots(1,1,figsize=(4,3.6))\n",
    "\n",
    "qoi=destss[:,:,0].flatten()\n",
    "\n",
    "yval=xoi*0+np.nanmean(qoi)\n",
    "lower_error=xoi*0+np.nanquantile(qoi,q=0.25)\n",
    "upper_error=xoi*0+np.nanquantile(qoi,q=0.75)\n",
    "asymmetric_error = re([(yval-lower_error), (upper_error-yval)])\n",
    "\n",
    "ax.plot(xoi, yval, ls='--',alpha=1, c='y',label='TwoNN')\n",
    "ax.fill_between(xoi, lower_error, upper_error ,alpha=0.5, facecolor='y')\n",
    "\n",
    "\n",
    "qoi=destss[:,:,1]\n",
    "\n",
    "yval=np.nanmean(qoi,axis=1)\n",
    "lower_error=np.nanquantile(qoi,axis=1,q=0.25)\n",
    "upper_error=np.nanquantile(qoi,axis=1,q=0.75)\n",
    "asymmetric_error = re([(yval-lower_error), (upper_error-yval)])\n",
    "ax.errorbar(xoi, yval, yerr=asymmetric_error,c='r',marker='o',ls='-',alpha=0.5,ms=3,label=r'$\\gamma^\\text{local}_\\text{both}(r)$')#, fmt='o')\n",
    "\n",
    "\n",
    "qoi=destss[:,:,2]\n",
    "\n",
    "yval=np.nanmean(qoi,axis=1)\n",
    "lower_error=np.nanquantile(qoi,axis=1,q=0.25)\n",
    "upper_error=np.nanquantile(qoi,axis=1,q=0.75)\n",
    "asymmetric_error = re([(yval-lower_error), (upper_error-yval)])\n",
    "ax.errorbar(xoi, yval, yerr=asymmetric_error,c='k',marker='o',ls='-',alpha=0.5,ms=3,label=r'$\\gamma^\\text{local}_\\text{naive}(r)$')#, fmt='o')\n",
    "\n",
    "\n",
    "ax.set_xlabel('r: Local ball radius')\n",
    "ax.set_ylabel('Local dimensionality')\n",
    "#ax.set_title('')\n",
    "ax.set_ylim([-5,50])\n",
    "#ax.legend()\n",
    "#ax.legend(loc='upper left', bbox_to_anchor=(1.05, 1))\n",
    "ax.set_title('V1 - Electrodes (LFP) [Papale et al. (2025)] ')\n",
    "fig.tight_layout()\n",
    "plt.savefig(\"tvsd_local_dim.svg\", format=\"svg\", bbox_inches = 'tight')  \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e73b4666",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Freemanziemba"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf3c2dab",
   "metadata": {},
   "outputs": [],
   "source": [
    "regions = ['V1', 'V2']\n",
    "estimators = ['naive', 'row_exp', 'col_exp', 'double_exp']\n",
    "\n",
    "# Get a matrix of responses for each trial\n",
    "data_freemanziemba = {region: [] for region in regions}\n",
    "for region in regions:\n",
    "\n",
    "    df, _ = data_utils.get_brainscore(region, avg_trials=False)\n",
    "    print(region, df.shape)\n",
    "\n",
    "    for rep_idx, rep_df in df.groupby('repetition'):\n",
    "        if rep_df.shape[0] in [3200, 135]:\n",
    "            data_freemanziemba[region].append(rep_df.values)\n",
    "\n",
    "    data_freemanziemba[region] = np.array(data_freemanziemba[region])\n",
    "    print(data_freemanziemba[region].shape)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5983a816",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "roi='V2'\n",
    "\n",
    "key = random.PRNGKey(50)\n",
    "\n",
    "\n",
    "P=110\n",
    "Q=100\n",
    "\n",
    "numits=50#\n",
    "\n",
    "rads=np.linspace(1.0,1.4,5)\n",
    "#rads=np.linspace(0.5,0.8,5)\n",
    "\n",
    "Ntrials=np.shape(data_freemanziemba[roi])[0]\n",
    "\n",
    "SP=P\n",
    "batch_size=200\n",
    "iter_cap=50\n",
    "\n",
    "destss=[]\n",
    "for j in range(numits): \n",
    "\n",
    "    key, key0 = random.split(key, 2)\n",
    "    ids = random.choice(key0,jnp.arange(Ntrials),shape=(2,),replace=False)\n",
    "    Phi0=data_freemanziemba[roi][ids[0]]\n",
    "    Phi1=data_freemanziemba[roi][ids[1]]\n",
    "\n",
    "\n",
    "    (Pmax,Qmax)=Phi0.shape\n",
    "\n",
    "    key, key0 = random.split(key, 2)\n",
    "    pidx = random.choice(key0,jnp.arange(Pmax),shape=(P,),replace=False)\n",
    "    key, key0 = random.split(key, 2)\n",
    "    qidx = random.choice(key0,jnp.arange(Qmax),shape=(Q,),replace=False)\n",
    "\n",
    "    dests=[]\n",
    "    for ir,rad in enumerate(rads):\n",
    "        #results=estimate_dimensionality_no_centering(Theta1, Theta2, P,Q)\n",
    "        #nia = results[0][0]/results[0][1]\n",
    "        #dia = results[-1][0]/results[-1][1]\n",
    "\n",
    "\n",
    "        Theta1=jnp.array(Phi0[pidx,:][:,qidx])\n",
    "        Theta2=jnp.array(Phi1[pidx,:][:,qidx])\n",
    "\n",
    "        tv = twoNN_intrinsic_dimension(Theta1)\n",
    "        key_new, key = random.split(key, 2)\n",
    "        ln,lv = getLocalPR3(Theta1,Theta2,key,rad=rad,SP=SP,batch_size=batch_size,iter_cap=iter_cap)\n",
    "    \n",
    "        dests.append([tv,lv,ln]) #ln is naive\n",
    "\n",
    "        print((ir+1)/len(rads),j)\n",
    "    destss.append(dests)\n",
    "destss=np.array(destss)\n",
    "\n",
    "destss=destss.transpose((1,0,2))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a4142be",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting the noise correcting version:\n",
    "def re(x): \n",
    "    x=np.array(x)\n",
    "    return np.where(x>0,x,0)\n",
    "\n",
    "xoi=rads\n",
    "#tru=np.array([GT,GT,d,d])\n",
    "#=destss-tru[:,None]\n",
    "\n",
    "fig,ax=plt.subplots(1,1,figsize=(4,3.2))\n",
    "\n",
    "qoi=destss[:,:,0].flatten()\n",
    "\n",
    "yval=xoi*0+np.nanmean(qoi)\n",
    "lower_error=xoi*0+np.nanquantile(qoi,q=0.25)\n",
    "upper_error=xoi*0+np.nanquantile(qoi,q=0.75)\n",
    "asymmetric_error = re([(yval-lower_error), (upper_error-yval)])\n",
    "\n",
    "ax.plot(xoi, yval ,alpha=1, c='y',label='TwoNN')\n",
    "ax.fill_between(xoi, lower_error, upper_error ,alpha=0.5, facecolor='y')\n",
    "\n",
    "\n",
    "qoi=destss[:,:,1]\n",
    "\n",
    "yval=np.nanmean(qoi,axis=1)\n",
    "lower_error=np.nanquantile(qoi,axis=1,q=0.25)\n",
    "upper_error=np.nanquantile(qoi,axis=1,q=0.75)\n",
    "asymmetric_error = re([(yval-lower_error), (upper_error-yval)])\n",
    "ax.errorbar(xoi, yval, yerr=asymmetric_error,c='r',marker='o',ls='-',alpha=0.5,ms=3,label=r'$\\gamma^\\text{local}_\\text{both}(r)$')#, fmt='o')\n",
    "\n",
    "\n",
    "qoi=destss[:,:,2]\n",
    "\n",
    "yval=np.nanmean(qoi,axis=1)\n",
    "lower_error=np.nanquantile(qoi,axis=1,q=0.25)\n",
    "upper_error=np.nanquantile(qoi,axis=1,q=0.75)\n",
    "asymmetric_error = re([(yval-lower_error), (upper_error-yval)])\n",
    "ax.errorbar(xoi, yval, yerr=asymmetric_error,c='k',marker='o',ls='-',alpha=0.5,ms=3,label=r'$\\gamma^\\text{local}_\\text{naive}(r)$')#, fmt='o')\n",
    "\n",
    "\n",
    "ax.set_xlabel('r: Local ball radius')\n",
    "ax.set_ylabel('Local dimensionality')\n",
    "#ax.set_title('')\n",
    "ax.set_ylim([-5,5])\n",
    "ax.legend()\n",
    "fig.tight_layout()\n",
    "#plt.savefig(\"stringer_local_dim.pdf\", format=\"pdf\", bbox_inches = 'tight')  \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b57169a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# FMRI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48cbae53",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_rois = ['V1', 'V2', 'V3',\n",
    "            'hV4', 'VO1', 'VO2', 'TO1',\n",
    "            'lFFA', 'rFFA',\n",
    "            'lPPA', 'rPPA',\n",
    "            'lEBA', 'rEBA',\n",
    "            'IT',  'glasser-FEF']\n",
    "\n",
    "regions = ['V1', 'V2', 'V3', 'hV4', 'VO1', 'glasser-FEF', 'IT']\n",
    "subject_names = ['01', '02', '03']\n",
    "estimators = ['naive', 'row_exp', 'col_exp', 'double_exp']\n",
    "\n",
    "subject_name = '01'\n",
    "roi = 'V1'\n",
    "\n",
    "df, df_img = data_utils.get_things(DATASET_ROOT, subject_name, roi, get_stimuli=False)\n",
    "print(df.shape)\n",
    "\n",
    "Phi = df.values.astype(np.float64)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbdb6492",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "key = random.PRNGKey(0)\n",
    "\n",
    "(Pmax,Qmax)=Phi.shape\n",
    "\n",
    "P=3000\n",
    "Q=100\n",
    "\n",
    "numits=3#\n",
    "\n",
    "rads=np.linspace(1.1,1.4,6)\n",
    "#rads=np.linspace(0.5,0.8,5)\n",
    "\n",
    "\n",
    "\n",
    "SP=P\n",
    "batch_size=20\n",
    "iter_cap=100\n",
    "\n",
    "destss=[]\n",
    "for j in range(numits): \n",
    "    key, key0 = random.split(key, 2)\n",
    "    pidx = random.choice(key0,jnp.arange(Pmax),shape=(P,),replace=False)\n",
    "    key, key0 = random.split(key, 2)\n",
    "    qidx = random.choice(key0,jnp.arange(Qmax),shape=(Q,),replace=False)\n",
    "\n",
    "    dests=[]\n",
    "    for ir,rad in enumerate(rads):\n",
    "        #results=estimate_dimensionality_no_centering(Theta1, Theta2, P,Q)\n",
    "        #nia = results[0][0]/results[0][1]\n",
    "        #dia = results[-1][0]/results[-1][1]\n",
    "\n",
    "\n",
    "        Theta=jnp.array(Phi[pidx,:][:,qidx])\n",
    "\n",
    "        tv = twoNN_intrinsic_dimension(Theta)\n",
    "        key_new, key = random.split(key, 2)\n",
    "        ln,lv = getLocalPR3(Theta,Theta,key,rad=rad,SP=SP,batch_size=batch_size,iter_cap=iter_cap)\n",
    "    \n",
    "        dests.append([tv,lv,ln]) #ln is naive\n",
    "\n",
    "        print((ir+1)/len(rads),j)\n",
    "    destss.append(dests)\n",
    "destss=np.array(destss)\n",
    "\n",
    "destss=destss.transpose((1,0,2))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f12d5fcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting the noise correcting version:\n",
    "def re(x): \n",
    "    x=np.array(x)\n",
    "    return np.where(x>0,x,0)\n",
    "\n",
    "xoi=rads\n",
    "#tru=np.array([GT,GT,d,d])\n",
    "#=destss-tru[:,None]\n",
    "\n",
    "fig,ax=plt.subplots(1,1,figsize=(4,3.2))\n",
    "\n",
    "qoi=destss[:,:,0].flatten()\n",
    "\n",
    "yval=xoi*0+np.nanmean(qoi)\n",
    "lower_error=xoi*0+np.nanquantile(qoi,q=0.25)\n",
    "upper_error=xoi*0+np.nanquantile(qoi,q=0.75)\n",
    "asymmetric_error = re([(yval-lower_error), (upper_error-yval)])\n",
    "\n",
    "ax.plot(xoi, yval ,alpha=1, c='y',label='TwoNN')\n",
    "ax.fill_between(xoi, lower_error, upper_error ,alpha=0.5, facecolor='y')\n",
    "\n",
    "\n",
    "qoi=destss[:,:,1]\n",
    "\n",
    "yval=np.nanmean(qoi,axis=1)\n",
    "lower_error=np.nanquantile(qoi,axis=1,q=0.25)\n",
    "upper_error=np.nanquantile(qoi,axis=1,q=0.75)\n",
    "asymmetric_error = re([(yval-lower_error), (upper_error-yval)])\n",
    "ax.errorbar(xoi, yval, yerr=asymmetric_error,c='r',marker='o',ls='-',alpha=0.5,ms=3,label=r'$\\gamma^\\text{local}_\\text{both}(r)$')#, fmt='o')\n",
    "\n",
    "\n",
    "qoi=destss[:,:,2]\n",
    "\n",
    "yval=np.nanmean(qoi,axis=1)\n",
    "lower_error=np.nanquantile(qoi,axis=1,q=0.25)\n",
    "upper_error=np.nanquantile(qoi,axis=1,q=0.75)\n",
    "asymmetric_error = re([(yval-lower_error), (upper_error-yval)])\n",
    "ax.errorbar(xoi, yval, yerr=asymmetric_error,c='k',marker='o',ls='-',alpha=0.5,ms=3,label=r'$\\gamma^\\text{local}_\\text{naive}(r)$')#, fmt='o')\n",
    "\n",
    "\n",
    "ax.set_xlabel('r: Local ball radius')\n",
    "ax.set_ylabel('Local dimensionality')\n",
    "#ax.set_title('')\n",
    "#ax.set_ylim([-5,50])\n",
    "ax.legend()\n",
    "fig.tight_layout()\n",
    "#plt.savefig(\"stringer_local_dim.pdf\", format=\"pdf\", bbox_inches = 'tight')  \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b7d574e",
   "metadata": {},
   "outputs": [],
   "source": [
    "## MajajHong"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8f7a657",
   "metadata": {},
   "outputs": [],
   "source": [
    "regions = ['V4', 'IT']\n",
    "subject_names = ['Tito', 'Chabo']\n",
    "estimators = ['naive', 'row_exp', 'col_exp', 'double_exp']\n",
    "\n",
    "# Get a matrix of responses for each trial\n",
    "data_majajhong = {}\n",
    "for region in regions:\n",
    "    for subject in subject_names:\n",
    "\n",
    "        df, _ = data_utils.get_brainscore(region, subject_name=subject, avg_trials=False)\n",
    "        print((region, subject), df.shape)\n",
    "\n",
    "        data_majajhong[(region, subject)] = []\n",
    "        for rep_idx, rep_df in df.groupby('repetition'):\n",
    "            if rep_df.shape[0] in [3200, 135]:\n",
    "                data_majajhong[(region, subject)].append(rep_df.values)\n",
    "\n",
    "        data_majajhong[(region, subject)] = np.array(data_majajhong[(region, subject)])\n",
    "        print(data_majajhong[(region, subject)].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "207ebad8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "region='IT'\n",
    "subject = 'Chabo'\n",
    "                 \n",
    "key = random.PRNGKey(50)\n",
    "\n",
    "\n",
    "P=2000\n",
    "Q=50\n",
    "\n",
    "numits=2#\n",
    "\n",
    "rads=np.linspace(1.0,1.4,5)\n",
    "#rads=np.linspace(0.5,0.8,5)\n",
    "\n",
    "Ntrials=np.shape(data_majajhong[(region, subject)])[0]\n",
    "\n",
    "SP=P\n",
    "batch_size=200\n",
    "iter_cap=50\n",
    "\n",
    "destss=[]\n",
    "for j in range(numits): \n",
    "\n",
    "    key, key0 = random.split(key, 2)\n",
    "    ids = random.choice(key0,jnp.arange(Ntrials),shape=(2,),replace=False)\n",
    "    Phi0=data_majajhong[(region, subject)][ids[0]]\n",
    "    Phi1=data_majajhong[(region, subject)][ids[1]]\n",
    "\n",
    "\n",
    "    (Pmax,Qmax)=Phi0.shape\n",
    "\n",
    "    key, key0 = random.split(key, 2)\n",
    "    pidx = random.choice(key0,jnp.arange(Pmax),shape=(P,),replace=False)\n",
    "    key, key0 = random.split(key, 2)\n",
    "    qidx = random.choice(key0,jnp.arange(Qmax),shape=(Q,),replace=False)\n",
    "\n",
    "    dests=[]\n",
    "    for ir,rad in enumerate(rads):\n",
    "        #results=estimate_dimensionality_no_centering(Theta1, Theta2, P,Q)\n",
    "        #nia = results[0][0]/results[0][1]\n",
    "        #dia = results[-1][0]/results[-1][1]\n",
    "\n",
    "\n",
    "        Theta1=jnp.array(Phi0[pidx,:][:,qidx])\n",
    "        Theta2=jnp.array(Phi1[pidx,:][:,qidx])\n",
    "\n",
    "        tv = twoNN_intrinsic_dimension(Theta1)\n",
    "        key_new, key = random.split(key, 2)\n",
    "        ln,lv = getLocalPR3(Theta1,Theta2,key,rad=rad,SP=SP,batch_size=batch_size,iter_cap=iter_cap)\n",
    "    \n",
    "        dests.append([tv,lv,ln]) #ln is naive\n",
    "\n",
    "        print((ir+1)/len(rads),j)\n",
    "    destss.append(dests)\n",
    "destss=np.array(destss)\n",
    "\n",
    "destss=destss.transpose((1,0,2))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5100dbec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting the noise correcting version:\n",
    "def re(x): \n",
    "    x=np.array(x)\n",
    "    return np.where(x>0,x,0)\n",
    "\n",
    "xoi=rads\n",
    "#tru=np.array([GT,GT,d,d])\n",
    "#=destss-tru[:,None]\n",
    "\n",
    "fig,ax=plt.subplots(1,1,figsize=(4,3.2))\n",
    "\n",
    "qoi=destss[:,:,0].flatten()\n",
    "\n",
    "yval=xoi*0+np.nanmean(qoi)\n",
    "lower_error=xoi*0+np.nanquantile(qoi,q=0.25)\n",
    "upper_error=xoi*0+np.nanquantile(qoi,q=0.75)\n",
    "asymmetric_error = re([(yval-lower_error), (upper_error-yval)])\n",
    "\n",
    "ax.plot(xoi, yval ,alpha=1, c='y',label='TwoNN')\n",
    "ax.fill_between(xoi, lower_error, upper_error ,alpha=0.5, facecolor='y')\n",
    "\n",
    "\n",
    "qoi=destss[:,:,1]\n",
    "\n",
    "yval=np.nanmean(qoi,axis=1)\n",
    "lower_error=np.nanquantile(qoi,axis=1,q=0.25)\n",
    "upper_error=np.nanquantile(qoi,axis=1,q=0.75)\n",
    "asymmetric_error = re([(yval-lower_error), (upper_error-yval)])\n",
    "ax.errorbar(xoi, yval, yerr=asymmetric_error,c='r',marker='o',ls='-',alpha=0.5,ms=3,label=r'$\\gamma^\\text{local}_\\text{both}(r)$')#, fmt='o')\n",
    "\n",
    "\n",
    "qoi=destss[:,:,2]\n",
    "\n",
    "yval=np.nanmean(qoi,axis=1)\n",
    "lower_error=np.nanquantile(qoi,axis=1,q=0.25)\n",
    "upper_error=np.nanquantile(qoi,axis=1,q=0.75)\n",
    "asymmetric_error = re([(yval-lower_error), (upper_error-yval)])\n",
    "ax.errorbar(xoi, yval, yerr=asymmetric_error,c='k',marker='o',ls='-',alpha=0.5,ms=3,label=r'$\\gamma^\\text{local}_\\text{naive}(r)$')#, fmt='o')\n",
    "\n",
    "\n",
    "ax.set_xlabel('r: Local ball radius')\n",
    "ax.set_ylabel('Local dimensionality')\n",
    "#ax.set_title('')\n",
    "ax.set_ylim([-5,5])\n",
    "ax.legend()\n",
    "fig.tight_layout()\n",
    "#plt.savefig(\"stringer_local_dim.pdf\", format=\"pdf\", bbox_inches = 'tight')  \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15f31bad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# RBF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b5f8a5d-70a1-4606-99a9-e73532dfe899",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def getRBFData3(Sigma_x,Sigma,P,Q,keyx,keyw,keyn,eps0=0,eps1=0,return_latent=False):\n",
    "    sqrt2=jnp.sqrt(2)\n",
    "    keyww,keywb, keyw2 = random.split(keyw,3)\n",
    "\n",
    "    d=jnp.shape(Sigma)[0]\n",
    "    #X=np.random.randn(P,d)*sigmax\n",
    "    #X=random.normal(key1, (P, d))*sigmax\n",
    "    Sigma_inv = jnp.linalg.inv(Sigma)\n",
    "    X = random.multivariate_normal(keyx, jnp.zeros(d), Sigma_x, shape=(P,))\n",
    "    #X = random.uniform(keyx, minval=-0.5, maxval=0.5, shape=(P,d))\n",
    "    W = random.multivariate_normal(keyww, jnp.zeros(d), Sigma_inv, shape=(Q,)).T\n",
    "\n",
    "    p = random.uniform(keywb, (Q,), minval=-jnp.pi/2, maxval=jnp.pi/2)\n",
    "    Phi = sqrt2*jnp.sin(jnp.matmul(X, W)-p[None, :]) #+ 0.1\n",
    "    #Phi = relu(jnp.matmul(X, W))#-p[None, :])\n",
    "\n",
    "    if eps0 != 0:  # induce additive independent noise\n",
    "        key1, key2, keyn2, keyn  = random.split(keyn, 4)\n",
    "        Psi1 = random.normal(key1, (P, Q))*eps0\n",
    "        Psi2 = random.normal(key2, (P, Q))*eps0\n",
    "    else:\n",
    "        Psi1 = 0.0\n",
    "        Psi2 = 0.0\n",
    "\n",
    "    if eps1 != 0:  # induce multiplicative independent noise\n",
    "        key1, key2, keyn = random.split(keyn, 3)\n",
    "        Psi1a = random.normal(key1, (P, Q))*eps1+1.0\n",
    "        Psi2a = random.normal(key2, (P, Q))*eps1+1.0\n",
    "    else:\n",
    "        Psi1a=1.0\n",
    "        Psi2a=1.0\n",
    "    \n",
    "    Theta1=Phi*Psi1a + Psi1\n",
    "    Theta2=Phi*Psi2a + Psi2\n",
    "\n",
    "    _, keyn2 = random.split(keyn, 2)\n",
    "    _, keyx2 = random.split(keyx,2)\n",
    "    _, keyw2 = random.split(keyw2,2)\n",
    "    if return_latent:\n",
    "        return Theta1,Theta2,X,W,keyx2,keyw2,keyn2\n",
    "    else:\n",
    "        return Theta1,Theta2,keyx2,keyw2,keyn2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1eaf47a-4c43-4798-9d07-d934b9ee9a04",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "key = random.PRNGKey(50)\n",
    "keyx, keyw, keyn = random.split(key, 3)\n",
    "\n",
    "d=4\n",
    "sigma=1.0\n",
    "sigmax=0.7\n",
    "\n",
    "Sigma=np.eye(d)*sigma**2\n",
    "Sigma_x=np.eye(d)*sigmax**2\n",
    "\n",
    "eps_add=0.3# 3 #.4#.5#.5 #.5\n",
    "eps_mult=0.0\n",
    "\n",
    "P=4000\n",
    "Q=100\n",
    "\n",
    "#Qs=np.power(10,np.linspace(1.5,3.5,10)).astype(int)\n",
    "#P=100\n",
    "\n",
    "numits=20#70\n",
    "\n",
    "GT=np.power(1+4*(sigmax/sigma)**2,d/2)\n",
    "print(GT)\n",
    "\n",
    "#Rs=np.linspace(0.2,2.0,10)\n",
    "#ks=np.arange(5,10)\n",
    "#rads=np.linspace(0.4,1.3,10)\n",
    "rads=np.linspace(1.1,1.6,12)\n",
    "\n",
    "SP=P\n",
    "batch_size=200\n",
    "iter_cap=50\n",
    "\n",
    "destss=[]\n",
    "for ir,rad in enumerate(rads):\n",
    "    dests=[]\n",
    "    for j in range(numits):    \n",
    "        Theta1,Theta2,keyx,keyw,keyn = getRBFData3(Sigma_x,Sigma,P,Q,keyx,keyw,keyn,eps0=eps_add,eps1=eps_mult,return_latent=False)\n",
    "        results=estimate_dimensionality_no_centering(Theta1, Theta2, P,Q)\n",
    "        nia = results[0][0]/results[0][1]\n",
    "        dia = results[-1][0]/results[-1][1]\n",
    "        #tv = twoNN_intrinsic_dimension(Theta1-np.mean(Theta1,axis=0))\n",
    "        tv = twoNN_intrinsic_dimension(Theta1)\n",
    "        key_new, key = random.split(key, 2)\n",
    "        ln,lv = getLocalPR3(Theta1,Theta2,key,rad=rad,SP=SP,batch_size=batch_size,iter_cap=iter_cap)\n",
    "    \n",
    "        dests.append([nia,dia,tv,lv,ln]) #ln is naive\n",
    "\n",
    "        print((ir+1)/len(rads),j)\n",
    "    destss.append(dests)\n",
    "destss=np.array(destss)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44ce87ca-8709-46ea-b5a9-882b5ab92698",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting the noise correcting version:\n",
    "def re(x): \n",
    "    x=np.array(x)\n",
    "    return np.where(x>0,x,0)\n",
    "\n",
    "xoi=rads\n",
    "#tru=np.array([GT,GT,d,d])\n",
    "#=destss-tru[:,None]\n",
    "\n",
    "fig,ax=plt.subplots(1,1,figsize=(4,3.6))\n",
    "\n",
    "qoi=destss[:,:,2].flatten()\n",
    "\n",
    "yval=xoi*0+np.nanmean(qoi)\n",
    "lower_error=xoi*0+np.nanquantile(qoi,q=0.25)\n",
    "upper_error=xoi*0+np.nanquantile(qoi,q=0.75)\n",
    "asymmetric_error = re([(yval-lower_error), (upper_error-yval)])\n",
    "\n",
    "ax.plot(xoi, yval , ls='--',alpha=1, c='y',label='TwoNN')\n",
    "ax.fill_between(xoi, lower_error, upper_error ,alpha=0.5, facecolor='y')\n",
    "\n",
    "\n",
    "qoi=destss[:,:,3]\n",
    "\n",
    "yval=np.nanmean(qoi,axis=1)\n",
    "lower_error=np.nanquantile(qoi,axis=1,q=0.25)\n",
    "upper_error=np.nanquantile(qoi,axis=1,q=0.75)\n",
    "asymmetric_error = re([(yval-lower_error), (upper_error-yval)])\n",
    "ax.errorbar(xoi, yval, yerr=asymmetric_error,c='r',marker='o',ls='-',alpha=0.5,ms=3,label=r'$\\gamma^\\text{local}_\\text{both}(r)$')#, fmt='o')\n",
    "\n",
    "qoi=destss[:,:,4]\n",
    "\n",
    "yval=np.nanmean(qoi,axis=1)\n",
    "lower_error=np.nanquantile(qoi,axis=1,q=0.25)\n",
    "upper_error=np.nanquantile(qoi,axis=1,q=0.75)\n",
    "asymmetric_error = re([(yval-lower_error), (upper_error-yval)])\n",
    "ax.errorbar(xoi, yval, yerr=asymmetric_error,c='k',marker='o',ls='-',alpha=0.5,ms=3,label=r'$\\gamma^\\text{local}_\\text{naive}(r)$')#, fmt='o')\n",
    "\n",
    "\n",
    "ax.hlines(d,np.min(xoi),np.max(xoi),color='c',lw=2,alpha=0.5,ls='--',label='True local dim.')\n",
    "ax.set_ylim([0,d*4.5])\n",
    "ax.set_xlabel('r: Local ball radius')\n",
    "ax.set_ylabel('Local dimensionality')\n",
    "ax.set_title('Random Fourier features')\n",
    "ax.legend()\n",
    "fig.tight_layout()\n",
    "plt.savefig(\"rbf_local_dim.svg\", format=\"svg\", bbox_inches = 'tight')  \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1535128-b508-4fe7-a0a3-1904e094a96a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "brainscore",
   "language": "python",
   "name": "brainscore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
