{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "# ============================================================\n",
        "# NYTimes (UCI Bag-of-Words) with RTPM SPCA\n",
        "# ============================================================\n",
        "\n",
        "!pip -q install numpy scipy scikit-learn tqdm matplotlib\n",
        "\n",
        "import os, subprocess, gzip, time\n",
        "import numpy as np\n",
        "import scipy.sparse as sp\n",
        "from tqdm import tqdm\n",
        "from sklearn.preprocessing import normalize as sk_normalize\n",
        "from array import array\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# -----------------------\n",
        "# Utils\n",
        "# -----------------------\n",
        "def sh(cmd):\n",
        "    print(\"\\n>\", cmd)\n",
        "    out = subprocess.run(cmd, shell=True, capture_output=True, text=True)\n",
        "    if out.returncode != 0:\n",
        "        print(out.stdout)\n",
        "        print(out.stderr)\n",
        "        raise RuntimeError(f\"Command failed: {cmd}\")\n",
        "    return out.stdout\n",
        "\n",
        "def file_info(path):\n",
        "    if not os.path.exists(path):\n",
        "        return \"MISSING\"\n",
        "    return f\"{path}: {os.path.getsize(path)} bytes\"\n",
        "\n",
        "def tnow():\n",
        "    return time.time()\n",
        "\n",
        "def fmt_secs(s):\n",
        "    if s < 60: return f\"{s:.2f}s\"\n",
        "    if s < 3600: return f\"{s/60:.2f}m\"\n",
        "    return f\"{s/3600:.2f}h\"\n",
        "\n",
        "def banner(msg):\n",
        "    print(\"\\n\" + \"=\"*80)\n",
        "    print(msg)\n",
        "    print(\"=\"*80)\n",
        "\n",
        "print(\"Ready.\")\n",
        "\n",
        "# -----------------------\n",
        "# 1) Download NYTimes\n",
        "# -----------------------\n",
        "banner(\"1) Download NYTimes dataset (UCI Bag-of-Words)\")\n",
        "\n",
        "BASE = \"https://archive.ics.uci.edu/ml/machine-learning-databases/bag-of-words/\"\n",
        "DATASET = \"nytimes\"\n",
        "doc_url = f\"{BASE}docword.{DATASET}.txt.gz\"\n",
        "voc_url = f\"{BASE}vocab.{DATASET}.txt\"\n",
        "\n",
        "sh(\"rm -f docword.txt.gz vocab.txt\")\n",
        "sh(f\"curl -L -f -o docword.txt.gz '{doc_url}'\")\n",
        "sh(f\"curl -L -f -o vocab.txt '{voc_url}'\")\n",
        "\n",
        "print(file_info(\"docword.txt.gz\"))\n",
        "print(file_info(\"vocab.txt\"))\n",
        "\n",
        "assert os.path.getsize(\"docword.txt.gz\") > 1_000_000, \"docword file too small (download likely failed)\"\n",
        "assert os.path.getsize(\"vocab.txt\") > 100_000, \"vocab file too small (download likely failed)\"\n",
        "\n",
        "print(\"\\nFirst vocab lines:\")\n",
        "sh(\"head -n 5 vocab.txt\")\n",
        "\n",
        "with gzip.open(\"docword.txt.gz\", \"rt\") as f:\n",
        "    head = [f.readline().strip() for _ in range(3)]\n",
        "print(\"\\ndocword header lines:\", head)\n",
        "\n",
        "# -----------------------\n",
        "# 2) Parameters\n",
        "# -----------------------\n",
        "banner(\"2) Parameters (edit these)\")\n",
        "\n",
        "DOC_LIMIT   = 10_000   # #docs loaded into X\n",
        "M           = 20_000   # target vocab size AFTER filtering (auto-shrinks)\n",
        "MIN_DF      = 10\n",
        "MAX_DF_FRAC = 0.5\n",
        "USE_LOG1P   = True\n",
        "DOC_L2_NORM = True\n",
        "\n",
        "# RTPM / Algorithm params\n",
        "K           = 5\n",
        "R           = 50       # truncation per step (top-R coords)\n",
        "T_MAX       = 50        # max iterations with early stopping\n",
        "TOL         = 1e-4      # relative tolerance for Rayleigh improvement\n",
        "BATCH_BASIS = 128\n",
        "\n",
        "RESTART_BUDGET = DOC_LIMIT  # choose this many basis seeds by largest variance\n",
        "\n",
        "# Printing\n",
        "TOP_WORDS_ABS = 20\n",
        "TOP_WORDS_POS = 12\n",
        "TOP_WORDS_NEG = 12\n",
        "TOP_DOCS      = 5\n",
        "\n",
        "# Visualization\n",
        "VIZ_M_TOTAL = 200  # subset size for heatmaps\n",
        "\n",
        "print(f\"DOC_LIMIT={DOC_LIMIT}, M={M}, MIN_DF={MIN_DF}, MAX_DF_FRAC={MAX_DF_FRAC}\")\n",
        "print(f\"USE_LOG1P={USE_LOG1P}, DOC_L2_NORM={DOC_L2_NORM}\")\n",
        "print(f\"K={K}, R={R}, T_MAX={T_MAX}, TOL={TOL}, BATCH_BASIS={BATCH_BASIS}, RESTART_BUDGET={RESTART_BUDGET}\")\n",
        "print(f\"Heatmap subset size VIZ_M_TOTAL={VIZ_M_TOTAL}\")\n",
        "\n",
        "# -----------------------\n",
        "# 3) Load vocab\n",
        "# -----------------------\n",
        "banner(\"3) Load vocabulary\")\n",
        "\n",
        "t0 = tnow()\n",
        "with open(\"vocab.txt\", \"rt\") as f:\n",
        "    vocab_all = [line.strip() for line in f]\n",
        "print(\"Loaded vocab lines:\", len(vocab_all), \"time:\", fmt_secs(tnow()-t0))\n",
        "print(\"Sample vocab:\", vocab_all[:10])\n",
        "\n",
        "# -----------------------\n",
        "# 4) Pass 1: DF over first DOC_LIMIT docs\n",
        "# -----------------------\n",
        "banner(\"4) Pass 1: compute document frequency (DF) over first DOC_LIMIT docs\")\n",
        "\n",
        "def pass1_compute_df(path_gz: str, doc_limit: int):\n",
        "    with gzip.open(path_gz, \"rt\") as f:\n",
        "        D = int(f.readline().strip())\n",
        "        V = int(f.readline().strip())\n",
        "        NNZ = int(f.readline().strip())\n",
        "        df = np.zeros(V, dtype=np.int32)\n",
        "\n",
        "        # NOTE: In UCI docword, each line is a nonzero count (doc, word, count).\n",
        "        # Incrementing df[w] counts how often w appears among those nonzeros in the first DOC_LIMIT docs.\n",
        "        for line in tqdm(f, desc=\"Pass1 streaming df\"):\n",
        "            d_str, w_str, _ = line.split()\n",
        "            d = int(d_str)\n",
        "            if d > doc_limit:\n",
        "                break\n",
        "            w = int(w_str) - 1\n",
        "            df[w] += 1\n",
        "    return D, V, NNZ, df\n",
        "\n",
        "t0 = tnow()\n",
        "D_header, V_header, NNZ_header, df = pass1_compute_df(\"docword.txt.gz\", DOC_LIMIT)\n",
        "print(\"Header D,V,NNZ:\", D_header, V_header, NNZ_header, \"time:\", fmt_secs(tnow()-t0))\n",
        "print(\"Words seen in first DOC_LIMIT:\", int((df > 0).sum()))\n",
        "print(\"DF stats (over all words): min/median/max:\", int(df.min()), int(np.median(df)), int(df.max()))\n",
        "\n",
        "max_df = int(MAX_DF_FRAC * DOC_LIMIT)\n",
        "eligible = np.where((df >= MIN_DF) & (df <= max_df))[0]\n",
        "print(\"\\nFiltering criteria:\")\n",
        "print(f\"  MIN_DF={MIN_DF}\")\n",
        "print(f\"  MAX_DF_FRAC={MAX_DF_FRAC} -> MAX_DF={max_df}\")\n",
        "print(\"Eligible words:\", eligible.size)\n",
        "\n",
        "if eligible.size == 0:\n",
        "    raise ValueError(\"No eligible words. Lower MIN_DF and/or increase MAX_DF_FRAC.\")\n",
        "\n",
        "if eligible.size < M:\n",
        "    print(f\"[WARN] Eligible words ({eligible.size}) < M ({M}). Shrinking M -> {eligible.size}.\")\n",
        "    M = int(eligible.size)\n",
        "\n",
        "# Select top-M by DF among eligible (keeps moderately common words)\n",
        "top = eligible[np.argsort(df[eligible])[-M:]]\n",
        "top = np.sort(top)\n",
        "\n",
        "print(\"Selected M:\", top.size)\n",
        "print(\"Selected DF stats: min/median/max:\",\n",
        "      int(df[top].min()), int(np.median(df[top])), int(df[top].max()))\n",
        "\n",
        "# Fast mapping: old vocab index -> new [0..M-1] or -1\n",
        "map_arr = -np.ones(V_header, dtype=np.int32)\n",
        "map_arr[top] = np.arange(len(top), dtype=np.int32)\n",
        "\n",
        "vocab_sel = [vocab_all[i] for i in top]\n",
        "print(\"Selected vocab size:\", len(vocab_sel))\n",
        "print(\"Selected vocab sample:\", vocab_sel[:25])\n",
        "\n",
        "# -----------------------\n",
        "# 5) Pass 2: build sparse X (DOC_LIMIT x M)\n",
        "# -----------------------\n",
        "banner(\"5) Pass 2: build sparse doc-term matrix X\")\n",
        "\n",
        "def pass2_build_X(path_gz: str, doc_limit: int, map_arr: np.ndarray, use_log1p: bool = True):\n",
        "    rows = array('I')\n",
        "    cols = array('I')\n",
        "    data = array('f')\n",
        "\n",
        "    with gzip.open(path_gz, \"rt\") as f:\n",
        "        _D = int(f.readline().strip())\n",
        "        _V = int(f.readline().strip())\n",
        "        _NNZ = int(f.readline().strip())\n",
        "\n",
        "        kept = 0\n",
        "        total = 0\n",
        "\n",
        "        for line in tqdm(f, desc=\"Pass2 building X\"):\n",
        "            d_str, w_str, c_str = line.split()\n",
        "            d = int(d_str)\n",
        "            if d > doc_limit:\n",
        "                break\n",
        "            total += 1\n",
        "            w_old = int(w_str) - 1\n",
        "            j = int(map_arr[w_old])\n",
        "            if j >= 0:\n",
        "                kept += 1\n",
        "                c = float(c_str)\n",
        "                val = np.log1p(c) if use_log1p else c\n",
        "                rows.append(d - 1)\n",
        "                cols.append(j)\n",
        "                data.append(val)\n",
        "\n",
        "    X = sp.csr_matrix(\n",
        "        (np.array(data, dtype=np.float32),\n",
        "         (np.array(rows, dtype=np.int32), np.array(cols, dtype=np.int32))),\n",
        "        shape=(doc_limit, int(map_arr.max()) + 1),\n",
        "        dtype=np.float32\n",
        "    )\n",
        "    X.sum_duplicates()\n",
        "    return X, kept, total\n",
        "\n",
        "t0 = tnow()\n",
        "X, kept_nnz, total_lines = pass2_build_X(\"docword.txt.gz\", DOC_LIMIT, map_arr, use_log1p=USE_LOG1P)\n",
        "print(\"Built X in\", fmt_secs(tnow()-t0))\n",
        "print(\"Lines read (within DOC_LIMIT):\", total_lines)\n",
        "print(\"Kept lines (after vocab filter):\", kept_nnz)\n",
        "print(\"X shape:\", X.shape, \"nnz:\", X.nnz, \"density:\", X.nnz/(X.shape[0]*X.shape[1]))\n",
        "\n",
        "# Row length diagnostics\n",
        "row_nnz = np.diff(X.indptr)\n",
        "print(\"Per-doc nnz stats: min/median/max:\", int(row_nnz.min()), int(np.median(row_nnz)), int(row_nnz.max()))\n",
        "\n",
        "if DOC_L2_NORM:\n",
        "    t0 = tnow()\n",
        "    X = sk_normalize(X, norm=\"l2\", axis=1, copy=False)\n",
        "    print(\"Applied per-doc L2 normalization in\", fmt_secs(tnow()-t0))\n",
        "\n",
        "n, d = X.shape\n",
        "print(\"Final n,d:\", n, d)\n",
        "\n",
        "# -----------------------\n",
        "# 6) Mean/variance (centered covariance)\n",
        "# -----------------------\n",
        "banner(\"6) Compute feature mean and variance (for centered covariance and restarts)\")\n",
        "\n",
        "def feature_mean_and_var(Xcsr: sp.csr_matrix):\n",
        "    mu = np.asarray(Xcsr.mean(axis=0)).ravel().astype(np.float32)           # E[x]\n",
        "    m2 = np.asarray(Xcsr.power(2).mean(axis=0)).ravel().astype(np.float32)  # E[x^2]\n",
        "    var = np.maximum(m2 - mu**2, 1e-12).astype(np.float32)                  # Var(x)\n",
        "    return mu, var, m2\n",
        "\n",
        "t0 = tnow()\n",
        "mu, var, m2 = feature_mean_and_var(X)\n",
        "print(\"Computed mu/var in\", fmt_secs(tnow()-t0))\n",
        "print(\"mu stats: min/median/max:\", float(mu.min()), float(np.median(mu)), float(mu.max()))\n",
        "print(\"var stats: min/median/max:\", float(var.min()), float(np.median(var)), float(var.max()))\n",
        "print(\"Top-5 variance words:\")\n",
        "top_var_idx = np.argsort(var)[-5:][::-1]\n",
        "for j in top_var_idx:\n",
        "    print(f\"  {vocab_sel[j]:25s} var={float(var[j]):.6g} mean={float(mu[j]):.6g}\")\n",
        "\n",
        "# -----------------------\n",
        "# 7) Centered covariance operator: Σ̂ U = X^T(XU)/n - μ(μ^T U)\n",
        "# -----------------------\n",
        "banner(\"7) Build centered covariance operator (matmat)\")\n",
        "\n",
        "def make_centered_cov_matmat(Xcsr: sp.csr_matrix, mu: np.ndarray):\n",
        "    n = Xcsr.shape[0]\n",
        "    mu = mu.astype(np.float32)\n",
        "\n",
        "    def matmat(U: np.ndarray) -> np.ndarray:\n",
        "        # U: (d,B)\n",
        "        XU = Xcsr @ U                              # (n,B)\n",
        "        XtXU = (Xcsr.T @ XU) / n                   # (d,B)\n",
        "        muTU = (mu.reshape(1, -1) @ U)             # (1,B)\n",
        "        return XtXU - mu.reshape(-1, 1) * muTU     # (d,B)\n",
        "\n",
        "    return matmat\n",
        "\n",
        "base_mm = make_centered_cov_matmat(X, mu)\n",
        "print(\"Centered covariance operator ready.\")\n",
        "print(\"Sanity check: apply Σ̂ to a random vector and report norm:\")\n",
        "v_test = np.random.default_rng(0).standard_normal(d).astype(np.float32)\n",
        "U_test = v_test.reshape(-1, 1)\n",
        "out = base_mm(U_test).reshape(-1)\n",
        "print(\"  ||v|| =\", float(np.linalg.norm(v_test)), \"||Σ̂ v|| =\", float(np.linalg.norm(out)))\n",
        "\n",
        "# -----------------------\n",
        "# 8) RTPM helpers + restart selection\n",
        "# -----------------------\n",
        "banner(\"8) RTPM helpers + choose basis restarts by largest variance\")\n",
        "\n",
        "def top_r_abs_cols(U: np.ndarray, r: int) -> np.ndarray:\n",
        "    d_, B = U.shape\n",
        "    if r >= d_:\n",
        "        return U\n",
        "    idx = np.argpartition(np.abs(U), -r, axis=0)[-r:, :]  # (r,B)\n",
        "    out = np.zeros_like(U)\n",
        "    cols = np.arange(B)[None, :]\n",
        "    out[idx, cols] = U[idx, cols]\n",
        "    return out\n",
        "\n",
        "def choose_restart_indices_by_var(var: np.ndarray, budget: int):\n",
        "    budget = min(budget, var.size)\n",
        "    idx = np.argsort(var)[-budget:]      # largest variances\n",
        "    return np.sort(idx).astype(np.int32)\n",
        "\n",
        "restart_idx0 = choose_restart_indices_by_var(var, RESTART_BUDGET)\n",
        "print(\"Restart budget:\", restart_idx0.size, \"out of d =\", d)\n",
        "print(\"Top-10 restart seeds by variance (word, var):\")\n",
        "seeds = restart_idx0[-10:][::-1]  # last are highest variance after sort\n",
        "for j in seeds:\n",
        "    print(f\"  seed feature {int(j):6d} word={vocab_sel[int(j)]:25s} var={float(var[int(j)]):.6g}\")\n",
        "\n",
        "# -----------------------\n",
        "# 9) RTPM with early stopping (batch logs)\n",
        "# -----------------------\n",
        "banner(\"9) RTPM (Alg-style) with early stopping + very verbose logs\")\n",
        "\n",
        "def rtpm_subset_with_stop(matmat, d: int, r: int, T_max: int, tol: float,\n",
        "                          restart_idx: np.ndarray, batch: int = 128,\n",
        "                          eps: float = 1e-12,\n",
        "                          verbose_batches: int = 3):\n",
        "    \"\"\"\n",
        "    Basis restarts on restart_idx.\n",
        "    Iteration: U <- top_r(Σ̂ U); normalize columns.\n",
        "    Early stop per batch when best Rayleigh stalls.\n",
        "\n",
        "    verbose_batches: print iteration progress for first few batches (to avoid spam)\n",
        "    \"\"\"\n",
        "    best_obj = -np.inf\n",
        "    best_u = None\n",
        "    best_i = None\n",
        "\n",
        "    m = restart_idx.size\n",
        "    num_batches = (m + batch - 1) // batch\n",
        "    print(f\"RTPM: scanning {m} seeds in {num_batches} batches (batch={batch}), r={r}, T_max={T_max}, tol={tol}\")\n",
        "\n",
        "    for b in range(num_batches):\n",
        "        idx = restart_idx[b*batch : min(m, (b+1)*batch)]\n",
        "        B = idx.size\n",
        "\n",
        "        # init as basis vectors e_i\n",
        "        U = np.zeros((d, B), dtype=np.float32)\n",
        "        U[idx, np.arange(B)] = 1.0\n",
        "\n",
        "        prev_best = None\n",
        "        iters_used = 0\n",
        "\n",
        "        # show some batch headers\n",
        "        if b < verbose_batches:\n",
        "            print(f\"\\n  [Batch {b+1}/{num_batches}] B={B}, seed idx range [{int(idx[0])}, {int(idx[-1])}]\")\n",
        "\n",
        "        for t in range(T_max):\n",
        "            iters_used += 1\n",
        "            U = matmat(U).astype(np.float32)\n",
        "            U = top_r_abs_cols(U, r)\n",
        "            norms = np.linalg.norm(U, axis=0, keepdims=True)\n",
        "            U = U / (norms + eps)\n",
        "\n",
        "            AU = matmat(U).astype(np.float32)\n",
        "            objs = np.sum(U * AU, axis=0)  # Rayleigh per restart\n",
        "            cur_best = float(np.max(objs))\n",
        "            cur_mean = float(np.mean(objs))\n",
        "\n",
        "            if b < verbose_batches and (t == 0 or (t+1) % 5 == 0):\n",
        "                print(f\"    iter {t+1:02d}: best Rayleigh={cur_best:.6g}, mean Rayleigh={cur_mean:.6g}\")\n",
        "\n",
        "            if prev_best is not None and abs(cur_best - prev_best) <= tol * (abs(prev_best) + 1e-8):\n",
        "                if b < verbose_batches:\n",
        "                    print(f\"    early-stop at iter {t+1} (Δbest={abs(cur_best-prev_best):.3g})\")\n",
        "                break\n",
        "            prev_best = cur_best\n",
        "\n",
        "        j = int(np.argmax(objs))\n",
        "        batch_best = float(objs[j])\n",
        "        batch_seed = int(idx[j])\n",
        "\n",
        "        if b < verbose_batches:\n",
        "            print(f\"    batch best: seed={batch_seed}, Rayleigh={batch_best:.6g}, iters_used={iters_used}\")\n",
        "\n",
        "        if batch_best > best_obj:\n",
        "            best_obj = batch_best\n",
        "            best_u = U[:, j].copy()\n",
        "            best_i = batch_seed\n",
        "\n",
        "    print(f\"RTPM result: best seed={best_i}, best Rayleigh={best_obj:.6g}\")\n",
        "    return best_u, best_obj, best_i\n",
        "\n",
        "# -----------------------\n",
        "# 10) Deflation wrapper + top-K run with diagnostics\n",
        "# -----------------------\n",
        "banner(\"10) Top-K components via deflation (with diagnostics)\")\n",
        "\n",
        "def deflated_matmat(base_matmat, comps, lambdas):\n",
        "    if len(comps) == 0:\n",
        "        return base_matmat\n",
        "    Umat = np.stack(comps, axis=1).astype(np.float32)  # (d,m)\n",
        "    lam = np.array(lambdas, dtype=np.float32)          # (m,)\n",
        "\n",
        "    def matmat_def(U: np.ndarray) -> np.ndarray:\n",
        "        Y = base_matmat(U).astype(np.float32)\n",
        "        proj = Umat.T @ U\n",
        "        Y -= Umat @ (lam[:, None] * proj)\n",
        "        return Y\n",
        "\n",
        "    return matmat_def\n",
        "\n",
        "def orthogonality_matrix(U):\n",
        "    # returns U^T U\n",
        "    return (U.T @ U).astype(np.float32)\n",
        "\n",
        "comps, lams, objs = [], [], []\n",
        "mm = base_mm\n",
        "\n",
        "for comp_id in range(K):\n",
        "    print(f\"\\n----- Computing component {comp_id+1}/{K} -----\")\n",
        "    t0 = tnow()\n",
        "    u, obj, start_i = rtpm_subset_with_stop(\n",
        "        mm, d=d, r=R, T_max=T_MAX, tol=TOL,\n",
        "        restart_idx=restart_idx0, batch=BATCH_BASIS,\n",
        "        verbose_batches=3\n",
        "    )\n",
        "    elapsed = fmt_secs(tnow()-t0)\n",
        "    print(f\"Component {comp_id+1} RTPM done in {elapsed}\")\n",
        "\n",
        "    Au = mm(u.reshape(-1, 1)).reshape(-1)\n",
        "    lam = float(u @ Au)\n",
        "    nnz = int(np.count_nonzero(u))\n",
        "    print(f\"  Rayleigh(obj)={obj:.6g}, lambda(u^T Σ u)={lam:.6g}, seed={start_i}, nnz={nnz}\")\n",
        "\n",
        "    comps.append(u.astype(np.float32))\n",
        "    lams.append(lam)\n",
        "    objs.append(obj)\n",
        "\n",
        "    mm = deflated_matmat(base_mm, comps, lams)\n",
        "    print(f\"  Deflation updated. Lambdas so far: {[float(x) for x in lams]}\")\n",
        "\n",
        "U_sparse = np.column_stack(comps).astype(np.float32)\n",
        "print(\"\\nAll components computed. U_sparse shape:\", U_sparse.shape)\n",
        "\n",
        "# -----------------------\n",
        "# 11) Print top words with sign alignment\n",
        "# -----------------------\n",
        "banner(\"11) Print top words for each component\")\n",
        "\n",
        "def align_component_sign(u: np.ndarray) -> np.ndarray:\n",
        "    j = np.argmax(np.abs(u))\n",
        "    return u if u[j] >= 0 else -u\n",
        "\n",
        "def top_words_abs(u: np.ndarray, vocab, topn=20):\n",
        "    idx = np.argsort(np.abs(u))[-topn:][::-1]\n",
        "    return [(vocab[i], float(u[i])) for i in idx]\n",
        "\n",
        "def top_words_signed(u: np.ndarray, vocab, top_pos=12, top_neg=12):\n",
        "    pos_idx = np.where(u > 0)[0]\n",
        "    neg_idx = np.where(u < 0)[0]\n",
        "    pos_sorted = pos_idx[np.argsort(u[pos_idx])][::-1][:top_pos]\n",
        "    neg_sorted = neg_idx[np.argsort(u[neg_idx])][:top_neg]\n",
        "    pos = [(vocab[i], float(u[i])) for i in pos_sorted]\n",
        "    neg = [(vocab[i], float(u[i])) for i in neg_sorted]\n",
        "    return pos, neg\n",
        "\n",
        "U_print = U_sparse.copy()\n",
        "for j in range(U_print.shape[1]):\n",
        "    U_print[:, j] = align_component_sign(U_print[:, j])\n",
        "\n",
        "for j in range(min(5, U_print.shape[1])):\n",
        "    u = U_print[:, j]\n",
        "    print(f\"\\n==================== PC {j+1} ====================\")\n",
        "    print(\"Top |loading| words:\")\n",
        "    for w, val in top_words_abs(u, vocab_sel, topn=TOP_WORDS_ABS):\n",
        "        print(f\"  {w:25s} {val:+.5f}\")\n",
        "\n",
        "    pos, neg = top_words_signed(u, vocab_sel, top_pos=TOP_WORDS_POS, top_neg=TOP_WORDS_NEG)\n",
        "    print(\"\\nTop POS words:\")\n",
        "    for w, val in pos:\n",
        "        print(f\"  {w:25s} {val:+.5f}\")\n",
        "\n",
        "    print(\"\\nTop NEG words:\")\n",
        "    for w, val in neg:\n",
        "        print(f\"  {w:25s} {val:+.5f}\")\n",
        "\n",
        "# -----------------------\n",
        "# 12) Representative docs (CENTERED score = <x - mu, u>)\n",
        "# -----------------------\n",
        "banner(\"12) Representative documents for each PC (debuggable interpretation)\")\n",
        "\n",
        "def doc_scores_centered(Xcsr: sp.csr_matrix, u: np.ndarray, mu: np.ndarray):\n",
        "    base = float(mu @ u.astype(np.float32))\n",
        "    scores = (Xcsr @ u.astype(np.float32)) - base\n",
        "    return np.asarray(scores).ravel()\n",
        "\n",
        "def print_doc_top_words(Xcsr: sp.csr_matrix, vocab, doc_id: int, topn: int = 12):\n",
        "    row = Xcsr.getrow(doc_id)\n",
        "    cols = row.indices\n",
        "    vals = row.data\n",
        "    if len(cols) == 0:\n",
        "        print(f\"Doc #{doc_id}: (empty)\")\n",
        "        return\n",
        "    order = np.argsort(vals)[::-1][:topn]\n",
        "    print(f\"Doc #{doc_id} top words:\")\n",
        "    for t in order:\n",
        "        print(f\"  {vocab[cols[t]]:25s} {float(vals[t]):.5f}\")\n",
        "\n",
        "for j in range(min(5, U_print.shape[1])):\n",
        "    u = U_print[:, j]\n",
        "    scores = doc_scores_centered(X, u, mu)\n",
        "\n",
        "    top_pos_docs = np.argsort(scores)[-TOP_DOCS:][::-1]\n",
        "    top_neg_docs = np.argsort(scores)[:TOP_DOCS]\n",
        "\n",
        "    print(f\"\\n==================== PC {j+1} Docs ====================\")\n",
        "    print(\"Score stats: min/median/max:\",\n",
        "          float(scores.min()), float(np.median(scores)), float(scores.max()))\n",
        "    print(\"Top POS docs:\", top_pos_docs, \"scores:\", scores[top_pos_docs])\n",
        "    for di in top_pos_docs:\n",
        "        print_doc_top_words(X, vocab_sel, int(di), topn=12)\n",
        "\n",
        "    print(\"\\nTop NEG docs:\", top_neg_docs, \"scores:\", scores[top_neg_docs])\n",
        "    for di in top_neg_docs:\n",
        "        print_doc_top_words(X, vocab_sel, int(di), topn=12)\n",
        "\n",
        "# -----------------------\n",
        "# 13) Final debug summary\n",
        "# -----------------------\n",
        "banner(\"13) Summary of what happened (debug recap)\")\n",
        "\n",
        "print(f\"Final X: n={n}, d={d}, nnz={X.nnz}, density={X.nnz/(n*d):.3e}\")\n",
        "print(f\"DF filtering: MIN_DF={MIN_DF}, MAX_DF={int(MAX_DF_FRAC*DOC_LIMIT)} -> selected vocab size M={len(vocab_sel)}\")\n",
        "print(f\"RTPM: r={R}, T_MAX={T_MAX}, tol={TOL}, restart_budget={RESTART_BUDGET}, batch={BATCH_BASIS}\")\n",
        "print(\"Lambdas (Rayleigh per component):\", [float(x) for x in lams])\n",
        "print(\"Done.\")"
      ],
      "metadata": {
        "id": "OUGeMR8jrkGr"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ============================================================\n",
        "# 13) Visualizations (CENTERED covariance + correlation)\n",
        "# ============================================================\n",
        "\n",
        "banner(\"13) Visualizations: centered covariance & correlation heatmaps; eigenvector plots (seaborn + clean words)\")\n",
        "\n",
        "!pip -q install seaborn\n",
        "import seaborn as sns\n",
        "\n",
        "# ---------- Word cleaning ----------\n",
        "def clean_word(w: str) -> str:\n",
        "    return w[4:] if w.startswith(\"zzz_\") else w\n",
        "\n",
        "def clean_words(ws):\n",
        "    return [clean_word(w) for w in ws]\n",
        "\n",
        "# ---------- Seaborn styling ----------\n",
        "sns.set_theme(style=\"whitegrid\", context=\"talk\")\n",
        "sns.set_context(\"talk\", font_scale=0.95)\n",
        "plt.rcParams.update({\n",
        "    \"figure.dpi\": 140,\n",
        "    \"savefig.dpi\": 140,\n",
        "})\n",
        "\n",
        "# ---------- Heatmap helpers ----------\n",
        "def choose_feature_subset_for_viz(U_print, var, m_total=200, extra_by_var=800):\n",
        "    supp = set()\n",
        "    for j in range(min(5, U_print.shape[1])):\n",
        "        supp |= set(np.nonzero(U_print[:, j])[0].tolist())\n",
        "    supp = sorted(supp)\n",
        "\n",
        "    if len(supp) >= m_total:\n",
        "        return np.array(supp[:m_total], dtype=np.int32)\n",
        "\n",
        "    extra = np.argsort(var)[-extra_by_var:][::-1]\n",
        "    subset = list(dict.fromkeys(supp + extra.tolist()))\n",
        "    subset = subset[:m_total]\n",
        "    return np.array(subset, dtype=np.int32)\n",
        "\n",
        "def dense_cov_corr_from_sparse(Xcsr: sp.csr_matrix, idx: np.ndarray, mu_full: np.ndarray, eps=1e-8):\n",
        "    Xs = Xcsr[:, idx].toarray().astype(np.float32)\n",
        "    mu_sub = mu_full[idx].astype(np.float32).reshape(1, -1)\n",
        "    Xc = Xs - mu_sub\n",
        "    n = Xc.shape[0]\n",
        "    Cov = (Xc.T @ Xc) / n\n",
        "    sd = np.sqrt(np.maximum(np.diag(Cov), eps)).astype(np.float32)\n",
        "    Corr = Cov / (sd[:, None] * sd[None, :])\n",
        "    return Cov, Corr\n",
        "\n",
        "def seaborn_heatmap(M, title, cmap=\"viridis\", show_tick_labels=False, tick_labels=None, max_ticks=40):\n",
        "    fig, ax = plt.subplots(figsize=(8.2, 7.0))\n",
        "    if show_tick_labels and tick_labels is not None and len(tick_labels) <= max_ticks:\n",
        "        labels = clean_words(tick_labels)\n",
        "        sns.heatmap(M, ax=ax, cmap=cmap, square=True, cbar=True,\n",
        "                    xticklabels=labels, yticklabels=labels)\n",
        "        ax.tick_params(axis=\"x\", rotation=90, labelsize=9)\n",
        "        ax.tick_params(axis=\"y\", rotation=0, labelsize=9)\n",
        "    else:\n",
        "        sns.heatmap(M, ax=ax, cmap=cmap, square=True, cbar=True,\n",
        "                    xticklabels=False, yticklabels=False)\n",
        "    ax.set_title(title, pad=12)\n",
        "    plt.tight_layout()\n",
        "    plt.show()\n",
        "\n",
        "# ---------- Run heatmaps ----------\n",
        "idx_viz = choose_feature_subset_for_viz(U_print, var, m_total=VIZ_M_TOTAL, extra_by_var=1200)\n",
        "words_viz = [vocab_sel[i] for i in idx_viz]\n",
        "\n",
        "print(\"Heatmap feature subset size:\", len(idx_viz))\n",
        "print(\"Heatmap subset includes\", len(set(idx_viz) & set(np.nonzero(U_print[:,0])[0])), \"features from PC1 support\")\n",
        "\n",
        "Cov_c, Corr = dense_cov_corr_from_sparse(X, idx_viz, mu)\n",
        "\n",
        "seaborn_heatmap(Cov_c, \"Centered covariance heatmap (subset)\", cmap=\"viridis\",\n",
        "                show_tick_labels=False, tick_labels=words_viz)\n",
        "seaborn_heatmap(Corr,  \"Correlation heatmap (subset)\",        cmap=\"coolwarm\",\n",
        "                show_tick_labels=False, tick_labels=words_viz)\n",
        "\n",
        "# ============================================================\n",
        "# Eigenvector plots — seaborn + clean top-10 words\n",
        "# ============================================================\n",
        "\n",
        "def plot_component_support(u, title):\n",
        "    nz = np.nonzero(u)[0]\n",
        "    vals = u[nz]\n",
        "    dfp = {\"index\": nz, \"entry\": vals}  # renamed key\n",
        "    fig, ax = plt.subplots(figsize=(9.4, 3.8))\n",
        "    sns.scatterplot(x=\"index\", y=\"entry\", data=dfp, s=30, alpha=0.85, ax=ax)\n",
        "    ax.axhline(0.0, linewidth=1.2, alpha=0.5)\n",
        "    ax.set_title(title + f\" | nnz={len(nz)}\", pad=10)\n",
        "    ax.set_xlabel(\"feature index\")\n",
        "    ax.set_ylabel(\"entry\")\n",
        "    plt.tight_layout()\n",
        "    plt.show()\n",
        "\n",
        "def plot_component_sorted_abs(u, title, topn=250):\n",
        "    a = np.sort(np.abs(u))[::-1]\n",
        "    topn = min(topn, len(a))\n",
        "    fig, ax = plt.subplots(figsize=(8.8, 3.8))\n",
        "    sns.lineplot(x=np.arange(1, topn+1), y=a[:topn], linewidth=2.2, ax=ax)\n",
        "    ax.set_title(title + \" | sorted |entry| (top)\", pad=10)\n",
        "    ax.set_xlabel(\"rank\")\n",
        "    ax.set_ylabel(\"|entry|\")\n",
        "    plt.tight_layout()\n",
        "    plt.show()\n",
        "\n",
        "def plot_component_top_words(u, vocab, topn=10, title=\"\"):\n",
        "    idx = np.argsort(np.abs(u))[-topn:][::-1]\n",
        "    labels = clean_words([vocab[i] for i in idx])\n",
        "    vals = u[idx]\n",
        "\n",
        "    fig, ax = plt.subplots(figsize=(9.0, 4.4))\n",
        "    sns.barplot(x=labels, y=vals, ax=ax)\n",
        "    ax.axhline(0.0, linewidth=1.2, alpha=0.6)\n",
        "    ax.set_title(title + f\" (top-{topn} words)\", pad=10)\n",
        "    ax.set_xlabel(\"\")\n",
        "    ax.set_ylabel(\"entry\")\n",
        "    ax.tick_params(axis=\"x\", rotation=30, labelsize=12)\n",
        "    plt.tight_layout()\n",
        "    plt.show()\n",
        "\n",
        "for j in range(min(5, U_print.shape[1])):\n",
        "    u = U_print[:, j]\n",
        "    plot_component_support(u, f\"PC {j+1}: nonzero entries (support scatter)\")\n",
        "    plot_component_sorted_abs(u, f\"PC {j+1}\", topn=250)\n",
        "    plot_component_top_words(u, vocab_sel, topn=10, title=f\"PC {j+1}: top words by |entry|\")\n",
        "\n",
        "# ============================================================\n",
        "# Component × feature heatmap on union support — seaborn\n",
        "# ============================================================\n",
        "\n",
        "def union_support(U, num_pcs=4):\n",
        "    supp = set()\n",
        "    for j in range(min(num_pcs, U.shape[1])):\n",
        "        supp |= set(np.nonzero(U[:, j])[0].tolist())\n",
        "    return np.array(sorted(supp), dtype=np.int32)\n",
        "\n",
        "supp = union_support(U_print, num_pcs=4)\n",
        "print(\"Union support size (PC1..PC4):\", len(supp))\n",
        "\n",
        "W = U_print[supp, :4].T  # (4, |supp|)\n",
        "fig, ax = plt.subplots(figsize=(11.2, 4.0))\n",
        "sns.heatmap(W, ax=ax, cmap=\"coolwarm\", cbar=True, xticklabels=False,\n",
        "            yticklabels=[f\"PC{j+1}\" for j in range(4)])\n",
        "ax.set_title(\"Top-4 components restricted to union support\", pad=12)\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "# ============================================================\n",
        "# Component × feature heatmap on union support — PER-PC SORTED BY |entry|\n",
        "# ============================================================\n",
        "\n",
        "supp = union_support(U_print, num_pcs=4)\n",
        "print(\"Union support size (PC1..PC4):\", len(supp))\n",
        "\n",
        "W_rank = np.zeros((4, len(supp)), dtype=np.float32)\n",
        "for j in range(4):\n",
        "    vals = U_print[supp, j]\n",
        "    ordj = np.argsort(np.abs(vals))[::-1]\n",
        "    W_rank[j, :] = vals[ordj]\n",
        "\n",
        "fig, ax = plt.subplots(figsize=(12.2, 4.4))\n",
        "sns.heatmap(\n",
        "    W_rank[:, 0:50],\n",
        "    ax=ax,\n",
        "    cmap=\"coolwarm\",\n",
        "    center=0.0,\n",
        "    cbar=True,\n",
        "    xticklabels=False,\n",
        "    yticklabels=[f\"PC{j+1}\" for j in range(4)]\n",
        ")\n",
        "\n",
        "ax.set_title(\"Top-4 components on union support (each row sorted by |entry| rank)\", pad=12)\n",
        "ax.set_xlabel(\"Entry rank within each PC (largest → smallest |entry|)\")\n",
        "\n",
        "L = W_rank.shape[1]\n",
        "ranks = [10, 25, 50]\n",
        "ranks = sorted({r for r in ranks if 1 <= r <= L})\n",
        "\n",
        "xticks = [(r - 1) + 0.5 for r in ranks]\n",
        "ax.set_xticks(xticks)\n",
        "ax.set_xticklabels([str(r) for r in ranks], rotation=90, ha=\"center\", fontsize=15)\n",
        "\n",
        "ax.tick_params(axis=\"x\", pad=0.1)\n",
        "plt.tight_layout()\n",
        "plt.show()"
      ],
      "metadata": {
        "id": "61YgsOwdsNKQ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "TIpkUSuexJBN"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}