{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2f51da84-3d0e-4de3-91df-4038c918c87a",
   "metadata": {},
   "source": [
    "To get started with this notebook you first need to get a token by \n",
    "registering for access by following the directions here (Scenario 1): \n",
    "\n",
    "https://tutorial.microns-explorer.org/quickstart_notebooks/01-caveclient-setup.html\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e005ec16-0e36-43e8-b06f-d4cecee4b9fe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cortical neurons: 50,594\n",
      "Sub‑matrix 5000×5000  nnz=27290\n"
     ]
    }
   ],
   "source": [
    "from caveclient import CAVEclient\n",
    "import pandas as pd, numpy as np, matplotlib.pyplot as plt, time, threading, pickle, os\n",
    "from scipy.sparse import coo_matrix, csr_matrix\n",
    "from sklearn.utils.extmath import randomized_svd\n",
    "from concurrent.futures import ThreadPoolExecutor, as_completed\n",
    "from tqdm import tqdm\n",
    "from pathlib import Path\n",
    "\n",
    "# ────────────────────────────── 0  CONFIG ──────────────────────────────\n",
    "CACHE       = Path(\"microns_cache\"); CACHE.mkdir(exist_ok=True)\n",
    "META_F      = CACHE/\"nucleus_meta.parquet\"\n",
    "TRIPLET_F   = CACHE/\"syn_triplets.npz\"\n",
    "SUB_F       = CACHE/\"W_sub.pkl\"\n",
    "DATSTACK    = \"minnie65_public\" # from https://www.biorxiv.org/content/10.1101/2022.07.20.499976v2\n",
    "DEPTH_MIN, DEPTH_MAX = 18000, 25000   # restrict to cortex band\n",
    "\n",
    "# ---- run‑time knobs -----------------------------------------------------------\n",
    "SUB_P   = int(os.getenv(\"SUB_P\", 3000))   # neurons kept in sub‑matrix\n",
    "P_DIM   = int(os.getenv(\"P_DIM\", 50))      # #PCs in each subset (P = N)\n",
    "R_TRIAL = int(os.getenv(\"R_TRIAL\", 1000))   # #random subsets (R)\n",
    "POWER_ITERS = int(os.getenv(\"POWER_ITERS\", 200))  # power iterations for HT method\n",
    "TOPK_SUPPORT = 50\n",
    "\n",
    "# ───────────────────────── Helpers ─────────────────────────\n",
    "\n",
    "def normalise_prob(p):\n",
    "    if p is None:\n",
    "        return None\n",
    "    p = np.asarray(p, float)\n",
    "    p[p < 0] = 0.0\n",
    "    s = p.sum()\n",
    "    return None if s <= 0 or not np.isfinite(s) else p / s\n",
    "\n",
    "# ─────────────────── 1  Soma metadata ──────────────────\n",
    "\n",
    "def download_metadata():\n",
    "    client = CAVEclient(DATSTACK)\n",
    "    meta = client.materialize.query_table(\"nucleus_ref_neuron_svm\")\n",
    "    good = meta[\"pt_position\"].apply(lambda p: isinstance(p,(list,tuple,np.ndarray)) and len(p)==3)\n",
    "    meta = meta[good].copy()\n",
    "    meta[\"z\"] = meta[\"pt_position\"].apply(lambda p:p[2])\n",
    "    meta = meta[(meta[\"cell_type\"]==\"neuron\") & meta[\"z\"].between(DEPTH_MIN, DEPTH_MAX)]\n",
    "    meta.to_parquet(META_F)\n",
    "    return meta\n",
    "\n",
    "if META_F.exists():\n",
    "    meta = pd.read_parquet(META_F)\n",
    "else:\n",
    "    print(\"Downloading nucleus_ref_neuron_svm …\")\n",
    "    meta = download_metadata(); print(\"Cached:\", META_F)\n",
    "\n",
    "valid_ids  = meta[\"pt_root_id\"].astype(np.int64).tolist()\n",
    "print(f\"Cortical neurons: {len(valid_ids):,}\")\n",
    "id2idx     = {rid:i for i,rid in enumerate(valid_ids)}\n",
    "post_is_neu= meta.set_index(\"pt_root_id\")[\"cell_type\"].to_dict()\n",
    "\n",
    "# ─────────────────── 2  Synapse triplets ─────────────────\n",
    "\n",
    "def download_triplets():\n",
    "    client = CAVEclient(DATSTACK)\n",
    "    rows, cols, wgts = [], [], []\n",
    "    chunk_ids  = 4000\n",
    "    chunks = [valid_ids[i:i+chunk_ids] for i in range(0, len(valid_ids), chunk_ids)]\n",
    "\n",
    "    def fetch(pre_ids):\n",
    "        try:\n",
    "            df = client.materialize.query_table(\n",
    "                \"synapses_pni_2\",\n",
    "                filter_in_dict={\"pre_pt_root_id\": pre_ids},\n",
    "                select_columns=[\"pre_pt_root_id\",\"post_pt_root_id\",\"size\"])\n",
    "            df = df[df[\"post_pt_root_id\"].map(post_is_neu).eq(\"neuron\")]\n",
    "            return (df[\"pre_pt_root_id\"].map(id2idx).values,\n",
    "                    df[\"post_pt_root_id\"].map(id2idx).values,\n",
    "                    df[\"size\"].values)\n",
    "        except Exception:\n",
    "            return ([],[],[])\n",
    "\n",
    "    with ThreadPoolExecutor(4) as ex:\n",
    "        futs=[ex.submit(fetch,c) for c in chunks]\n",
    "        for fut in tqdm(as_completed(futs), total=len(futs)):\n",
    "            r,c,s=fut.result(); rows+=r.tolist(); cols+=c.tolist(); wgts+=s.tolist()\n",
    "\n",
    "    np.savez_compressed(TRIPLET_F, pre_idx=np.asarray(rows), post_idx=np.asarray(cols), size=np.asarray(wgts))\n",
    "    print(\"Triplets cached:\", TRIPLET_F)\n",
    "\n",
    "if not TRIPLET_F.exists():\n",
    "    print(\"Streaming synapses_pni_2 … (~15 min one‑off)\")\n",
    "    download_triplets()\n",
    "\n",
    "# ─────────────────── 3  Sub‑matrix builder ─────────────────\n",
    "\n",
    "def build_submatrix(k=SUB_P, seed=0):\n",
    "    d      = np.load(TRIPLET_F)\n",
    "    pre,post,siz = d[\"pre_idx\"], d[\"post_idx\"], d[\"size\"]\n",
    "    pmax   = int(max(pre.max(), post.max())+1)\n",
    "    rng    = np.random.default_rng(seed)\n",
    "    keep   = rng.choice(pmax, k, replace=False)\n",
    "    mask   = np.isin(pre, keep) & np.isin(post, keep)\n",
    "    W_sub  = coo_matrix((siz[mask], (pre[mask], post[mask])), shape=(pmax,pmax)).tocsr()[keep][:,keep]\n",
    "    col_norm = np.sqrt(W_sub.power(2).sum(axis=0)).A1\n",
    "    prob_imp = normalise_prob(1/(col_norm+1))\n",
    "    pickle.dump((W_sub,prob_imp), SUB_F.open(\"wb\"))\n",
    "    print(\"Cached sub‑matrix: \", SUB_F)\n",
    "    return W_sub, prob_imp\n",
    "\n",
    "if SUB_F.exists():\n",
    "    with SUB_F.open(\"rb\") as f:\n",
    "        W_sub, prob_imp = pickle.load(f)\n",
    "else:\n",
    "    print(\"Building sub‑matrix...\"); W_sub, prob_imp = build_submatrix()\n",
    "\n",
    "print(f\"Sub‑matrix {W_sub.shape[0]}×{W_sub.shape[1]}  nnz={W_sub.nnz}\")\n",
    "\n",
    "# ─────────────────── 4  Column‑set helper ─────────────────\n",
    "\n",
    "def sample_colsets(n_cols, num_sets, set_size, seed=0, prob=None):\n",
    "    rng  = np.random.default_rng(seed)\n",
    "    prob = normalise_prob(prob)\n",
    "    return [rng.choice(n_cols, size=set_size, replace=False, p=prob) for _ in range(num_sets)]\n",
    "\n",
    "# ─────────────────── 5  PCA implementations ───────────────\n",
    "\n",
    "def heavy_tail_pca_sparse(\n",
    "    A: csr_matrix,\n",
    "    *,\n",
    "    P=P_DIM,\n",
    "    R=R_TRIAL,\n",
    "    N=P_DIM,\n",
    "    seed=None,\n",
    "    prob=None,\n",
    "    colsets=None,\n",
    "    tol=1e-6,\n",
    "    max_iter=POWER_ITERS\n",
    "):\n",
    "    \"\"\"\n",
    "    Streaming power‑iteration that approximates the draft UU^T eigenvector\n",
    "    without constructing the p×p matrix. Colsets must have length R.\n",
    "    Skips trials where randomized_svd fails to converge.\n",
    "    \"\"\"\n",
    "    from sklearn.utils.extmath import randomized_svd\n",
    "    from numpy.linalg import LinAlgError\n",
    "\n",
    "    rng, p, n = np.random.default_rng(seed), *A.shape\n",
    "    prob = normalise_prob(prob)\n",
    "    use_sets = colsets is not None\n",
    "    if use_sets and len(colsets) != R:\n",
    "        raise ValueError(\"colsets length must equal R\")\n",
    "\n",
    "    v = rng.standard_normal(p)\n",
    "    v /= np.linalg.norm(v)\n",
    "\n",
    "    for it in range(max_iter):\n",
    "        v_new = np.zeros_like(v)\n",
    "        used = 0\n",
    "        for r in range(R):\n",
    "            cols = colsets[r] if use_sets else rng.choice(n, size=N, replace=False, p=prob)\n",
    "            try:\n",
    "                sub = A[:, cols]\n",
    "                U, _, _ = randomized_svd(sub, n_components=P, n_iter=0, random_state=r + 12345)\n",
    "                v_new += U @ (U.T @ v)\n",
    "                used += 1\n",
    "            except LinAlgError:\n",
    "                print(f\"  Iter {it}, subset {r}: SVD did not converge — skipping.\")\n",
    "                continue\n",
    "        if used == 0:\n",
    "            raise RuntimeError(\"All SVD trials failed — try increasing N or reducing P.\")\n",
    "        v_new /= np.linalg.norm(v_new)\n",
    "        if np.linalg.norm(v_new - v) < tol:\n",
    "            break\n",
    "        v = v_new\n",
    "\n",
    "    return v\n",
    "\n",
    "def our_method_sparse(A: csr_matrix, *, P=P_DIM, R=R_TRIAL, seed=1, prob=None, colsets=None):\n",
    "    \"\"\"\n",
    "    Our HTPCA method from the paper that 'heavy_tail_pca_sparse' above approximates using 'randomized_svd'\n",
    "    instead of the slower 'eigh' used here.\n",
    "    \"\"\"\n",
    "    rng,p,n = np.random.default_rng(seed),*A.shape\n",
    "    prob    = normalise_prob(prob)\n",
    "    use_sets= colsets is not None\n",
    "    if use_sets and len(colsets)!=R:\n",
    "        raise ValueError(\"colsets length must equal R\")\n",
    "    V       = np.zeros((p,p), dtype=np.float32)\n",
    "    for i in range(R):\n",
    "        cols = colsets[i] if use_sets else rng.choice(n, size=P, replace=False, p=prob)\n",
    "        X    = A[:,cols].toarray()\n",
    "        _,U  = np.linalg.eigh(X@X.T)\n",
    "        Usel = U[:,-P:]\n",
    "        V   += Usel @ Usel.T\n",
    "    _,vecs = np.linalg.eigh(V)\n",
    "    v = vecs[:,-1]\n",
    "    return v/np.linalg.norm(v)\n",
    "\n",
    "# ─────────────────── 6  Diagnostics ───────────────────────\n",
    "\n",
    "def diag(v1,v2,k=TOPK_SUPPORT):\n",
    "    v1n,v2n = v1/np.linalg.norm(v1), v2/np.linalg.norm(v2)\n",
    "    if np.dot(v1n,v2n)<0: v2n=-v2n\n",
    "    print(\"  cosine  : %.4f\"%np.dot(v1n,v2n))\n",
    "    s1=set(np.argpartition(v1n,-k)[-k:]); s2=set(np.argpartition(v2n,-k)[-k:])\n",
    "    print(f\"  overlap : {len(s1&s2)}/{k}\")\n",
    "    print(\"  neg‑frac: %.2f | %.2f\"%((v1<0).mean(), (v2<0).mean()))\n",
    "\n",
    "# ─────────────────── 7  Run comparison between heavy_tail_pca_sparse and our_method_sparse   ────────────────────\n",
    "# Note: Increasing R_TRIAL increases agreement -- commented out as this takes a while to run\n",
    "# print(f\"\\n Sub‑matrix benchmark … (SUB_P={SUB_P}, R={R_TRIAL}, P=N={P_DIM})\")\n",
    "# start=time.time()\n",
    "\n",
    "# n_cols      = W_sub.shape[1]\n",
    "# shared_sets = sample_colsets(n_cols, R_TRIAL, P_DIM, seed=0, prob=prob_imp)\n",
    "\n",
    "# vec_ht  = heavy_tail_pca_sparse(W_sub, seed=0, prob=prob_imp, colsets=shared_sets)\n",
    "# vec_pap = our_method_sparse(W_sub, seed=0, prob=prob_imp, colsets=shared_sets)\n",
    "\n",
    "# diag(vec_ht, vec_pap)\n",
    "# print(f\"Elapsed {time.time()-start:.1f}s\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "707bd773-a12f-4971-abbd-8a40d061e443",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded cached alpha scores.\n",
      "\n",
      "Top 10 heavy-tailed rows (lowest α):\n",
      "  row=17599 | α=0.09\n",
      "  row=31855 | α=0.10\n",
      "  row=50346 | α=0.10\n",
      "  row=28352 | α=0.11\n",
      "  row=11549 | α=0.11\n",
      "  row=16843 | α=0.11\n",
      "  row=26903 | α=0.11\n",
      "  row=48047 | α=0.11\n",
      "  row=12312 | α=0.11\n",
      "  row=25053 | α=0.11\n",
      "Loaded cached heavy-tailed submatrix.\n",
      "\n",
      "Extracted 3000×3000 heavy-tailed submatrix\n",
      "\n",
      "Dowsampled to 1500×1500 heavy-tailed submatrix\n",
      "\n",
      "HT-PCA split-half reproducibility check ...\n",
      "Cosine similarity (HT-PCA left vs right): 0.4285\n",
      "\n",
      "Minsker method split-half reproducibility check ...\n",
      "(CVXPY) May 22 06:52:30 PM: Encountered unexpected exception importing solver MOSEK:\n",
      "AttributeError(\"module 'mosek' has no attribute 'conetype'\")\n",
      "\n",
      "--- Minsker method diagnostics ---\n",
      "Input shape: (1500, 750)\n",
      "Mean of mat: 85.90731377777777\n",
      "Constructed mats with shape: (2250000, 3)\n",
      "Final alpha sum: 1.0\n",
      "\n",
      "--- Minsker method diagnostics ---\n",
      "Input shape: (1500, 750)\n",
      "Mean of mat: 85.824288\n",
      "Constructed mats with shape: (2250000, 3)\n",
      "Final alpha sum: 1.0\n",
      "Cosine similarity (Minsker left vs right): 0.0059\n",
      "Cosine similarity (HT-PCA left vs right): 0.1526\n",
      "\n",
      "Minsker method split-half reproducibility check ...\n",
      "\n",
      "--- Minsker method diagnostics ---\n",
      "Input shape: (1500, 750)\n",
      "Mean of mat: 89.74914488888889\n",
      "Constructed mats with shape: (2250000, 3)\n",
      "Final alpha sum: 1.0\n",
      "\n",
      "--- Minsker method diagnostics ---\n",
      "Input shape: (1500, 750)\n",
      "Mean of mat: 81.98245688888889\n",
      "Constructed mats with shape: (2250000, 3)\n",
      "Final alpha sum: 1.0\n",
      "Cosine similarity (Minsker left vs right): 0.0004\n",
      "Cosine similarity (HT-PCA left vs right): 0.4182\n",
      "\n",
      "Minsker method split-half reproducibility check ...\n",
      "\n",
      "--- Minsker method diagnostics ---\n",
      "Input shape: (1500, 750)\n",
      "Mean of mat: 83.94049777777778\n",
      "Constructed mats with shape: (2250000, 3)\n",
      "Final alpha sum: 1.0000000000000002\n",
      "\n",
      "--- Minsker method diagnostics ---\n",
      "Input shape: (1500, 750)\n",
      "Mean of mat: 87.791104\n",
      "Constructed mats with shape: (2250000, 3)\n",
      "Final alpha sum: 1.0\n",
      "Cosine similarity (Minsker left vs right): 0.9875\n",
      "  Iter 11, subset 12: SVD did not converge — skipping.\n",
      "Cosine similarity (HT-PCA left vs right): 0.9560\n",
      "\n",
      "Minsker method split-half reproducibility check ...\n",
      "\n",
      "--- Minsker method diagnostics ---\n",
      "Input shape: (1500, 750)\n",
      "Mean of mat: 82.823168\n",
      "Constructed mats with shape: (2250000, 3)\n"
     ]
    }
   ],
   "source": [
    "from scipy.stats import levy_stable\n",
    "from joblib import Parallel, delayed\n",
    "\n",
    "HT_SUBMATRIX_F = CACHE / \"W_heavy_submatrix.pkl\"\n",
    "ALPHA_SCORES_F = CACHE / \"alpha_scores.pkl\"\n",
    "\n",
    "# Reconstruct full sparse matrix W from cached triplet data\n",
    "trip = np.load(TRIPLET_F)\n",
    "pre, post, size = trip[\"pre_idx\"], trip[\"post_idx\"], trip[\"size\"]\n",
    "pmax = int(max(pre.max(), post.max()) + 1)\n",
    "W = coo_matrix((size, (pre, post)), shape=(pmax, pmax)).tocsr()\n",
    "\n",
    "def estimate_alpha(data: np.ndarray, quantile: float = 0.95) -> float:\n",
    "    data = np.sort(data)\n",
    "    tail = data[int(len(data) * quantile):]\n",
    "    try:\n",
    "        alpha, _, _, _ = levy_stable.fit(tail, floc=0, fscale=1)\n",
    "        return alpha\n",
    "    except Exception:\n",
    "        return np.nan\n",
    "\n",
    "def scan_heavy_tailed_rows(W: csr_matrix, n_jobs: int = -1) -> list:\n",
    "    def process_row(i):\n",
    "        row = W[i].toarray().flatten()\n",
    "        if np.count_nonzero(row) > 10:\n",
    "            return (i, estimate_alpha(row[row > 0]))\n",
    "        return (i, np.nan)\n",
    "\n",
    "    print(\"\\n Scanning full matrix for heavy-tailed rows in parallel...\")\n",
    "    results = Parallel(n_jobs=n_jobs)(delayed(process_row)(i) for i in range(W.shape[0]))\n",
    "    return [(i, a) for i, a in results if np.isfinite(a)]\n",
    "\n",
    "def extract_heavy_tailed_submatrix(W: csr_matrix, alpha_scores: list, N: int = 3000, min_nnz: int = 100) -> csr_matrix:\n",
    "    valid = [i for i, a in alpha_scores if W[i].nnz >= min_nnz]\n",
    "    valid_scores = [(i, a) for i, a in alpha_scores if i in valid]\n",
    "    valid_scores.sort(key=lambda x: x[1])\n",
    "    top = [i for i, _ in valid_scores[:N]]\n",
    "    return W[top][:, top]\n",
    "\n",
    "# Cache or compute alpha scores\n",
    "if ALPHA_SCORES_F.exists():\n",
    "    with open(ALPHA_SCORES_F, \"rb\") as f:\n",
    "        alpha_scores = pickle.load(f)\n",
    "    print(\"Loaded cached alpha scores.\")\n",
    "else:\n",
    "    alpha_scores = scan_heavy_tailed_rows(W)\n",
    "    with open(ALPHA_SCORES_F, \"wb\") as f:\n",
    "        pickle.dump(alpha_scores, f)\n",
    "    print(\"Saved alpha scores.\")\n",
    "\n",
    "# Print top-10 most heavy-tailed\n",
    "alpha_scores.sort(key=lambda x: x[1])\n",
    "print(\"\\nTop 10 heavy-tailed rows (lowest α):\")\n",
    "for i, a in alpha_scores[:10]:\n",
    "    print(f\"  row={i:5d} | α={a:.2f}\")\n",
    "\n",
    "# Extract and cache heavy-tailed submatrix\n",
    "if HT_SUBMATRIX_F.exists():\n",
    "    with open(HT_SUBMATRIX_F, \"rb\") as f:\n",
    "        W_heavy = pickle.load(f)\n",
    "    print(\"Loaded cached heavy-tailed submatrix.\")\n",
    "else:\n",
    "    W_heavy = extract_heavy_tailed_submatrix(W, alpha_scores, N=3000)\n",
    "    with open(HT_SUBMATRIX_F, \"wb\") as f:\n",
    "        pickle.dump(W_heavy, f)\n",
    "    print(\"Saved heavy-tailed submatrix: \", HT_SUBMATRIX_F)\n",
    "\n",
    "print(f\"\\nExtracted {W_heavy.shape[0]}×{W_heavy.shape[1]} heavy-tailed submatrix\")\n",
    "\n",
    "# Extract top 1500 most heavy-tailed entries for split-half evaluation\n",
    "alpha_dict = dict(alpha_scores)\n",
    "row_ids = list(alpha_dict.keys())\n",
    "alphas = np.array([alpha_dict[i] for i in row_ids])\n",
    "sort_idx = np.argsort(alphas)\n",
    "top_1500_ids = [row_ids[i] for i in sort_idx[:1500]]\n",
    "W_heavy_1500 = W[top_1500_ids][:, top_1500_ids]\n",
    "print(f\"\\nDowsampled to {W_heavy_1500.shape[0]}×{W_heavy_1500.shape[1]} heavy-tailed submatrix\")\n",
    "\n",
    "# ——— Run split-half comparison on HT-PCA and Minsker ———\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "\n",
    "def run_split_half(v1, v2):\n",
    "    norm1 = np.linalg.norm(v1)\n",
    "    norm2 = np.linalg.norm(v2)\n",
    "    if norm1 == 0 or norm2 == 0 or np.isnan(norm1) or np.isnan(norm2):\n",
    "        return float(\"nan\")\n",
    "    v1n, v2n = v1 / norm1, v2 / norm2\n",
    "    if np.dot(v1n, v2n) < 0:\n",
    "        v2n = -v2n\n",
    "    return float(np.dot(v1n, v2n))\n",
    "\n",
    "def run_split_half(v1, v2):\n",
    "    v1n = v1 / (np.linalg.norm(v1) + 1e-10)\n",
    "    v2n = v2 / (np.linalg.norm(v2) + 1e-10)\n",
    "    if np.dot(v1n, v2n) < 0:\n",
    "        v2n = -v2n\n",
    "    return float(np.dot(v1n, v2n))\n",
    "\n",
    "def minsker_method(mat, NU=[0.5], NUM_GROUPS=10, num_comps=1):\n",
    "    import cvxpy as cp\n",
    "    print(\"\\n--- Minsker method diagnostics ---\")\n",
    "    print(\"Input shape:\", mat.shape)\n",
    "    print(\"Mean of mat:\", mat.mean())\n",
    "\n",
    "    if mat.mean() == 0:\n",
    "        print(\"Matrix mean is zero. Returning zeros.\")\n",
    "        return np.zeros(mat.shape[0])\n",
    "\n",
    "    mat = mat / np.mean(mat)\n",
    "    size_groups = int(mat.shape[1] / NUM_GROUPS)\n",
    "    cols = list(range(mat.shape[1]))\n",
    "    np.random.shuffle(cols)\n",
    "    mats = []\n",
    "    for i in range(NUM_GROUPS):\n",
    "        idxs = [cols[j] for j in range(size_groups * i, size_groups * (i + 1))]\n",
    "        X_trim = mat[:, idxs]\n",
    "        if X_trim.shape[1] == 0:\n",
    "            continue\n",
    "        mats.append(X_trim @ X_trim.T / X_trim.shape[1])\n",
    "        mats[-1] = mats[-1].flatten()\n",
    "    if len(mats) == 0:\n",
    "        print(\"All groups empty. Returning zeros.\")\n",
    "        return np.zeros(mat.shape[0])\n",
    "\n",
    "    mats = np.array(mats).T\n",
    "    print(\"Constructed mats with shape:\", mats.shape)\n",
    "\n",
    "    coefs = cp.Variable(mats.shape[1])\n",
    "    c = cp.Variable(mats.shape[1])\n",
    "    objective = cp.sum(c)\n",
    "    constraints = [cp.SOC(c[i], mats @ coefs - mats[:, i]) for i in range(mats.shape[1])]\n",
    "    constraints += [coefs >= 0, cp.sum(coefs) == 1, c >= 0]\n",
    "    problem = cp.Problem(cp.Minimize(objective), constraints)\n",
    "    try:\n",
    "        problem.solve(solver='ECOS')\n",
    "    except Exception as e:\n",
    "        print(\"Solver failed with exception:\", e)\n",
    "        return np.zeros(mat.shape[0])\n",
    "\n",
    "    if coefs.value is None or np.isnan(coefs.value).any():\n",
    "        print(\"Invalid coefficients. Returning zeros.\")\n",
    "        return np.zeros(mat.shape[0])\n",
    "\n",
    "    alpha = [coefs.value[i] if coefs.value[i] >= NU[0] / mats.shape[1] else 0 for i in range(mats.shape[1])]\n",
    "    alpha = np.array(alpha)\n",
    "    if alpha.sum() == 0 or np.isnan(alpha).any():\n",
    "        print(\"Alpha vector is invalid. Returning zeros.\")\n",
    "        return np.zeros(mat.shape[0])\n",
    "\n",
    "    alpha /= alpha.sum()\n",
    "    print(\"Final alpha sum:\", alpha.sum())\n",
    "\n",
    "    Cov = sum([alpha[i] * mats[:, i] for i in range(mats.shape[1])])\n",
    "    Cov = np.reshape(Cov, (mat.shape[0], mat.shape[0]))\n",
    "    try:\n",
    "        _, U = np.linalg.eigh(Cov)\n",
    "        return U[:, -num_comps:]\n",
    "    except:\n",
    "        print(\"eigh failed. Returning zeros.\")\n",
    "        return np.zeros(mat.shape[0])\n",
    "\n",
    "#---------------- Split-half analysis (1500×1500)-----------------\n",
    "print(\"\\nHT-PCA split-half reproducibility check ...\")\n",
    "n = W_heavy_1500.shape[1]\n",
    "\n",
    "# Run split-half analysis for 'num_reps' replicates\n",
    "num_reps = 10\n",
    "ht_pca_cs = []\n",
    "minsker_cs = []\n",
    "for r in range(num_reps):\n",
    "    idx = np.random.default_rng(r).permutation(n)\n",
    "    half = n // 2\n",
    "    \n",
    "    W1 = W_heavy_1500[:, idx[:half]]\n",
    "    W2 = W_heavy_1500[:, idx[half:2*half]]\n",
    "    \n",
    "    vec1_ht = heavy_tail_pca_sparse(W1, seed=1, R=100)\n",
    "    vec2_ht = heavy_tail_pca_sparse(W2, seed=2, R=100)\n",
    "\n",
    "    ht_cs = run_split_half(vec1_ht, vec2_ht)\n",
    "    ht_pca_cs.append(ht_cs)\n",
    "    print(\"Cosine similarity (HT-PCA left vs right): %.4f\" % ht_cs)\n",
    "    \n",
    "    print(\"\\nMinsker method split-half reproducibility check ...\")\n",
    "    \n",
    "    vec1_min = minsker_method(W1.toarray(), NUM_GROUPS=3, num_comps=1)\n",
    "    vec2_min = minsker_method(W2.toarray(), NUM_GROUPS=3, num_comps=1)\n",
    "\n",
    "    min_cs = run_split_half(vec1_min.flatten(), vec2_min.flatten())\n",
    "    minsker_cs.append(min_cs)\n",
    "    print(\"Cosine similarity (Minsker left vs right): %.4f\" % min_cs)\n",
    "\n",
    "print(\"Cosine similarity (HT-PCA) summary:\", np.mean(ht_pca_cs), np.std(ht_pca_cs))\n",
    "print(\"Cosine similarity (Minsker) summary:\", np.mean(min_cs), np.std(min_cs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1891c77-74fb-4f79-bc24-490b977df8f4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
