{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Microsoft dataset\n",
    "You can download the dataset into this folder from https://github.com/microsoft/clustered-nanopore-reads-dataset."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.) Visualize cluster size distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Settings \n",
    "FONTSIZE = 11\n",
    "plt.rcParams.update({\n",
    "    'text.usetex': True,\n",
    "    'font.family': 'serif',\n",
    "    'font.serif': ['Computer Modern Roman'],\n",
    "    'text.latex.preamble': r'\\usepackage{amsmath} \\usepackage{type1cm}',\n",
    "    'font.size': FONTSIZE,\n",
    "})\n",
    "\n",
    "# Define file paths \n",
    "notebook_dir = os.getcwd()\n",
    "centers_path = os.path.join(notebook_dir, \"clustered-nanopore-reads-dataset\", \"Centers.txt\")\n",
    "clusters_path = os.path.join(notebook_dir, \"clustered-nanopore-reads-dataset\", \"Clusters.txt\")\n",
    "separator = \"===============================\"\n",
    "\n",
    "# Read Centers.txt\n",
    "with open(centers_path, 'r') as f:\n",
    "    ground_truth_list = [line.strip() for line in f]\n",
    "\n",
    "# Read Clusters.txt\n",
    "with open(clusters_path, 'r') as f:\n",
    "    raw = f.read()\n",
    "\n",
    "# Split clusters using separator\n",
    "clusters_raw = raw.split(separator)\n",
    "\n",
    "# Remove leading empty cluster if file starts with a separator\n",
    "if len(clusters_raw) > 0 and clusters_raw[0].strip() == '':\n",
    "    clusters_raw = clusters_raw[1:]\n",
    "\n",
    "# Clean up reads in each cluster\n",
    "clusters = [list(filter(None, cluster.strip().splitlines())) for cluster in clusters_raw]\n",
    "\n",
    "# Check\n",
    "assert len(clusters) == len(ground_truth_list), \\\n",
    "       f\"Expected {len(ground_truth_list)} clusters, got {len(clusters)}\"\n",
    "\n",
    "# Compute cluster size distribution\n",
    "cluster_sizes = [len(c) for c in clusters]\n",
    "\n",
    "# Filter out empty clusters and compute mean size over non‐empty clusters\n",
    "nonempty_sizes = [sz for sz in cluster_sizes if sz > 0]\n",
    "if nonempty_sizes:\n",
    "    mean_nonempty = np.mean(nonempty_sizes)\n",
    "    print(f\"Mean cluster size (excluding empty): {mean_nonempty:.2f}\")\n",
    "else:\n",
    "    print(\"No non‐empty clusters to average over.\")\n",
    "\n",
    "# Count empty clusters\n",
    "num_empty_clusters = sum(1 for size in cluster_sizes if size == 0)\n",
    "print(f\"Number of empty clusters: {num_empty_clusters}\")\n",
    "\n",
    "# Plot and show cluster size distribution\n",
    "plt.hist(cluster_sizes, bins=50)\n",
    "plt.xlabel(\"Cluster size\")\n",
    "plt.ylabel(\"Frequency\")\n",
    "plt.title(\"Distribution of Cluster Sizes\")\n",
    "plt.grid(True)\n",
    "plt.show()\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.) Split in 80\\% train and 10\\% validation and test sets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import ScalarFormatter, FuncFormatter\n",
    "from collections import defaultdict\n",
    "import wandb\n",
    "from matplotlib import cm\n",
    "from matplotlib.patches import Patch\n",
    "\n",
    "# Settings \n",
    "FONTSIZE = 6\n",
    "plt.rcParams.update({\n",
    "    'text.usetex': True,\n",
    "    'font.family': 'serif',\n",
    "    'font.serif': ['Computer Modern Roman'],\n",
    "    'text.latex.preamble': r'\\usepackage{amsmath} \\usepackage{type1cm}',\n",
    "    'font.size': FONTSIZE,\n",
    "})\n",
    "\n",
    "# Data split\n",
    "indices = list(range(len(clusters)))\n",
    "train_indices, temp_indices = train_test_split(indices, test_size=0.2, random_state=42)\n",
    "val_indices,   test_indices = train_test_split(temp_indices,   test_size=0.5, random_state=42)\n",
    "\n",
    "train_clusters = [clusters[i] for i in train_indices]\n",
    "val_clusters   = [clusters[i] for i in val_indices]\n",
    "test_clusters  = [clusters[i] for i in test_indices]\n",
    "\n",
    "train_ground_truth = [ground_truth_list[i] for i in train_indices]\n",
    "val_ground_truth   = [ground_truth_list[i] for i in val_indices]\n",
    "test_ground_truth  = [ground_truth_list[i] for i in test_indices]\n",
    "\n",
    "\n",
    "# Plot \n",
    "cmap    = cm.get_cmap('PuBu')\n",
    "colors  = cmap(np.linspace(0.2, 0.8, 3))\n",
    "labels  = ['Train', 'Val', 'Test']\n",
    "splits  = [train_clusters, val_clusters, test_clusters]\n",
    "\n",
    "fig, axs = plt.subplots(1, 3, figsize=(5.5, 1.3), dpi=300,\n",
    "                        gridspec_kw={'wspace': 0.4})\n",
    "\n",
    "for ax, data, col in zip(axs, splits, colors):\n",
    "    sizes = [len(c) for c in data]\n",
    "    ax.hist(\n",
    "        sizes,\n",
    "        bins=50,\n",
    "        color=col,\n",
    "        edgecolor='none'          \n",
    "    )\n",
    "    ax.set_xlabel(\"Cluster size\")\n",
    "    ax.grid(True, which='both', linestyle='--', linewidth=0.3)\n",
    "\n",
    "axs[0].set_ylabel(\"Frequency\")\n",
    "\n",
    "for ax in axs:\n",
    "    # light-gray frame\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_color('lightgray')\n",
    "    # tick lines lightgray, tick labels black\n",
    "    ax.tick_params(\n",
    "        axis='both',\n",
    "        which='both',\n",
    "        color='lightgray',\n",
    "        labelcolor='black', \n",
    "    )\n",
    "\n",
    "handles = [Patch(facecolor=c, edgecolor='none') for c in colors]\n",
    "fig.legend(handles, labels,\n",
    "           loc='upper center',\n",
    "           ncol=3,\n",
    "           frameon=False,\n",
    "           fontsize=FONTSIZE)\n",
    "\n",
    "plt.tight_layout(rect=[0, 0, 1, 0.85])\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Count clusters with only one read in each split\n",
    "num_train_singletons = sum(1 for c in train_clusters if len(c) == 1)\n",
    "num_val_singletons = sum(1 for c in val_clusters if len(c) == 1)\n",
    "num_test_singletons = sum(1 for c in test_clusters if len(c) == 1)\n",
    "\n",
    "# Count clusters with more than one read\n",
    "num_train_multi = sum(1 for c in train_clusters if len(c) > 1)\n",
    "num_val_multi = sum(1 for c in val_clusters if len(c) > 1)\n",
    "num_test_multi = sum(1 for c in test_clusters if len(c) > 1)\n",
    "\n",
    "# Print the results\n",
    "print(f\"Train clusters with only 1 sequence: {num_train_singletons}\")\n",
    "print(f\"Validation clusters with only 1 sequence: {num_val_singletons}\")\n",
    "print(f\"Test clusters with only 1 sequence: {num_test_singletons}\")\n",
    "\n",
    "print(f\"Train clusters with >1 sequence: {num_train_multi}\")\n",
    "print(f\"Validation clusters with >1 sequence: {num_val_multi}\")\n",
    "print(f\"Test clusters with >1 sequence: {num_test_multi}\")\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.) Create subclusters of at most 10 reads to fit the model’s context length."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "import random \n",
    "from typing import List, Optional, Tuple\n",
    "import random\n",
    "\n",
    "# Settings \n",
    "FONTSIZE = 6\n",
    "plt.rcParams.update({\n",
    "    'text.usetex': True,\n",
    "    'font.family': 'serif',\n",
    "    'font.serif': ['Computer Modern Roman'],\n",
    "    'text.latex.preamble': r'\\usepackage{amsmath} \\usepackage{type1cm}',\n",
    "    'font.size': FONTSIZE,\n",
    "})\n",
    "\n",
    "# For reproducibility\n",
    "random.seed(42)\n",
    "\n",
    "def create_subclusters(\n",
    "    clusters: List[List[str]],\n",
    "    ground_truths: List[str],\n",
    "    max_reads: Optional[int] = 10,\n",
    "    truncate: bool = False\n",
    ") -> Tuple[List[str], List[int]]:\n",
    "    \"\"\"\n",
    "    For each cluster, either:\n",
    "      - If truncate=True: randomly sample up to max_reads reads (if the cluster is larger)\n",
    "      - If truncate=False: split the cluster into random‐sized subclusters (each of size between 2 and max_reads),\n",
    "        until fewer than 2 reads remain.\n",
    "\n",
    "    Returns:\n",
    "        examples: List of strings read1|read2|...|readN:ground_truth\n",
    "        new_cluster_sizes: List of the corresponding subcluster sizes\n",
    "    \"\"\"\n",
    "    examples: List[str] = []\n",
    "    new_cluster_sizes: List[int] = []\n",
    "\n",
    "    for reads, gt in zip(clusters, ground_truths):\n",
    "        reads_copy = reads.copy()\n",
    "        random.shuffle(reads_copy)\n",
    "\n",
    "        if truncate:\n",
    "            # If cluster is larger than max_reads, randomly pick max_reads reads\n",
    "            if max_reads is not None and len(reads_copy) > max_reads:\n",
    "                sampled = random.sample(reads_copy, max_reads)\n",
    "            else:\n",
    "                sampled = reads_copy\n",
    "\n",
    "            if len(sampled) >= 2:\n",
    "                example = \"|\".join(sampled) + \":\" + gt\n",
    "                examples.append(example)\n",
    "                new_cluster_sizes.append(len(sampled))\n",
    "\n",
    "        else:\n",
    "            # Slice off random‐sized subclusters until fewer than 2 remain\n",
    "            reads_remaining = reads_copy\n",
    "            if max_reads is None:\n",
    "                # One full cluster if large enough\n",
    "                if len(reads_remaining) >= 2:\n",
    "                    example = \"|\".join(reads_remaining) + \":\" + gt\n",
    "                    examples.append(example)\n",
    "                    new_cluster_sizes.append(len(reads_remaining))\n",
    "            else:\n",
    "                while len(reads_remaining) >= 2:\n",
    "                    subcluster_size = random.randint(2, min(max_reads, len(reads_remaining)))\n",
    "                    subcluster_reads = reads_remaining[:subcluster_size]\n",
    "                    reads_remaining = reads_remaining[subcluster_size:]\n",
    "                    example = \"|\".join(subcluster_reads) + \":\" + gt\n",
    "                    examples.append(example)\n",
    "                    new_cluster_sizes.append(len(subcluster_reads))\n",
    "                # Discard leftover if only 1 read remains\n",
    "\n",
    "    return examples, new_cluster_sizes\n",
    "\n",
    "\n",
    "# Create subclusters for test and val set. For train set we will sample cluster size dynamically during training. \n",
    "train_examples, train_cluster_sizes = create_subclusters(\n",
    "    train_clusters, train_ground_truth, max_reads=None\n",
    ")\n",
    "val_examples, val_cluster_sizes = create_subclusters(\n",
    "    val_clusters, val_ground_truth\n",
    ")\n",
    "test_examples, test_cluster_sizes = create_subclusters(\n",
    "    test_clusters, test_ground_truth\n",
    ")\n",
    "\n",
    "print(f\"Total training examples:   {len(train_examples)}\")\n",
    "print(f\"Total validation examples: {len(val_examples)}\")\n",
    "print(f\"Total test examples:       {len(test_examples)}\")\n",
    "\n",
    "max_train_size = max(train_cluster_sizes)\n",
    "max_val_size   = max(val_cluster_sizes)\n",
    "max_test_size  = max(test_cluster_sizes)\n",
    "\n",
    "# Plot\n",
    "cmap   = cm.get_cmap('PuBu')\n",
    "colors = cmap(np.linspace(0.2, 0.8, 3))   \n",
    "labels = ['Train', 'Val', 'Test']\n",
    "\n",
    "fig, axs = plt.subplots(\n",
    "    1, 3,\n",
    "    figsize=(5.5, 1.3),\n",
    "    dpi=300,\n",
    "    gridspec_kw={'wspace': 0.4}\n",
    ")\n",
    "\n",
    "# Apply light‐gray spines & tick formatting to each axes\n",
    "for ax in axs:\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_color('lightgray')\n",
    "    ax.tick_params(\n",
    "        axis='both',\n",
    "        which='both',\n",
    "        color='lightgray',\n",
    "        labelcolor='black'\n",
    "    )\n",
    "\n",
    "# Plot 1: Train Subcluster Sizes\n",
    "bins_train = range(2, max_train_size + 2)\n",
    "axs[0].hist(\n",
    "    train_cluster_sizes,\n",
    "    bins=bins_train,\n",
    "    align='left',\n",
    "    rwidth=0.8,\n",
    "    color=colors[0],\n",
    "    edgecolor='none'\n",
    ")\n",
    "axs[0].set_xlabel(\"Reads per subcluster\")\n",
    "axs[0].set_ylabel(\"Frequency\")\n",
    "axs[0].grid(True, which='both', linestyle='--', linewidth=0.3)\n",
    "\n",
    "# Plot 2: Val Subcluster Sizes \n",
    "bins_val = range(2, max_val_size + 2)\n",
    "axs[1].hist(\n",
    "    val_cluster_sizes,\n",
    "    bins=bins_val,\n",
    "    align='left',\n",
    "    rwidth=0.8,\n",
    "    color=colors[1],\n",
    "    edgecolor='none'\n",
    ")\n",
    "axs[1].set_xlabel(\"Reads per subcluster\")\n",
    "axs[1].grid(True, which='both', linestyle='--', linewidth=0.3)\n",
    "\n",
    "# Plot 3: Test Subcluster Sizes \n",
    "bins_test = range(2, max_test_size + 2)\n",
    "axs[2].hist(\n",
    "    test_cluster_sizes,\n",
    "    bins=bins_test,\n",
    "    align='left',\n",
    "    rwidth=0.8,\n",
    "    color=colors[2],\n",
    "    edgecolor='none'\n",
    ")\n",
    "axs[2].set_xlabel(\"Reads per subcluster\")\n",
    "axs[2].grid(True, which='both', linestyle='--', linewidth=0.3)\n",
    "\n",
    "handles = [Patch(facecolor=c, edgecolor='none') for c in colors]\n",
    "fig.legend(\n",
    "    handles, \n",
    "    labels,\n",
    "    loc='upper center',\n",
    "    ncol=3,\n",
    "    frameon=False,\n",
    "    fontsize=FONTSIZE\n",
    ")\n",
    "\n",
    "plt.tight_layout(rect=[0, 0, 1, 0.85])\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.) Visualize Average Levenshtein Distance per Subcluster Size (Test Set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from Levenshtein import distance  \n",
    "from collections import defaultdict\n",
    "import numpy as np\n",
    "\n",
    "# Settings \n",
    "FONTSIZE = 6\n",
    "plt.rcParams.update({\n",
    "    'text.usetex': True,\n",
    "    'font.family': 'serif',\n",
    "    'font.serif': ['Computer Modern Roman'],\n",
    "    'text.latex.preamble': r'\\usepackage{amsmath} \\usepackage{type1cm}',\n",
    "    'font.size': FONTSIZE,\n",
    "})\n",
    "\n",
    "dist_by_size = defaultdict(list)   # key = N reads, value = list[float]\n",
    "\n",
    "for ex in test_examples:\n",
    "    reads_part, gt = ex.split(\":\")\n",
    "    reads = reads_part.split(\"|\")\n",
    "    N = len(reads)\n",
    "\n",
    "    dists = [distance(r, gt) for r in reads]\n",
    "    dist_by_size[N].append(np.mean(dists))\n",
    "\n",
    "# Plot\n",
    "sizes   = sorted(dist_by_size.keys())\n",
    "avg_dist = [np.mean(dist_by_size[n]) for n in sizes]\n",
    "err      = [np.std (dist_by_size[n]) for n in sizes]   \n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(1.5,1), dpi=300)\n",
    "\n",
    "# Apply light‐gray spines & tick formatting \n",
    "for spine in ax.spines.values():\n",
    "    spine.set_color('lightgray')\n",
    "ax.tick_params(\n",
    "    axis='both',\n",
    "    which='both',\n",
    "    color='lightgray',\n",
    "    labelcolor='black'\n",
    ")\n",
    "\n",
    "# Plot with error bars\n",
    "ax.errorbar(\n",
    "    sizes,\n",
    "    avg_dist,\n",
    "    yerr=err,\n",
    "    fmt='o-',\n",
    "    capsize=1.5,\n",
    "    color=colors[2],\n",
    "    markersize=1.5, \n",
    "    linewidth=0.5,             \n",
    "    markerfacecolor=colors[2]\n",
    ")\n",
    "ax.set_xlabel(\"Sub‐cluster size N\")\n",
    "ax.set_ylabel(\"Avg. $d_L$\")\n",
    "ax.grid(True, which='both', linestyle='--', linewidth=0.3)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print(\"#clusters per N\")\n",
    "for n in sizes:\n",
    "    count = len(dist_by_size[n])\n",
    "    mean_val = avg_dist[sizes.index(n)]\n",
    "    print(f\"  N={n:2d}: {count:>5d} clusters   mean={mean_val:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print some example formatted training samples\n",
    "print(\"\\nSample training examples:\")\n",
    "for ex in train_examples[:5]:  \n",
    "    print(ex)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.) Save Train, Val, and Test Data to Disk and Upload Test Data to Weights & Biases"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import wandb\n",
    "import torch\n",
    "import pickle\n",
    "import numpy as np\n",
    "from datetime import datetime\n",
    "from pathlib import Path\n",
    "\n",
    "\n",
    "# Path Setup\n",
    "\n",
    "project_root = Path().cwd().parents[1]        \n",
    "\n",
    "sys.path.insert(0, str(project_root))\n",
    "\n",
    "from src.data_pkg.prepare import encode_list, pad_encoded_data\n",
    "\n",
    "# Config\n",
    "wandb_entity = \n",
    "wandb_project = \"TRACE_RECONSTRUCTION\"\n",
    "sequence_type = \"nuc\"\n",
    "block_size = 1500\n",
    "output_dir = \n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "# Load vocab\n",
    "data_pkg_dir = os.path.join(project_root, \"src\", \"data_pkg\")\n",
    "meta_path = os.path.join(data_pkg_dir, f'meta_{sequence_type}.pkl')\n",
    "\n",
    "if os.path.exists(meta_path):\n",
    "    with open(meta_path, 'rb') as f:\n",
    "        meta = pickle.load(f)\n",
    "    meta_vocab_size = meta['vocab_size']\n",
    "    stoi, itos = meta['stoi'], meta['itos']\n",
    "    encode = lambda s: [stoi[c] for c in s]\n",
    "    decode = lambda l: ''.join([itos[i] for i in l])\n",
    "else:\n",
    "    raise FileNotFoundError(f\"Meta file not found: {meta_path}\")\n",
    "\n",
    "# Save raw text\n",
    "def write_data_to_file(filepath, data):\n",
    "    with open(filepath, 'w') as f:\n",
    "        for line in data:\n",
    "            f.write(line.strip() + '\\n')\n",
    "\n",
    "write_data_to_file(os.path.join(output_dir, \"train.txt\"), train_examples)  # e.g., read1|read2:gt\n",
    "write_data_to_file(os.path.join(output_dir, \"val.txt\"), val_examples)\n",
    "write_data_to_file(os.path.join(output_dir, \"test.txt\"), test_examples)\n",
    "\n",
    "# Regenerate reads.txt and ground_truth.txt from test_examples\n",
    "reads_txt_path = os.path.join(output_dir, \"reads.txt\")\n",
    "ground_truth_path = os.path.join(output_dir, \"ground_truth.txt\")\n",
    "\n",
    "with open(reads_txt_path, 'w') as rf, open(ground_truth_path, 'w') as gf:\n",
    "    for ex in test_examples:\n",
    "        reads_part, gt = ex.split(\":\")\n",
    "        reads = reads_part.split(\"|\")\n",
    "        for r in reads:\n",
    "            rf.write(r.strip() + '\\n')\n",
    "        rf.write('===============================\\n')\n",
    "        gf.write(gt.strip() + '\\n')\n",
    "\n",
    "# Encode and pad only test data (val and train is done in finetuning script)\n",
    "def encode_and_pad(sequences, name):\n",
    "    print(f\"Number of {name} examples before encoding: {len(sequences)}\")\n",
    "    \n",
    "    encoded = encode_list(sequences, stoi)\n",
    "    print(f\"Successfully encoded {len(encoded)} examples\")\n",
    "\n",
    "    padded = pad_encoded_data(encoded, block_size, stoi)\n",
    "    print(f\"Padded to block size {block_size} — total examples: {len(padded)}\")\n",
    "\n",
    "    np_data = np.array(padded, dtype=np.int64)\n",
    "    print(f\"Final NumPy array shape: {np_data.shape}\")\n",
    "\n",
    "    x = torch.from_numpy(np_data[:, 0:block_size-1])\n",
    "    y = torch.from_numpy(np_data[:, 1:block_size])\n",
    "    print(f\"Final torch tensor shapes — x: {x.shape}, y: {y.shape}\")\n",
    "\n",
    "    torch.save(x, os.path.join(output_dir, f\"{name}_x.pt\"))\n",
    "    torch.save(y, os.path.join(output_dir, f\"{name}_y.pt\"))\n",
    "\n",
    "# Only encode and save test set\n",
    "encode_and_pad(test_examples, \"test\")\n",
    "\n",
    "# Upload test set to W&B if < 2 GiB \n",
    "test_files = [\"test.txt\", \"test_x.pt\", \"test_y.pt\", \"ground_truth.txt\", \"reads.txt\"]\n",
    "total_bytes = sum(os.path.getsize(os.path.join(output_dir, f)) for f in test_files)\n",
    "max_bytes = 2 * 1024**3  # 2 GiB\n",
    "\n",
    "if total_bytes < max_bytes:\n",
    "    run = wandb.init(\n",
    "        project=wandb_project,\n",
    "        entity=wandb_entity,\n",
    "        name=\"microsoft_test_artifact\",\n",
    "        job_type=\"data_processing\"\n",
    "    )\n",
    "\n",
    "    artifact_test = wandb.Artifact(\n",
    "        name=f\"Microsoft-test-{datetime.now().strftime('%Y%m%d_%H%M%S')}\",\n",
    "        type=\"dataset\",\n",
    "        description=\"Microsoft DNA dataset test split with torch tensors\",\n",
    "        metadata={\n",
    "            \"sequence_type\": sequence_type,\n",
    "            \"block_size\": block_size,\n",
    "            \"vocab_size\": meta_vocab_size,\n",
    "        }\n",
    "    )\n",
    "\n",
    "    for f in test_files:\n",
    "        artifact_test.add_file(os.path.join(output_dir, f))\n",
    "\n",
    "    run.log_artifact(artifact_test)\n",
    "    run.finish()\n",
    "else:\n",
    "    print(f\"Test data size {total_bytes / 1024**3:.2f} GiB exceeds 2 GiB — skipping W&B artifact.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print stats and sample examples\n",
    "print(f\"\\n# Train examples: {len(train_examples)}\")\n",
    "print(\"First 5 train examples:\")\n",
    "for ex in train_examples[:5]:\n",
    "    print(\"  \", ex)\n",
    "\n",
    "print(f\"\\n# Validation examples: {len(val_examples)}\")\n",
    "print(\"First 5 val examples:\")\n",
    "for ex in val_examples[:5]:\n",
    "    print(\"  \", ex)\n",
    "\n",
    "print(f\"\\n# Test examples: {len(test_examples)}\")\n",
    "print(\"First 5 test examples:\")\n",
    "for ex in test_examples[:5]:\n",
    "    print(\"  \", ex)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reclustering test data for self-reported DNAformer comparison\n",
    "\n",
    "The Microsoft dataset does not provide explicit indices. Therefore, we first extract the shortest prefix length that uniquely identifies all ground truth sequences in the test set. These unique prefixes are then treated as indices for clustering, following the approach described in [Bar-Lev et al. (2025)](https://www.nature.com/articles/s42256-025-01003-z).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "\n",
    "def find_min_unique_prefix_length(sequences: List[str]) -> int:\n",
    "    \"\"\"Find the shortest prefix length such that all sequences are uniquely identified.\"\"\"\n",
    "    max_len = max(len(seq) for seq in sequences) \n",
    "    for prefix_len in range(1, max_len + 1):\n",
    "        prefixes = [seq[:prefix_len] for seq in sequences]\n",
    "        if len(set(prefixes)) == len(sequences):\n",
    "            return prefix_len\n",
    "    return max_len  # Fallback: full length needed\n",
    "\n",
    "\n",
    "min_prefix_length = find_min_unique_prefix_length(test_ground_truth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Minimum prefix length is {min_prefix_length}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "from statistics import mean\n",
    "\n",
    "# Flatten all reads from clusters\n",
    "all_reads = []\n",
    "for cluster in clusters:\n",
    "    all_reads.extend(cluster)\n",
    "\n",
    "# Extract unique prefixes from ground_truth\n",
    "def extract_unique_prefixes(sequences: List[str], prefix_len: int) -> dict:\n",
    "    \"\"\"Returns a mapping from unique prefix to full sequence.\"\"\"\n",
    "    mapping = {}\n",
    "    for seq in sequences:\n",
    "        prefix = seq[:prefix_len]\n",
    "        mapping[prefix] = seq\n",
    "    return mapping\n",
    "\n",
    "prefix_to_gt = extract_unique_prefixes(ground_truth_list, min_prefix_length)\n",
    "\n",
    "# Bin reads by prefix\n",
    "binned_reads = defaultdict(list)\n",
    "for read in all_reads:\n",
    "    prefix = read[:min_prefix_length]\n",
    "    if prefix in prefix_to_gt:  # only keep reads that match a known prefix\n",
    "        binned_reads[prefix].append(read)\n",
    "\n",
    "# Filter out singleton bins\n",
    "filtered_bins = {prefix: reads for prefix, reads in binned_reads.items() if len(reads) > 1}\n",
    "\n",
    "# Create final reclustered set with ground truths\n",
    "reclustered = []\n",
    "for prefix, reads in filtered_bins.items():\n",
    "    gt_seq = prefix_to_gt[prefix]\n",
    "    reclustered.append({\n",
    "        \"index\": prefix,\n",
    "        \"data\": gt_seq,\n",
    "        \"noisy_copies\": reads\n",
    "    })\n",
    "\n",
    "# Compute stats\n",
    "num_clusters = len(reclustered)\n",
    "avg_cluster_size = mean(len(entry[\"noisy_copies\"]) for entry in reclustered)\n",
    "\n",
    "# Assertions\n",
    "#assert num_clusters == 9954, f\"Expected 9954 clusters, got {num_clusters}\"\n",
    "#assert abs(avg_cluster_size - 17) < 0.5, f\"Expected average cluster size ~17, got {avg_cluster_size:.2f}\"\n",
    "\n",
    "print(f\"Reclustered set approximately matches stats reported by Bar-Lev et al. (2025): {num_clusters} clusters vs. 9954 (Bar-Lev et al.), average size {avg_cluster_size:.2f} vs. 17 (Bar-Lev et al.)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cluster_sizes = [len(entry[\"noisy_copies\"]) for entry in reclustered]\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.hist(cluster_sizes, bins=range(1, max(cluster_sizes)+2))\n",
    "plt.xlabel(\"Cluster size\")\n",
    "plt.ylabel(\"Number of clusters\")\n",
    "plt.title(\"Reclustered Size Distribution\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract lists of reads and ground truths\n",
    "reads_list = [entry[\"noisy_copies\"] for entry in reclustered]\n",
    "gt_list = [entry[\"data\"] for entry in reclustered]\n",
    "\n",
    "# Create examples in read1|read2|...|readN:ground_truth format and truncate to 10 reads to fit models context\n",
    "reclustered_examples, reclustered_cluster_sizes = create_subclusters(\n",
    "    reads_list, gt_list, truncate=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cluster_sizes = [\n",
    "    len(example.split(\":\", 1)[0].split(\"|\"))\n",
    "    for example in reclustered_examples\n",
    "]\n",
    "\n",
    "plt.hist(\n",
    "    cluster_sizes,\n",
    "    bins=range(1, max(cluster_sizes) + 2),\n",
    "    align='left'\n",
    ")\n",
    "plt.xlabel(\"Cluster size\")\n",
    "plt.ylabel(\"Number of clusters\")\n",
    "plt.title(\"Reclustered Size Distribution\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define output directory\n",
    "reclustered_output_dir = os.path.join(notebook_dir, \"reclustered_microsoft\")\n",
    "os.makedirs(reclustered_output_dir, exist_ok=True)\n",
    "\n",
    "# Save to .txt\n",
    "write_data_to_file(os.path.join(reclustered_output_dir, \"test.txt\"), reclustered_examples)\n",
    "\n",
    "# Regenerate reads.txt and ground_truth.txt\n",
    "with open(os.path.join(reclustered_output_dir, \"reads.txt\"), 'w') as rf, \\\n",
    "     open(os.path.join(reclustered_output_dir, \"ground_truth.txt\"), 'w') as gf:\n",
    "    for ex in reclustered_examples:\n",
    "        reads_part, gt = ex.split(\":\")\n",
    "        reads = reads_part.split(\"|\")\n",
    "        for r in reads:\n",
    "            rf.write(r.strip() + '\\n')\n",
    "        rf.write('===============================\\n')\n",
    "        gf.write(gt.strip() + '\\n')\n",
    "\n",
    "# Encode and pad\n",
    "encode_and_pad(reclustered_examples, os.path.join(reclustered_output_dir, \"test\"))  # will save test_x.pt and test_y.pt\n",
    "\n",
    "# Upload to W&B\n",
    "reclustered_test_files = [\"test.txt\", \"test_x.pt\", \"test_y.pt\", \"ground_truth.txt\", \"reads.txt\"]\n",
    "total_bytes = sum(os.path.getsize(os.path.join(reclustered_output_dir, f)) for f in reclustered_test_files)\n",
    "\n",
    "if total_bytes < max_bytes:\n",
    "    run = wandb.init(\n",
    "        project=wandb_project,\n",
    "        entity=wandb_entity,\n",
    "        name=\"microsoft_reclustered_test_artifact\",\n",
    "        job_type=\"data_processing\"\n",
    "    )\n",
    "\n",
    "    artifact_test = wandb.Artifact(\n",
    "        name=f\"Microsoft-test-reclustered-{datetime.now().strftime('%Y%m%d_%H%M%S')}\",\n",
    "        type=\"dataset\",\n",
    "        description=\"Microsoft DNA dataset test split reclustered by unique prefix (DNAformer-style)\",\n",
    "        metadata={\n",
    "            \"sequence_type\": sequence_type,\n",
    "            \"block_size\": block_size,\n",
    "            \"vocab_size\": meta_vocab_size,\n",
    "            \"reclustered\": True,\n",
    "            \"min_prefix_length\": min_prefix_length,\n",
    "        }\n",
    "    )\n",
    "\n",
    "    for f in reclustered_test_files:\n",
    "        artifact_test.add_file(os.path.join(reclustered_output_dir, f))\n",
    "\n",
    "    run.log_artifact(artifact_test)\n",
    "    run.finish()\n",
    "else:\n",
    "    print(f\"Reclustered test data size {total_bytes / 1024**3:.2f} GiB exceeds 2 GiB — skipping W&B artifact.\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "treconlm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
