{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f70fad0c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Channel 0: used 1707/2048 (unseen: 341)\n",
      "Channel 1: used 1893/2048 (unseen: 155)\n",
      "Channel 2: used 2001/2048 (unseen: 47)\n",
      "Channel 3: used 2033/2048 (unseen: 15)\n",
      "Channel 4: used 2036/2048 (unseen: 12)\n",
      "Channel 5: used 2036/2048 (unseen: 12)\n",
      "Channel 6: used 2045/2048 (unseen: 3)\n",
      "Channel 7: used 2046/2048 (unseen: 2)\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "samples_dir = \"/home/wmar/wmar_audio/outputs/rcc_reconstructions/original_mimi/samples\"\n",
    "matrix_dir = \"/home/wmar/wmar_audio/outputs/confusion/matrices_new\"\n",
    "os.makedirs(matrix_dir, exist_ok=True)\n",
    "\n",
    "vocab_size = 2048\n",
    "num_channels = 8\n",
    "\n",
    "# Initialize containers\n",
    "conf_matrices = np.zeros((num_channels, vocab_size, vocab_size), dtype=np.int64)\n",
    "usage_counts = np.zeros((num_channels, vocab_size), dtype=np.int64)\n",
    "\n",
    "# Process files\n",
    "for name in sorted(os.listdir(samples_dir)):\n",
    "    sample_dir = os.path.join(samples_dir, name)\n",
    "    if not os.path.isdir(sample_dir):\n",
    "        continue\n",
    "\n",
    "    orig_path = os.path.join(sample_dir, \"orig_tokens.txt\")\n",
    "    rt_path = os.path.join(sample_dir, \"roundtrip_tokens.txt\")\n",
    "\n",
    "    if not (os.path.exists(orig_path) and os.path.exists(rt_path)):\n",
    "        continue\n",
    "\n",
    "    # Load tokens\n",
    "    with open(orig_path) as f:\n",
    "        orig = np.array([[int(x) for x in l.split()] for l in f if l.strip()], dtype=np.int64)\n",
    "    with open(rt_path) as f:\n",
    "        rt = np.array([[int(x) for x in l.split()] for l in f if l.strip()], dtype=np.int64)\n",
    "\n",
    "    if orig.shape != rt.shape:\n",
    "        continue\n",
    "\n",
    "    # Update stats\n",
    "    for ch in range(min(num_channels, orig.shape[0])):\n",
    "        row_orig = orig[ch]\n",
    "        row_rt = rt[ch]\n",
    "        \n",
    "        # Update usage histogram\n",
    "        np.add.at(usage_counts[ch], row_orig, 1)\n",
    "\n",
    "        # Update confusion matrix (errors only)\n",
    "        mask = row_orig != row_rt\n",
    "        if np.any(mask):\n",
    "            np.add.at(conf_matrices[ch], (row_orig[mask], row_rt[mask]), 1)\n",
    "\n",
    "# Print stats and save\n",
    "for ch in range(num_channels):\n",
    "    unseen = np.sum(usage_counts[ch] == 0)\n",
    "    print(f\"Channel {ch}: used {vocab_size - unseen}/{vocab_size} (unseen: {unseen})\")\n",
    "    np.save(os.path.join(matrix_dir, f\"confusion_{ch}.npy\"), conf_matrices[ch])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "67ce2908",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Channel 0: used 1949/2048 (unseen: 99)\n",
      "Channel 1: used 1887/2048 (unseen: 161)\n",
      "Channel 2: used 1886/2048 (unseen: 162)\n",
      "Channel 3: used 1917/2048 (unseen: 131)\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "samples_dir = \"/home/wmar/wmar_audio/outputs/rcc_reconstructions/original_encodec_32/samples\"\n",
    "matrix_dir = \"/home/wmar/wmar_audio/outputs/confusion/matrices_encodec\"\n",
    "os.makedirs(matrix_dir, exist_ok=True)\n",
    "\n",
    "vocab_size = 2048\n",
    "num_channels = 4\n",
    "\n",
    "# Initialize containers\n",
    "conf_matrices = np.zeros((num_channels, vocab_size, vocab_size), dtype=np.int64)\n",
    "usage_counts = np.zeros((num_channels, vocab_size), dtype=np.int64)\n",
    "\n",
    "# Process files\n",
    "for name in sorted(os.listdir(samples_dir)):\n",
    "    sample_dir = os.path.join(samples_dir, name)\n",
    "    if not os.path.isdir(sample_dir):\n",
    "        continue\n",
    "\n",
    "    orig_path = os.path.join(sample_dir, \"orig_tokens.txt\")\n",
    "    rt_path = os.path.join(sample_dir, \"roundtrip_tokens.txt\")\n",
    "\n",
    "    if not (os.path.exists(orig_path) and os.path.exists(rt_path)):\n",
    "        continue\n",
    "\n",
    "    # Load tokens\n",
    "    with open(orig_path) as f:\n",
    "        orig = np.array([[int(x) for x in l.split()] for l in f if l.strip()], dtype=np.int64)\n",
    "    with open(rt_path) as f:\n",
    "        rt = np.array([[int(x) for x in l.split()] for l in f if l.strip()], dtype=np.int64)\n",
    "\n",
    "    if orig.shape != rt.shape:\n",
    "        continue\n",
    "\n",
    "    # Update stats\n",
    "    for ch in range(min(num_channels, orig.shape[0])):\n",
    "        row_orig = orig[ch]\n",
    "        row_rt = rt[ch]\n",
    "        \n",
    "        # Update usage histogram\n",
    "        np.add.at(usage_counts[ch], row_orig, 1)\n",
    "\n",
    "        # Update confusion matrix (errors only)\n",
    "        mask = row_orig != row_rt\n",
    "        if np.any(mask):\n",
    "            np.add.at(conf_matrices[ch], (row_orig[mask], row_rt[mask]), 1)\n",
    "\n",
    "# Print stats and save\n",
    "for ch in range(num_channels):\n",
    "    unseen = np.sum(usage_counts[ch] == 0)\n",
    "    print(f\"Channel {ch}: used {vocab_size - unseen}/{vocab_size} (unseen: {unseen})\")\n",
    "    np.save(os.path.join(matrix_dir, f\"confusion_{ch}.npy\"), conf_matrices[ch])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0f041683",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import networkx as nx\n",
    "import igraph as ig\n",
    "import leidenalg\n",
    "import community as community_louvain\n",
    "\n",
    "\n",
    "def louvain_clusters(S, min_sub_count=1, resolution=1):\n",
    "    # NOTE: This treats graphs as UNDIRECTED by symmetrizing the input.\n",
    "    # Filter weights and create graph\n",
    "    S_masked = np.where(S >= min_sub_count, S, 0)\n",
    "    # Explicitly symmetrize\n",
    "    S_sym = S_masked + S_masked.T\n",
    "    g = nx.from_numpy_array(S_sym)\n",
    "\n",
    "    # Run clustering\n",
    "    partition = community_louvain.best_partition(\n",
    "        g,\n",
    "        weight='weight',\n",
    "        resolution=resolution,\n",
    "        random_state=27\n",
    "    )\n",
    "\n",
    "    # Convert to array, 'partition' is a dict {node_id: community_id} covering all nodes\n",
    "    labels = np.array([partition[i] for i in range(S.shape[0])])\n",
    "\n",
    "    # Reindex to 0...K-1\n",
    "    _, unique_labels = np.unique(labels, return_inverse=True)\n",
    "    return unique_labels\n",
    "\n",
    "\n",
    "def leiden_clusters(S, min_sub_count=1, resolution=1):\n",
    "    # NOTE: This correctly treats graphs as DIRECTED.\n",
    "    # Filter weights and create graph\n",
    "    S_masked = np.where(S >= min_sub_count, S, 0)\n",
    "    g = ig.Graph.Weighted_Adjacency(S_masked.tolist(), mode='directed')\n",
    "\n",
    "    # Run clustering\n",
    "    partition = leidenalg.find_partition(\n",
    "        g,\n",
    "        leidenalg.RBConfigurationVertexPartition,\n",
    "        weights=g.es['weight'],\n",
    "        resolution_parameter=resolution,\n",
    "        seed=27\n",
    "    )\n",
    "\n",
    "    # Convert to array, partition.membership is a list with index=node, value=community\n",
    "    labels = np.array(partition.membership)\n",
    "\n",
    "    # Reindex to 0...K-1\n",
    "    _, unique_labels = np.unique(labels, return_inverse=True)\n",
    "    return unique_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fb2aa87",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Channel: 0\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=1): 203\n",
      "Leiden components (min_sub=1): 202\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=5): 873\n",
      "Leiden components (min_sub=5): 873\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=10): 1446\n",
      "Leiden components (min_sub=10): 1447\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=25): 1913\n",
      "Leiden components (min_sub=25): 1913\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=50): 2007\n",
      "Leiden components (min_sub=50): 2007\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=100): 2038\n",
      "Leiden components (min_sub=100): 2038\n",
      "Channel: 1\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=1): 173\n",
      "Leiden components (min_sub=1): 173\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=5): 673\n",
      "Leiden components (min_sub=5): 675\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=10): 972\n",
      "Leiden components (min_sub=10): 972\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=25): 1494\n",
      "Leiden components (min_sub=25): 1496\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=50): 1782\n",
      "Leiden components (min_sub=50): 1781\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=100): 1952\n",
      "Leiden components (min_sub=100): 1951\n",
      "Channel: 2\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=1): 40\n",
      "Leiden components (min_sub=1): 39\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=5): 961\n",
      "Leiden components (min_sub=5): 962\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=10): 1663\n",
      "Leiden components (min_sub=10): 1661\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=25): 1974\n",
      "Leiden components (min_sub=25): 1974\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=50): 2027\n",
      "Leiden components (min_sub=50): 2027\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=100): 2039\n",
      "Leiden components (min_sub=100): 2039\n",
      "Channel: 3\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=1): 14\n",
      "Leiden components (min_sub=1): 13\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=5): 1607\n",
      "Leiden components (min_sub=5): 1610\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=10): 1904\n",
      "Leiden components (min_sub=10): 1905\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=25): 2016\n",
      "Leiden components (min_sub=25): 2016\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=50): 2028\n",
      "Leiden components (min_sub=50): 2028\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=100): 2040\n",
      "Leiden components (min_sub=100): 2040\n",
      "Channel: 4\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=1): 13\n",
      "Leiden components (min_sub=1): 13\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=5): 1733\n",
      "Leiden components (min_sub=5): 1733\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=10): 1959\n",
      "Leiden components (min_sub=10): 1959\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=25): 2010\n",
      "Leiden components (min_sub=25): 2010\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=50): 2032\n",
      "Leiden components (min_sub=50): 2032\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=100): 2038\n",
      "Leiden components (min_sub=100): 2038\n",
      "Channel: 5\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=1): 16\n",
      "Leiden components (min_sub=1): 16\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=5): 1812\n",
      "Leiden components (min_sub=5): 1813\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=10): 1968\n",
      "Leiden components (min_sub=10): 1968\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=25): 2010\n",
      "Leiden components (min_sub=25): 2010\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=50): 2023\n",
      "Leiden components (min_sub=50): 2023\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=100): 2037\n",
      "Leiden components (min_sub=100): 2037\n",
      "Channel: 6\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=1): 7\n",
      "Leiden components (min_sub=1): 8\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=5): 1821\n",
      "Leiden components (min_sub=5): 1821\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=10): 1972\n",
      "Leiden components (min_sub=10): 1971\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=25): 2014\n",
      "Leiden components (min_sub=25): 2015\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=50): 2029\n",
      "Leiden components (min_sub=50): 2029\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=100): 2036\n",
      "Leiden components (min_sub=100): 2037\n",
      "Channel: 7\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=1): 6\n",
      "Leiden components (min_sub=1): 6\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=5): 1828\n",
      "Leiden components (min_sub=5): 1827\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=10): 1957\n",
      "Leiden components (min_sub=10): 1958\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=25): 2012\n",
      "Leiden components (min_sub=25): 2013\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=50): 2030\n",
      "Leiden components (min_sub=50): 2030\n",
      "Resolution: 1\n",
      "Louvain components (min_sub=100): 2034\n",
      "Leiden components (min_sub=100): 2034\n"
     ]
    }
   ],
   "source": [
    "CHANNELS = [\n",
    "    'rvq_first_0',\n",
    "    'rvq_rest_0',\n",
    "    'rvq_rest_1',\n",
    "    'rvq_rest_2',\n",
    "    'rvq_rest_3',\n",
    "    'rvq_rest_4',\n",
    "    'rvq_rest_5',\n",
    "    'rvq_rest_6',\n",
    "]\n",
    "\n",
    "COUNTS = [\n",
    "    1,\n",
    "    # 2,\n",
    "    # 3,\n",
    "    # 4,\n",
    "    5,\n",
    "    10,\n",
    "    # 15,\n",
    "    # 20,\n",
    "    25,\n",
    "    # 30,\n",
    "    50,\n",
    "    100,\n",
    "]\n",
    "\n",
    "RESOLUTIONS = [\n",
    "    # 0.8,\n",
    "    1,\n",
    "    # 1.2,\n",
    "]\n",
    "\n",
    "import pickle\n",
    "\n",
    "for channel, name in enumerate(CHANNELS):\n",
    "    print(\"Channel:\", channel)\n",
    "    matrix = np.load(f\"{matrix_dir}/confusion_{channel}.npy\")\n",
    "\n",
    "    louvain_clusterings = {}\n",
    "    leiden_clusterings = {}\n",
    "\n",
    "    for count in COUNTS:\n",
    "        for resolution in RESOLUTIONS:\n",
    "            print(\"Resolution:\", resolution)\n",
    "            louvain_labels = louvain_clusters(matrix, min_sub_count=count, resolution=resolution)\n",
    "            leiden_labels = leiden_clusters(matrix, min_sub_count=count, resolution=resolution)\n",
    "\n",
    "            print(f\"Louvain components (min_sub={count}): {louvain_labels.max() + 1}\")\n",
    "            print(f\"Leiden components (min_sub={count}): {leiden_labels.max() + 1}\")\n",
    "\n",
    "            louvain_clusterings[count] = louvain_labels\n",
    "            leiden_clusterings[count] = leiden_labels\n",
    "\n",
    "    with open(f\"/home/wmar/wmar_audio/models/embeddings/new_clusterings/mimi_louvain_{name}_clusterings.pkl\", \"wb\") as f:\n",
    "        pickle.dump(louvain_clusterings, f)\n",
    "\n",
    "    with open(f\"/home/wmar/wmar_audio/models/embeddings/new_clusterings/mimi_leiden_{name}_clusterings.pkl\", \"wb\") as f:\n",
    "        pickle.dump(leiden_clusterings, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a3fc6dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# After clustering, inspect number of clusters, singletons, and largest cluster"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef03bd20",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from glob import glob\n",
    "from collections import defaultdict\n",
    "\n",
    "def load_data(samples_dir, num_channels=4):\n",
    "    orig_list, rt_list = [], []\n",
    "    paths = sorted(glob(os.path.join(samples_dir, \"*\")))\n",
    "    print(f\"Loading {len(paths)} samples...\")\n",
    "\n",
    "    for d in paths:\n",
    "        o_path, r_path = os.path.join(d, \"orig_tokens.txt\"), os.path.join(d, \"roundtrip_tokens.txt\")\n",
    "        if not (os.path.exists(o_path) and os.path.exists(r_path)): continue\n",
    "\n",
    "        with open(o_path, \"r\") as f: \n",
    "            o = [np.fromstring(l, sep=\" \", dtype=np.int64) for l in f.readlines()[:num_channels]]\n",
    "        with open(r_path, \"r\") as f: \n",
    "            r = [np.fromstring(l, sep=\" \", dtype=np.int64) for l in f.readlines()[:num_channels]]\n",
    "\n",
    "        T = min(o[0].size, r[0].size)\n",
    "        if T == 0: continue\n",
    "        \n",
    "        orig_list.append(np.stack([x[:T] for x in o], axis=1))\n",
    "        rt_list.append(np.stack([x[:T] for x in r], axis=1))\n",
    "\n",
    "    assert len(orig_list) > 0, \"No data found\"\n",
    "    full_orig = torch.from_numpy(np.concatenate(orig_list, axis=0))\n",
    "    full_rt = torch.from_numpy(np.concatenate(rt_list, axis=0))\n",
    "    assert full_orig.shape == full_rt.shape\n",
    "    return full_orig, full_rt\n",
    "\n",
    "def compute_metrics(x, y, k, cmap=None):\n",
    "    device = x.device\n",
    "    if cmap is not None:\n",
    "        if not isinstance(cmap, torch.Tensor): cmap = torch.tensor(cmap, device=device, dtype=torch.long)\n",
    "        assert cmap.dim() == 1 and cmap.size(0) == 2048\n",
    "        # Ensure k matches the map's actual range for safety\n",
    "        assert cmap.max() < k, f\"Map has index {cmap.max()} but k is {k}\"\n",
    "        x, y = cmap[x], cmap[y]\n",
    "\n",
    "    flat = x.long() * k + y.long()\n",
    "    counts = torch.bincount(flat, minlength=k*k).float()\n",
    "    joint = counts.reshape(k, k)\n",
    "    total = joint.sum()\n",
    "    \n",
    "    if total == 0: return 0.0, 0.0\n",
    "\n",
    "    Pxy = joint / total\n",
    "    Px, Py = Pxy.sum(1), Pxy.sum(0)\n",
    "    nz = Pxy > 0\n",
    "    mi = (Pxy[nz] * torch.log(Pxy[nz] / (Px.unsqueeze(1) * Py.unsqueeze(0))[nz])).sum().item()\n",
    "    \n",
    "    nz_x = Px > 0\n",
    "    hx = -(Px[nz_x] * torch.log(Px[nz_x])).sum().item()\n",
    "    mi_frac = mi / hx if hx > 1e-9 else 0.0\n",
    "\n",
    "    correct = joint.diag().sum().item()\n",
    "    e_raw = 1.0 - correct / total.item()\n",
    "    e_amort = e_raw / (1.0 - 1.0 / k) if k > 1 else 0.0\n",
    "    return mi_frac, e_amort\n",
    "\n",
    "# Configs\n",
    "SAMPLES_DIR = \"/home/wmar/wmar_audio/outputs/rcc_reconstructions/original_mimi/samples\"\n",
    "CLUSTERS_GLOB = \"/home/AlignedIS-dev/models/embeddings/clusterings/mimi_*_rvq_*.pkl\"\n",
    "CHANNELS = 4\n",
    "VOCAB = 2048\n",
    "OUT_DIR = \"results/plots/retokenization\"\n",
    "\n",
    "orig_all, rt_all = load_data(SAMPLES_DIR, CHANNELS)\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "orig_all, rt_all = orig_all.to(device), rt_all.to(device)\n",
    "\n",
    "# Load maps: maps[channel][method][k] = array\n",
    "maps = defaultdict(lambda: defaultdict(dict))\n",
    "\n",
    "for p in sorted(glob(CLUSTERS_GLOB)):\n",
    "    fname = os.path.basename(p)\n",
    "    if \"first_0\" in fname: c = 0\n",
    "    elif \"rest_0\" in fname: c = 1\n",
    "    elif \"rest_1\" in fname: c = 2\n",
    "    elif \"rest_2\" in fname: c = 3\n",
    "    else: continue\n",
    "    \n",
    "    if \"leiden\" in fname: mtype = \"leiden\"\n",
    "    elif \"louvain\" in fname: mtype = \"louvain\"\n",
    "    else: continue \n",
    "\n",
    "    with open(p, \"rb\") as f:\n",
    "        # Dictionary is {min_count: clustering_array}\n",
    "        data = pickle.load(f)\n",
    "        for min_count, cmap in data.items():\n",
    "            k_actual = int(cmap.max()) + 1\n",
    "            maps[c][mtype][min_count] = (k_actual, cmap)\n",
    "\n",
    "os.makedirs(OUT_DIR, exist_ok=True)\n",
    "\n",
    "for ch in range(CHANNELS):\n",
    "    print(f\"Processing Channel {ch}...\")\n",
    "    x, y = orig_all[:, ch], rt_all[:, ch]\n",
    "    assert x.max() < VOCAB and y.max() < VOCAB\n",
    "\n",
    "    base_mi, base_err = compute_metrics(x, y, VOCAB, None)\n",
    "    \n",
    "    # res stores: (min_count, k_actual, mi, err)\n",
    "    res = defaultdict(list)\n",
    "    for mtype in [\"leiden\", \"louvain\"]:\n",
    "        for mc in sorted(maps[ch][mtype].keys()):\n",
    "            k_act, cmap = maps[ch][mtype][mc]\n",
    "            mi, err = compute_metrics(x, y, k_act, cmap)\n",
    "            res[mtype].append((mc, k_act, mi, err))\n",
    "        res[mtype].sort() # Sort by min_count\n",
    "\n",
    "    fig, ax = plt.subplots(1, 2, figsize=(14, 5))\n",
    "    \n",
    "    # Unified X-axis based on min_count\n",
    "    all_mcs = sorted(list(set([v[0] for v in res[\"leiden\"]] + [v[0] for v in res[\"louvain\"]])))\n",
    "    mc_map = {mc: i for i, mc in enumerate(all_mcs)}\n",
    "\n",
    "    for i, title in enumerate([\"MI / H(X)\", \"Amortized Error\"]):\n",
    "        ax[i].axhline(base_mi if i==0 else base_err, color='k', ls='--', label='Unclustered')\n",
    "        \n",
    "        # Index of metric in the tuple (mc, k, mi, err) -> mi is index 2, err is index 3\n",
    "        idx = i + 2 \n",
    "        \n",
    "        if res[\"leiden\"]:\n",
    "            ax[i].plot([mc_map[v[0]] for v in res[\"leiden\"]], [v[idx] for v in res[\"leiden\"]], 'o-', label='Leiden')\n",
    "        if res[\"louvain\"]:\n",
    "            ax[i].plot([mc_map[v[0]] for v in res[\"louvain\"]], [v[idx] for v in res[\"louvain\"]], 's-', label='Louvain')\n",
    "            \n",
    "        ax[i].set_title(title)\n",
    "        ax[i].set_xticks(range(len(all_mcs)))\n",
    "        ax[i].set_xticklabels([str(mc) for mc in all_mcs], rotation=45)\n",
    "        ax[i].set_xlabel(\"min_count\")\n",
    "        ax[i].grid(True, alpha=0.3)\n",
    "        if i == 0: ax[i].legend()\n",
    "\n",
    "    plt.suptitle(f\"Channel {ch}\")\n",
    "    plt.tight_layout()\n",
    "    # plt.savefig(os.path.join(OUT_DIR, f\"metrics_channel_community_{ch}.png\"))\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "651b8cb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from glob import glob\n",
    "from collections import defaultdict\n",
    "\n",
    "# -----------------------------------------------------------------------------\n",
    "# 1. Config & Mode Selection\n",
    "# -----------------------------------------------------------------------------\n",
    "\n",
    "# OPTIONS: \"community\" (Leiden/Louvain) or \"kmeans\" (Standard/Even)\n",
    "MODE = \"kmeans\"  # Change to \"community\" to run Leiden/Louvain\n",
    "\n",
    "SAMPLES_DIR = \"/home/wmar/wmar_audio/outputs/rcc_reconstructions/original_mimi/samples\"\n",
    "OUT_DIR = f\"results/plots/pareto\"\n",
    "CHANNELS = 4\n",
    "VOCAB = 2048\n",
    "\n",
    "if MODE == \"community\":\n",
    "    CLUSTERS_GLOB = \"/home/AlignedIS-dev/models/embeddings/clusterings/mimi_*_rvq_*.pkl\"\n",
    "elif MODE == \"kmeans\":\n",
    "    CLUSTERS_GLOB = \"/home/AlignedIS-dev/models/embeddings/clusterings/mimi_rvq_*.pkl\"\n",
    "else:\n",
    "    raise ValueError(\"MODE must be 'community' or 'kmeans'\")\n",
    "\n",
    "# -----------------------------------------------------------------------------\n",
    "# 2. Metrics Implementation\n",
    "# -----------------------------------------------------------------------------\n",
    "\n",
    "def compute_basic_metrics(x, y, k, cmap=None):\n",
    "    \"\"\"Computes Normalized MI and Amortized Error.\"\"\"\n",
    "    device = x.device\n",
    "    if cmap is not None:\n",
    "        if not isinstance(cmap, torch.Tensor): cmap = torch.tensor(cmap, device=device, dtype=torch.long)\n",
    "        x, y = cmap[x], cmap[y]\n",
    "\n",
    "    flat = x.long() * k + y.long()\n",
    "    counts = torch.bincount(flat, minlength=k*k).float()\n",
    "    joint = counts.reshape(k, k)\n",
    "    total = joint.sum()\n",
    "    \n",
    "    if total == 0: return 0.0, 0.0\n",
    "\n",
    "    # MI / H(X)\n",
    "    Pxy = joint / total\n",
    "    Px, Py = Pxy.sum(1), Pxy.sum(0)\n",
    "    nz = Pxy > 0\n",
    "    mi = (Pxy[nz] * torch.log(Pxy[nz] / (Px.unsqueeze(1) * Py.unsqueeze(0))[nz])).sum().item()\n",
    "    \n",
    "    nz_x = Px > 0\n",
    "    hx = -(Px[nz_x] * torch.log(Px[nz_x])).sum().item()\n",
    "    norm_mi = mi / hx if hx > 1e-9 else 0.0\n",
    "\n",
    "    # Amortized Error\n",
    "    correct = joint.diag().sum().item()\n",
    "    e_raw = 1.0 - correct / total.item()\n",
    "    e_amort = e_raw / (1.0 - 1.0 / k) if k > 1 else 0.0\n",
    "\n",
    "    return norm_mi, e_amort\n",
    "\n",
    "def compute_effective_vocab(x, k, cmap=None):\n",
    "    \"\"\"Computes Effective Vocabulary Size (2^Entropy).\"\"\"\n",
    "    device = x.device\n",
    "    if cmap is not None:\n",
    "        if not isinstance(cmap, torch.Tensor): cmap = torch.tensor(cmap, device=device, dtype=torch.long)\n",
    "        x = cmap[x]\n",
    "\n",
    "    counts = torch.bincount(x, minlength=k).float()\n",
    "    probs = counts / counts.sum()\n",
    "    probs = probs[probs > 0] \n",
    "    entropy = -(probs * torch.log2(probs)).sum().item()\n",
    "    return 2 ** entropy\n",
    "\n",
    "def compute_merger_consistency(x, y, cmap, vocab_size=2048):\n",
    "    \"\"\"Computes Merger Consistency (Precision).\"\"\"\n",
    "    device = x.device\n",
    "    if not isinstance(cmap, torch.Tensor): cmap = torch.tensor(cmap, device=device, dtype=torch.long)\n",
    "\n",
    "    flat_idx = x.long() * vocab_size + y.long()\n",
    "    counts = torch.bincount(flat_idx, minlength=vocab_size**2).float()\n",
    "    confusion = counts.reshape(vocab_size, vocab_size)\n",
    "    adj = (confusion + confusion.T) > 0\n",
    "\n",
    "    unique_clusters = torch.unique(cmap)\n",
    "    total_pairs = 0\n",
    "    valid_pairs = 0\n",
    "    \n",
    "    for c_id in unique_clusters:\n",
    "        members = torch.nonzero(cmap == c_id, as_tuple=True)[0]\n",
    "        n = len(members)\n",
    "        if n < 2: continue\n",
    "            \n",
    "        sub_adj = adj[members][:, members]\n",
    "        links = sub_adj.sum().item() - n \n",
    "        valid_pairs += (links / 2.0)\n",
    "        total_pairs += ((n * (n - 1)) / 2.0)\n",
    "        \n",
    "    if total_pairs == 0: return 1.0\n",
    "    return valid_pairs / total_pairs\n",
    "\n",
    "# -----------------------------------------------------------------------------\n",
    "# 3. Data Loading & Parsing\n",
    "# -----------------------------------------------------------------------------\n",
    "\n",
    "def load_data(samples_dir, num_channels=4):\n",
    "    orig_list, rt_list = [], []\n",
    "    paths = sorted(glob(os.path.join(samples_dir, \"*\")))\n",
    "    print(f\"Loading {len(paths)} samples...\")\n",
    "\n",
    "    for d in paths:\n",
    "        o_path, r_path = os.path.join(d, \"orig_tokens.txt\"), os.path.join(d, \"roundtrip_tokens.txt\")\n",
    "        if not (os.path.exists(o_path) and os.path.exists(r_path)): continue\n",
    "        try:\n",
    "            with open(o_path, \"r\") as f: \n",
    "                o = [np.fromstring(l, sep=\" \", dtype=np.int64) for l in f.readlines()[:num_channels]]\n",
    "            with open(r_path, \"r\") as f: \n",
    "                r = [np.fromstring(l, sep=\" \", dtype=np.int64) for l in f.readlines()[:num_channels]]\n",
    "            T = min(o[0].size, r[0].size)\n",
    "            if T == 0: continue\n",
    "            orig_list.append(np.stack([x[:T] for x in o], axis=1))\n",
    "            rt_list.append(np.stack([x[:T] for x in r], axis=1))\n",
    "        except: continue\n",
    "\n",
    "    full_orig = torch.from_numpy(np.concatenate(orig_list, axis=0))\n",
    "    full_rt = torch.from_numpy(np.concatenate(rt_list, axis=0))\n",
    "    return full_orig, full_rt\n",
    "\n",
    "# -----------------------------------------------------------------------------\n",
    "# 4. Main Processing\n",
    "# -----------------------------------------------------------------------------\n",
    "\n",
    "orig_all, rt_all = load_data(SAMPLES_DIR, CHANNELS)\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "orig_all, rt_all = orig_all.to(device), rt_all.to(device)\n",
    "\n",
    "maps = defaultdict(lambda: defaultdict(dict))\n",
    "\n",
    "print(f\"Parsing cluster files for mode: {MODE}\")\n",
    "for p in sorted(glob(CLUSTERS_GLOB)):\n",
    "    fname = os.path.basename(p)\n",
    "    \n",
    "    # Parse Channel\n",
    "    if \"first_0\" in fname: c = 0\n",
    "    elif \"rest_0\" in fname: c = 1\n",
    "    elif \"rest_1\" in fname: c = 2\n",
    "    elif \"rest_2\" in fname: c = 3\n",
    "    else: continue\n",
    "    \n",
    "    # Parse Method\n",
    "    if MODE == \"community\":\n",
    "        if \"leiden\" in fname: mtype = \"Leiden\"\n",
    "        elif \"louvain\" in fname: mtype = \"Louvain\"\n",
    "        else: continue \n",
    "    else: # kmeans\n",
    "        if \"leiden\" in fname or \"louvain\" in fname: continue\n",
    "        mtype = \"Even K-Means\" if \"even\" in fname else \"K-Means\"\n",
    "\n",
    "    with open(p, \"rb\") as f:\n",
    "        data = pickle.load(f)\n",
    "        \n",
    "        if MODE == \"community\":\n",
    "            # Data: {min_count: cmap}\n",
    "            for min_count, cmap in data.items():\n",
    "                k_act = int(cmap.max()) + 1\n",
    "                maps[c][mtype][min_count] = (k_act, cmap)\n",
    "        else:\n",
    "            # Data: {k: {seed: cmap, ...}}\n",
    "            # We strictly look for seed 0\n",
    "            for k_val, seeds_dict in data.items():\n",
    "                if 0 in seeds_dict:\n",
    "                    cmap = seeds_dict[0]\n",
    "                    maps[c][mtype][k_val] = (k_val, cmap)\n",
    "\n",
    "os.makedirs(OUT_DIR, exist_ok=True)\n",
    "\n",
    "for ch in range(CHANNELS):\n",
    "    print(f\"Processing Channel {ch}...\")\n",
    "    x, y = orig_all[:, ch], rt_all[:, ch]\n",
    "    results = defaultdict(list)\n",
    "    \n",
    "    methods = [\"Leiden\", \"Louvain\"] if MODE == \"community\" else [\"K-Means\", \"Even K-Means\"]\n",
    "    \n",
    "    for mtype in methods:\n",
    "        if mtype not in maps[ch]: continue\n",
    "        \n",
    "        sorted_params = sorted(maps[ch][mtype].keys())\n",
    "        \n",
    "        for param in sorted_params:\n",
    "            k, cmap = maps[ch][mtype][param]\n",
    "            \n",
    "            n_mi, err = compute_basic_metrics(x, y, k, cmap)\n",
    "            eff_vocab = compute_effective_vocab(x, k, cmap)\n",
    "            consistency = compute_merger_consistency(x, y, cmap, VOCAB)\n",
    "            \n",
    "            results[mtype].append({\n",
    "                \"param\": param,\n",
    "                \"norm_mi\": n_mi,\n",
    "                \"error\": err,\n",
    "                \"eff_vocab\": eff_vocab,\n",
    "                \"consistency\": consistency\n",
    "            })\n",
    "\n",
    "    # --- Plotting ---\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(12, 7))\n",
    "    \n",
    "    # Ax0: Capacity vs Robustness\n",
    "    ax = axes[0]\n",
    "    for mtype, data in results.items():\n",
    "        if not data: continue\n",
    "        x_vals = [d[\"eff_vocab\"] for d in data]\n",
    "        y_vals = [d[\"norm_mi\"] for d in data]\n",
    "        marker = 'o' if \"Leiden\" in mtype or \"K-Means\" == mtype else 's'\n",
    "        \n",
    "        ax.plot(x_vals, y_vals, marker=marker, label=mtype)\n",
    "        for i, d in enumerate(data):\n",
    "            ax.annotate(str(d['param']), (x_vals[i], y_vals[i]), fontsize=8, alpha=0.7)\n",
    "\n",
    "    ax.set_title(\"Option A: Watermark Efficacy (Pareto)\\nCapacity vs Robustness (Higher is Better)\")\n",
    "    ax.set_xlabel(\"Effective Vocab Size (Bits)\")\n",
    "    ax.set_ylabel(\"Normalized MI\")\n",
    "    ax.grid(True, linestyle='--', alpha=0.5)\n",
    "    ax.legend()\n",
    "    \n",
    "    # Ax1: Precision vs Error\n",
    "    ax = axes[1]\n",
    "    for mtype, data in results.items():\n",
    "        if not data: continue\n",
    "        x_vals = [d[\"error\"] for d in data]\n",
    "        y_vals = [d[\"consistency\"] for d in data]\n",
    "        marker = 'o' if \"Leiden\" in mtype or \"K-Means\" == mtype else 's'\n",
    "        \n",
    "        ax.plot(x_vals, y_vals, marker=marker, label=mtype)\n",
    "        for i, d in enumerate(data):\n",
    "            ax.annotate(str(d['param']), (x_vals[i], y_vals[i]), fontsize=8, alpha=0.7)\n",
    "\n",
    "    ax.set_title(\"Option B: Cluster Quality (Pareto)\\nPrecision (Y) vs Error (X)\")\n",
    "    ax.set_xlabel(\"Amortized Error (Lower is Better)\")\n",
    "    ax.set_ylabel(\"Merger Consistency (Higher is Better)\")\n",
    "    ax.grid(True, linestyle='--', alpha=0.5)\n",
    "    ax.legend()\n",
    "\n",
    "    plt.suptitle(f\"Channel {ch} Analysis ({MODE})\")\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(os.path.join(OUT_DIR, f\"pareto_{MODE}_channel_{ch}.png\"))\n",
    "    plt.close()\n",
    "\n",
    "print(f\"Done. Saved to {OUT_DIR}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e8d0111",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import os\n",
    "import pickle\n",
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from glob import glob\n",
    "from collections import defaultdict\n",
    "\n",
    "# ==========================================\n",
    "# 1. Load Data (Standard)\n",
    "# ==========================================\n",
    "def load_data(samples_dir, num_channels=4):\n",
    "    orig_list, rt_list = [], []\n",
    "    paths = sorted(glob(os.path.join(samples_dir, \"*\")))\n",
    "    print(f\"Loading {len(paths)} samples...\")\n",
    "\n",
    "    for d in paths:\n",
    "        o_path, r_path = os.path.join(d, \"orig_tokens.txt\"), os.path.join(d, \"roundtrip_tokens.txt\")\n",
    "        if not (os.path.exists(o_path) and os.path.exists(r_path)): continue\n",
    "\n",
    "        with open(o_path, \"r\") as f: \n",
    "            o = [np.fromstring(l, sep=\" \", dtype=np.int64) for l in f.readlines()[:num_channels]]\n",
    "        with open(r_path, \"r\") as f: \n",
    "            r = [np.fromstring(l, sep=\" \", dtype=np.int64) for l in f.readlines()[:num_channels]]\n",
    "\n",
    "        T = min(o[0].size, r[0].size)\n",
    "        if T == 0: continue\n",
    "        \n",
    "        orig_list.append(np.stack([x[:T] for x in o], axis=1))\n",
    "        rt_list.append(np.stack([x[:T] for x in r], axis=1))\n",
    "\n",
    "    full_orig = torch.from_numpy(np.concatenate(orig_list, axis=0))\n",
    "    full_rt = torch.from_numpy(np.concatenate(rt_list, axis=0))\n",
    "    return full_orig, full_rt\n",
    "\n",
    "def get_eta_ctx(x, y, cmap):\n",
    "    \"\"\"Computes eta_ctx (Cluster Match Rate) for the Pareto Y-axis.\"\"\"\n",
    "    device = x.device\n",
    "    if not isinstance(cmap, torch.Tensor): \n",
    "        cmap = torch.tensor(cmap, device=device, dtype=torch.long)\n",
    "    \n",
    "    cx, cy = cmap[x], cmap[y]\n",
    "    match = (cx == cy).float().mean().item()\n",
    "    return match\n",
    "\n",
    "# ==========================================\n",
    "# 2. Configuration & Execution\n",
    "# ==========================================\n",
    "SAMPLES_DIR = \"/home/wmar/wmar_audio/outputs/rcc_reconstructions/original_mimi/samples\"\n",
    "CLUSTERS_GLOB = \"/home/AlignedIS-dev/models/embeddings/clusterings/mimi_*_rvq_*.pkl\"\n",
    "OUT_DIR = \"/home/AlignedIS-dev/results/plots/channel_comparison\"\n",
    "CHANNELS = 4\n",
    "LIMIT_INSTANCES = 3  # Keep only the first 3 min_count instances\n",
    "\n",
    "os.makedirs(OUT_DIR, exist_ok=True)\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "print(\"Loading Data...\")\n",
    "orig_all, rt_all = load_data(SAMPLES_DIR, CHANNELS)\n",
    "orig_all, rt_all = orig_all.to(device), rt_all.to(device)\n",
    "\n",
    "print(\"Loading Maps...\")\n",
    "# maps[channel][method] = list of (min_count, k_act, cmap)\n",
    "maps = defaultdict(lambda: defaultdict(list))\n",
    "\n",
    "for p in sorted(glob(CLUSTERS_GLOB)):\n",
    "    fname = os.path.basename(p)\n",
    "    if \"first_0\" in fname: c = 0\n",
    "    elif \"rest_0\" in fname: c = 1\n",
    "    elif \"rest_1\" in fname: c = 2\n",
    "    elif \"rest_2\" in fname: c = 3\n",
    "    else: continue\n",
    "    \n",
    "    mtype = \"leiden\" if \"leiden\" in fname else \"louvain\" if \"louvain\" in fname else None\n",
    "    if not mtype: continue\n",
    "\n",
    "    with open(p, \"rb\") as f:\n",
    "        data = pickle.load(f)\n",
    "        # Sort by min_count and take only the first N immediately\n",
    "        sorted_keys = sorted(data.keys())[:LIMIT_INSTANCES]\n",
    "        for mc in sorted_keys:\n",
    "            cmap = data[mc]\n",
    "            k_act = int(cmap.max()) + 1\n",
    "            maps[c][mtype].append((mc, k_act, cmap))\n",
    "\n",
    "# ==========================================\n",
    "# 3. Pareto Plotting\n",
    "# ==========================================\n",
    "fig, ax = plt.subplots(1, 1, figsize=(10, 8))\n",
    "\n",
    "# Define styles for channels\n",
    "colors = ['#1f77b4', '#2ca02c', '#9467bd', '#d62728'] # Blue, Green, Purple, Red\n",
    "markers = {'leiden': 'o', 'louvain': '^'}\n",
    "linestyles = {'leiden': '-', 'louvain': '--'}\n",
    "\n",
    "print(\"Computing Pareto Frontiers...\")\n",
    "\n",
    "for ch in range(CHANNELS):\n",
    "    x_tok, y_tok = orig_all[:, ch], rt_all[:, ch]\n",
    "    \n",
    "    for mtype in [\"leiden\", \"louvain\"]:\n",
    "        data_points = maps[ch][mtype]\n",
    "        if not data_points: continue\n",
    "        \n",
    "        # Lists for plotting\n",
    "        K_values = []\n",
    "        Eta_values = []\n",
    "        \n",
    "        for mc, k_act, cmap in data_points:\n",
    "            eta = get_eta_ctx(x_tok, y_tok, cmap)\n",
    "            K_values.append(k_act)\n",
    "            Eta_values.append(eta)\n",
    "        \n",
    "        # Sort by K (descending usually, but for line plot x-axis ascending is better)\n",
    "        # We want to draw the line from High K (low robustness) to Low K (high robustness)\n",
    "        # But matplotlib sorts by X usually. Let's just zip and sort by K.\n",
    "        pairs = sorted(zip(K_values, Eta_values))\n",
    "        ks, etas = zip(*pairs)\n",
    "        \n",
    "        label = f\"Ch {ch} {mtype.capitalize()}\" if ch == 0 or ch == 3 else None\n",
    "        \n",
    "        ax.plot(ks, etas, \n",
    "                color=colors[ch], \n",
    "                marker=markers[mtype], \n",
    "                linestyle=linestyles[mtype], \n",
    "                linewidth=2 if ch==3 else 1.5,\n",
    "                alpha=1.0 if ch==3 else 0.7,\n",
    "                label=f\"Ch {ch} {mtype.capitalize()}\"\n",
    "        )\n",
    "        \n",
    "        # Annotate the min_count points for clarity\n",
    "        for (mc, _, _), k_val, eta_val in zip(data_points, K_values, Eta_values):\n",
    "            ax.annotate(f\"{mc}\", (k_val, eta_val), \n",
    "                        textcoords=\"offset points\", xytext=(0, 5), ha='center', fontsize=8)\n",
    "\n",
    "# Baseline (r) for reference (unclustered)\n",
    "# Just computing it for Channel 0 as a reference line might be messy, \n",
    "# but usually unclustered is at K=2048, eta = r.\n",
    "# Let's verify ch 3 baseline specifically.\n",
    "r_ch3 = (orig_all[:, 3] == rt_all[:, 3]).float().mean().item()\n",
    "ax.scatter([2048], [r_ch3], color=colors[3], s=100, marker='*', label=\"Ch 3 Baseline (No Cluster)\")\n",
    "\n",
    "ax.set_xlabel(\"Effective Vocabulary Size ($K$)\", fontsize=12)\n",
    "ax.set_ylabel(r\"Cluster Match Rate ($\\eta_{ctx}$)\", fontsize=12)\n",
    "ax.set_title(\"Pareto Frontier: Robustness vs. Capacity\\n(First 3 min_counts)\", fontsize=14)\n",
    "ax.grid(True, which='both', linestyle='--', alpha=0.5)\n",
    "ax.set_xscale('log') # Log scale is usually better for K\n",
    "ax.legend()\n",
    "\n",
    "# Highlight the \"Dominated\" region\n",
    "# In a Pareto plot, top-right is best (High K, High Eta).\n",
    "# Bottom-left is worst.\n",
    "ax.text(0.05, 0.05, \"Low Capacity,\\nLow Robustness\", transform=ax.transAxes, color='gray')\n",
    "ax.text(0.95, 0.95, \"High Capacity,\\nHigh Robustness\", transform=ax.transAxes, ha='right', va='top', color='gray')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(OUT_DIR, \"pareto_channel_comparison.png\"))\n",
    "print(f\"Saved Pareto plot to {OUT_DIR}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "wmar",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
