{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "eef70c28",
   "metadata": {},
   "source": [
    "# Token Substitution Community Detection\n",
    "\n",
    "This notebook loads a sparse substitution counts matrix between tokens, constructs a graph, and applies several community detection algorithms (Spectral Clustering, Louvain, Girvan-Newman) to group similar tokens. Visualizations and modularity scores are used to evaluate the quality of the clusters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6867afc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import defaultdict\n",
    "import networkx as nx\n",
    "import community as community_louvain\n",
    "from sklearn.cluster import SpectralClustering\n",
    "from networkx.algorithms.community.quality import modularity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0da46c78-be37-46bc-8f46-a09db5be597c",
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = list(range(2048))\n",
    "\n",
    "raw_sub_matrix = np.load(\"outputs/confusion/matrices/confusion_0.npy\")\n",
    "sym_sub_matrix = (raw_sub_matrix + raw_sub_matrix.T) / 2  # Make it symmetric for clustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9e4d323f-b591-4075-b257-2fdf380d7ad3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def lists_from_labels(labels, offset=32000):\n",
    "    communities = defaultdict(list)\n",
    "    for node, label in enumerate(labels):\n",
    "        communities[label].append(node + offset)\n",
    "    return list(communities.values())\n",
    "\n",
    "def dict_from_labels(labels, offset=32000):\n",
    "    return {\n",
    "        (token_id + offset): cluster_id.item()\n",
    "        for token_id, cluster_id in enumerate(labels)\n",
    "    }\n",
    "\n",
    "def lists_from_dict(partition):\n",
    "    comms = defaultdict(list)\n",
    "    for node, cid in partition.items():\n",
    "        comms[cid].append(node)\n",
    "    return list(comms.values())\n",
    "\n",
    "def dict_from_lists(lists):\n",
    "    return {\n",
    "        token_id: cluster_id\n",
    "        for cluster_id, cluster in enumerate(lists)\n",
    "        for token_id in cluster\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c468b8f0-668e-4431-afeb-f43f65932c2d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/storage/miniconda3/envs/wmar/lib/python3.11/site-packages/sklearn/manifold/_spectral_embedding.py:309: UserWarning: Array is not symmetric, and will be converted to symmetric by average with its transpose.\n",
      "  adjacency = check_symmetric(adjacency)\n",
      "/storage/miniconda3/envs/wmar/lib/python3.11/site-packages/sklearn/manifold/_spectral_embedding.py:328: UserWarning: Graph is not fully connected, spectral embedding may not work as expected.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[9 3 9 ... 3 9 0]\n",
      "Modularity (sklearn SpectralClustering): 0.5302220731009741\n"
     ]
    }
   ],
   "source": [
    "def spectral(matrix, n_clusters):\n",
    "    spectral = SpectralClustering(n_clusters=n_clusters, affinity='precomputed', random_state=42)\n",
    "    labels = spectral.fit_predict(matrix)\n",
    "    return labels\n",
    "\n",
    "\n",
    "n_clusters = 10\n",
    "labels_sklearn = spectral(raw_sub_matrix, n_clusters)\n",
    "print(labels_sklearn)\n",
    "\n",
    "G_full = nx.from_numpy_array(raw_sub_matrix)\n",
    "mod_sklearn = modularity(G_full, lists_from_labels(labels_sklearn, offset=0))\n",
    "print(\"Modularity (sklearn SpectralClustering):\", mod_sklearn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d9ba40d3-f64a-40cc-beeb-b6db36844f94",
   "metadata": {},
   "outputs": [],
   "source": [
    "# values = [labels_sklearn[node] for node in G_full.nodes()]\n",
    "# nx.draw_spring(G_full, node_color=values, node_size=30, with_labels=False)\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ad1e9271-90ae-468d-90f1-9961fed6bf12",
   "metadata": {},
   "outputs": [],
   "source": [
    "def louvain_clusters(S, min_sub_count=1):\n",
    "    G = nx.from_numpy_array(np.where(S >= min_sub_count, S, 0))\n",
    "    n_nodes = S.shape[0]\n",
    "    n_components = len(list(nx.connected_components(G)))\n",
    "\n",
    "    # Louvain clustering on the graph\n",
    "    partition = community_louvain.best_partition(G, weight='weight')\n",
    "\n",
    "    # Identify unassigned nodes\n",
    "    assigned_nodes = set(partition.keys())\n",
    "    all_nodes = set(range(n_nodes))\n",
    "    unassigned = all_nodes - assigned_nodes\n",
    "\n",
    "    # Reassign unassigned nodes to strongest neighbor's community\n",
    "    for node in unassigned:\n",
    "        # Find the strongest neighbor above threshold\n",
    "        neighbors = [(j, S[node, j]) for j in range(n_nodes) if j in partition and S[node, j] >= min_sub_count]\n",
    "        if neighbors:\n",
    "            # Pick neighbor with highest substitution count\n",
    "            best_neighbor = max(neighbors, key=lambda x: x[1])[0]\n",
    "            partition[node] = partition[best_neighbor]\n",
    "        else:\n",
    "            # No eligible neighbors; assign to new unique community\n",
    "            partition[node] = max(partition.values()) + 1\n",
    "\n",
    "    mod_louvain = modularity(G, lists_from_dict(partition))\n",
    "    print(f\"Modularity and components (min_sub={min_sub_count}):\", round(mod_louvain, 4), n_components)\n",
    "\n",
    "    # Convert partition dict to label array\n",
    "    labels = np.zeros(n_nodes, dtype=int)\n",
    "    for node in range(n_nodes):\n",
    "        labels[node] = partition[node]\n",
    "\n",
    "    # Optional: reindex labels to 0...K-1\n",
    "    unique_ids = {old: new for new, old in enumerate(sorted(set(labels)))}\n",
    "    labels = np.array([unique_ids[l] for l in labels])\n",
    "    return labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b66eb13c",
   "metadata": {},
   "outputs": [],
   "source": [
    "audio_token_start = 32002  # first token\n",
    "audio_token_end = 32502    # last token\n",
    "\n",
    "\n",
    "def merge_clusters(clusterings, n):\n",
    "    import heapq\n",
    "\n",
    "    print(\"Merging...\")\n",
    "\n",
    "    # Sort input clusters by size (descending) to greedily balance large clusters\n",
    "    clusterings = sorted(clusterings, key=len, reverse=True)\n",
    "\n",
    "    # Use a min-heap to keep track of current size of each output cluster\n",
    "    merged = [[] for _ in range(n)]\n",
    "    heap = [(0, i) for i in range(n)]  # (total size, index)\n",
    "    heapq.heapify(heap)\n",
    "\n",
    "    for cluster in clusterings:\n",
    "        size, idx = heapq.heappop(heap)\n",
    "        merged[idx].extend(cluster)\n",
    "        heapq.heappush(heap, (size + len(cluster), idx))\n",
    "    return merged\n",
    "\n",
    "\n",
    "def create_synonym_splits(clusterings, n):\n",
    "    clusters = {}\n",
    "    for cluster_id, token_ids in enumerate(clusterings):\n",
    "        for token_id in token_ids:\n",
    "            clusters[token_id] = cluster_id\n",
    "\n",
    "    # If the clusters are more than the splits, merge them\n",
    "    if n < len(clusterings):\n",
    "        splits = merge_clusters(clusterings, n)\n",
    "    else:\n",
    "        splits = [[] for _ in range(n)]\n",
    "        for token in clusters.keys():\n",
    "            splits[clusters[token]].append(token)\n",
    "\n",
    "    # # Distribute the rest of the audio tokens evenly\n",
    "    # for token in range(audio_token_start, audio_token_end):\n",
    "    #     if token not in clusters.keys():\n",
    "    #         min_split = np.argmin(np.array([len(split) for split in splits]))\n",
    "    #         splits[min_split].append(token)\n",
    "\n",
    "    return splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8753fcf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import igraph as ig\n",
    "import leidenalg\n",
    "\n",
    "\n",
    "def leiden_clusters(S, min_sub_count=1):\n",
    "    G = nx.from_numpy_array(np.where(S >= min_sub_count, S, 0))\n",
    "\n",
    "    edges = []\n",
    "    weights = []\n",
    "    V = S.shape[0]\n",
    "\n",
    "    for i in range(V):\n",
    "        for j in range(i + 1, V):\n",
    "            if S[i, j] >= min_sub_count:\n",
    "                edges.append((i, j))\n",
    "                weights.append(S[i, j])\n",
    "\n",
    "    g = ig.Graph(edges=edges, n=V)\n",
    "    g.es[\"weight\"] = weights\n",
    "\n",
    "    partition = leidenalg.find_partition(\n",
    "        g,\n",
    "        leidenalg.ModularityVertexPartition,\n",
    "        weights=weights,\n",
    "    )\n",
    "\n",
    "    labels = np.zeros(V, dtype=int)\n",
    "    for cid, community in enumerate(partition):\n",
    "        for v in community:\n",
    "            labels[v] = cid\n",
    "\n",
    "    comms = defaultdict(set)\n",
    "    for node, cid in enumerate(labels):\n",
    "        comms[cid].add(node)\n",
    "\n",
    "    communities = list(comms.values())\n",
    "    modularity = nx.algorithms.community.modularity(\n",
    "        G, communities, weight=\"weight\"\n",
    "    )\n",
    "    print(f\"Modularity and components (min_sub={min_sub_count}):\", round(modularity, 4), len(comms))\n",
    "    return labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "293bb3ee",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Channel: 0\n",
      "Modularity and components (min_sub=1): 0.5828 191\n",
      "Modularity and components (min_sub=1): 0.5539 287\n",
      "Modularity and components (min_sub=5): 0.7881 852\n",
      "Modularity and components (min_sub=5): 0.6642 1053\n",
      "Modularity and components (min_sub=10): 0.8183 1430\n",
      "Modularity and components (min_sub=10): 0.5862 1627\n",
      "Modularity and components (min_sub=25): 0.7738 1907\n",
      "Modularity and components (min_sub=25): 0.4937 1965\n",
      "Modularity and components (min_sub=50): 0.7281 2002\n",
      "Modularity and components (min_sub=50): 0.3682 2026\n",
      "Modularity and components (min_sub=100): 0.752 2038\n",
      "Modularity and components (min_sub=100): 0.3574 2043\n",
      "Channel: 1\n",
      "Modularity and components (min_sub=1): 0.5845 191\n",
      "Modularity and components (min_sub=1): 0.541 251\n",
      "Modularity and components (min_sub=5): 0.7914 852\n",
      "Modularity and components (min_sub=5): 0.5286 889\n",
      "Modularity and components (min_sub=10): 0.8281 1430\n",
      "Modularity and components (min_sub=10): 0.5225 1214\n",
      "Modularity and components (min_sub=25): 0.7714 1907\n",
      "Modularity and components (min_sub=25): 0.5149 1675\n",
      "Modularity and components (min_sub=50): 0.7281 2002\n",
      "Modularity and components (min_sub=50): 0.5127 1876\n",
      "Modularity and components (min_sub=100): 0.752 2038\n",
      "Modularity and components (min_sub=100): 0.4606 1987\n",
      "Channel: 2\n",
      "Modularity and components (min_sub=1): 0.5812 191\n",
      "Modularity and components (min_sub=1): 0.3354 62\n",
      "Modularity and components (min_sub=5): 0.7891 852\n",
      "Modularity and components (min_sub=5): 0.4618 1214\n",
      "Modularity and components (min_sub=10): 0.8279 1430\n",
      "Modularity and components (min_sub=10): 0.3759 1761\n",
      "Modularity and components (min_sub=25): 0.7715 1907\n",
      "Modularity and components (min_sub=25): 0.1565 1998\n",
      "Modularity and components (min_sub=50): 0.7281 2002\n",
      "Modularity and components (min_sub=50): 0.068 2030\n",
      "Modularity and components (min_sub=100): 0.752 2038\n",
      "Modularity and components (min_sub=100): 0.0084 2040\n",
      "Channel: 3\n",
      "Modularity and components (min_sub=1): 0.5845 191\n",
      "Modularity and components (min_sub=1): 0.2702 25\n",
      "Modularity and components (min_sub=5): 0.7893 852\n",
      "Modularity and components (min_sub=5): 0.2586 1725\n",
      "Modularity and components (min_sub=10): 0.8288 1430\n",
      "Modularity and components (min_sub=10): 0.1688 1935\n",
      "Modularity and components (min_sub=25): 0.7681 1907\n",
      "Modularity and components (min_sub=25): 0.0842 2019\n",
      "Modularity and components (min_sub=50): 0.7281 2002\n",
      "Modularity and components (min_sub=50): 0.0402 2032\n",
      "Modularity and components (min_sub=100): 0.752 2038\n",
      "Modularity and components (min_sub=100): -0.1254 2042\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "\n",
    "\n",
    "for channel, name in enumerate(['rvq_first_0', 'rvq_rest_0', 'rvq_rest_1', 'rvq_rest_2']):\n",
    "    print(\"Channel:\", channel)\n",
    "    matrix = np.load(f\"outputs/confusion/matrices/confusion_{channel}.npy\")\n",
    "\n",
    "    louvain_clusterings = {}\n",
    "    leiden_clusterings = {}\n",
    "    for count in [1, 5, 10, 25, 50, 100]:\n",
    "        louvain_labels = louvain_clusters(raw_sub_matrix, min_sub_count=count)\n",
    "        leiden_labels = leiden_clusters(matrix, min_sub_count=count)\n",
    "        louvain_clusterings[count] = louvain_labels\n",
    "        leiden_clusterings[count] = leiden_labels\n",
    "\n",
    "    with open(f\"/home/AlignedIS-dev/models/embeddings/clusterings/mimi_louvain_{name}_clusterings.pkl\", \"wb\") as f:\n",
    "        pickle.dump(louvain_clusterings, f)\n",
    "\n",
    "    with open(f\"/home/AlignedIS-dev/models/embeddings/clusterings/mimi_leiden_{name}_clusterings.pkl\", \"wb\") as f:\n",
    "        pickle.dump(leiden_clusterings, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "wmar",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
