{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Noisy DNA dataset\n",
    "\n",
    "Clone the GitHub repository into the folder where this notebook is located:\n",
    "\n",
    "- GitHub repository (contains the ground truth sequences): https://github.com/-lab/noisy_dna_data_storage\n",
    "\n",
    "Then, download the noisy reads from the following Figshare link and place them in the data folder of the cloned repository:\n",
    "\n",
    "- Noisy reads (Figshare): https://figshare.com/s/cd611884b34a8c89f4b4\n",
    "\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 1.) Load data from file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "\n",
    "# Get the directory where the notebook is\n",
    "notebook_dir = os.getcwd()\n",
    "\n",
    "# Go up two levels to reach the project root, then 'src'\n",
    "src_path = os.path.abspath(os.path.join(notebook_dir, \"..\", \"..\", \"src\"))\n",
    "sys.path.append(src_path)\n",
    "\n",
    "from utils.noisy_dna import file_to_list, fastq_to_list\n",
    "\n",
    "# Path to data directory\n",
    "data_dir = os.path.join(notebook_dir, 'noisy_dna_data_storage', 'data')\n",
    "\n",
    "# Full path to the .txt file\n",
    "orig_file = os.path.join(data_dir, \"File1_ODNA.txt\")\n",
    "orig_seqs = file_to_list(orig_file)\n",
    "\n",
    "# Full path to the .fastq file \n",
    "fastq_file = os.path.join(data_dir, \"I16_S2_R1_001.fastq.gz\")\n",
    "seqs = fastq_to_list(fastq_file)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 2.) Filter out too short/long reads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"sequences are, e.g.,  {seqs[:3]}\")\n",
    "print(f\"Ground truth are, e.g.,  {orig_seqs[:3]}\")\n",
    "\n",
    "\n",
    "print(\"all sequences: \", len(seqs))\n",
    "print(\"all orig sequences: \", len(orig_seqs))\n",
    "reads = [seq for seq in seqs if len(seq) >= 55 and len(seq)<=70]\n",
    "print(\"all trimmed sequences: \",  len(reads))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.) Get index to ground truth sequence map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "from helper import decode_dna_index \n",
    "\n",
    "start_pos = 42\n",
    "end_pos = 54\n",
    "\n",
    "# Initialize containers\n",
    "decoded_binary = []\n",
    "decoded_ints = []\n",
    "index_to_dna = {}\n",
    "gt_map = {}\n",
    "\n",
    "# Decode each ground truth sequence's index\n",
    "for orig_seq in orig_seqs:\n",
    "    idx = orig_seq[start_pos:end_pos]\n",
    "    if idx.startswith(\"ACAAC\"):\n",
    "        try:\n",
    "            binary_str, int_val = decode_dna_index(idx)\n",
    "            decoded_binary.append(binary_str)\n",
    "            decoded_ints.append(int_val)\n",
    "            index_to_dna[int_val] = idx  # maps to DNA index string\n",
    "            gt_map[int_val] = orig_seq   # maps to full sequence \n",
    "        except Exception as e:\n",
    "            print(f\"Failed to decode {idx}: {e}\")\n",
    "    else:\n",
    "        print(f\"Index {idx} does not start with ACAAC.\")\n",
    "\n",
    "duplicates = [k for k, v in Counter(decoded_ints).items() if v > 1]\n",
    "decoded_set = set(decoded_ints)\n",
    "expected_set = set(range(16384))\n",
    "missing = expected_set - decoded_set\n",
    "extra = decoded_set - expected_set\n",
    "\n",
    "print(\"Example mappings:\")\n",
    "for i in range(5):\n",
    "    print(f\"DNA Index: {index_to_dna[decoded_ints[i]]}, Binary = {decoded_binary[i]}, Int = {decoded_ints[i]}\")\n",
    "\n",
    "if missing:\n",
    "    print(f\"Missing {len(missing)} index values. Example(s): {sorted(list(missing))[:10]}\")\n",
    "else:\n",
    "    print(\"All integer indices from 0 to 16383 are present.\")\n",
    "\n",
    "if extra:\n",
    "    print(f\"Unexpected extra index values: {sorted(list(extra))[:10]}\")\n",
    "else:\n",
    "    print(\"No extra index values found.\")\n",
    "\n",
    "if duplicates:\n",
    "    print(f\"Found {len(duplicates)} duplicated integer indices. Examples: {duplicates[:10]}\")\n",
    "else:\n",
    "    print(\"No duplicate integer indices.\")\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.) Sweep over sliding windows around true index positions and error thresholds to known index placeholder pattern ( for this dataset true start and end position and threshold 0 worked best)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import csv\n",
    "import pickle  \n",
    "from helper import cluster_by_index_region\n",
    "\n",
    "# True index window\n",
    "start_pos = 42\n",
    "end_pos = 54\n",
    "max_expand = 6 # set to larger to perform sweep\n",
    "thresholds = [0,1,2] # add 1,2, etc. to allow error threshold\n",
    "\n",
    "# Configs\n",
    "print_flag = True\n",
    "\n",
    "# Run sweep\n",
    "for i in range(0, max_expand + 1):\n",
    "    start_window = start_pos - i\n",
    "    end_window = end_pos + i\n",
    "    for threshold in thresholds:\n",
    "        print(f\"\\nRunning with start={start_window}, end={end_window}, threshold={threshold}\")\n",
    "        clusters, failed = cluster_by_index_region(\n",
    "            reads=reads,\n",
    "            decoded_ints=decoded_ints,\n",
    "            threshold=threshold,\n",
    "            start_window=start_window,\n",
    "            end_window=end_window,\n",
    "            print_flag=print_flag\n",
    "        )\n",
    "\n",
    "        # Ensure we are storing sequences, not indices\n",
    "        for cluster_id, cluster_reads in clusters.items():\n",
    "            for r in cluster_reads:\n",
    "                if not isinstance(r, str):\n",
    "                    raise ValueError(f\"Cluster {cluster_id} contains non-string item: {r} (type: {type(r)})\")\n",
    "\n",
    "        # Save CSV version (read strings)\n",
    "        filename = f\"index_clusters_sw{start_window}_ew{end_window}_th{threshold}.csv\"\n",
    "        save_path = os.path.join(data_dir, filename)\n",
    "        with open(save_path, 'w', newline='') as f:\n",
    "            writer = csv.writer(f)\n",
    "            for cluster_reads in clusters.values():\n",
    "                writer.writerow(cluster_reads)\n",
    "        print(f\"Saved clusters to: {save_path}\")\n",
    "\n",
    "        # Save dictionary version as .pkl\n",
    "        dict_filename = f\"index_clusters_sw{start_window}_ew{end_window}_th{threshold}.pkl\"\n",
    "        dict_path = os.path.join(data_dir, dict_filename)\n",
    "        with open(dict_path, 'wb') as f:\n",
    "            pickle.dump(clusters, f)\n",
    "        print(f\"Saved cluster dictionary to: {dict_path}\")\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.) Evaluate sliding‐window and error‐threshold combinations by minimizing misclustering rate\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import csv\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from Levenshtein import distance as levenshtein_distance\n",
    "\n",
    "# Configs\n",
    "expected_mean = 60 * (0.057 + 0.06 + 0.026)  # = 8.58\n",
    "min_expected = 5\n",
    "max_expected = 13\n",
    "\n",
    "\n",
    "results = []\n",
    "\n",
    "for fname in sorted(os.listdir(data_dir)):\n",
    "    if not fname.endswith(\".pkl\"):\n",
    "        continue\n",
    "\n",
    "    full_path = os.path.join(data_dir, fname)\n",
    "    with open(full_path, 'rb') as f:\n",
    "        clusters = pickle.load(f)  # clusters[int_val] = list of read strings\n",
    "\n",
    "    distances = []\n",
    "    filtered_clusters = {}\n",
    "    clusters_with_outliers = set()\n",
    "    too_few_edits = 0\n",
    "    too_many_edits = 0\n",
    "    total_reads = 0\n",
    "    misclustered_reads = 0\n",
    "\n",
    "    for int_val, read_list in clusters.items():\n",
    "        gt_seq = gt_map.get(int_val)\n",
    "        if gt_seq is None:\n",
    "            continue\n",
    "\n",
    "        cluster_dists = []\n",
    "        filtered_reads = []\n",
    "\n",
    "        for read in read_list:\n",
    "\n",
    "            if not isinstance(read, str) or not isinstance(gt_seq, str):\n",
    "                print(f\"[TYPE ERROR] int_val={int_val} | gt_seq type: {type(gt_seq)} | read type: {type(read)}\")\n",
    "                continue\n",
    "\n",
    "\n",
    "            dist = levenshtein_distance(gt_seq, read)\n",
    "            distances.append(dist)\n",
    "            cluster_dists.append(dist)\n",
    "\n",
    "            total_reads += 1\n",
    "            if min_expected <= dist <= max_expected:\n",
    "                filtered_reads.append(read)\n",
    "            else:\n",
    "                misclustered_reads += 1\n",
    "                clusters_with_outliers.add(int_val)\n",
    "                if dist < min_expected:\n",
    "                    too_few_edits += 1\n",
    "                else:\n",
    "                    too_many_edits += 1\n",
    "\n",
    "        if filtered_reads:\n",
    "            filtered_clusters[int_val] = filtered_reads\n",
    "\n",
    "    rate = misclustered_reads / total_reads if total_reads > 0 else np.nan\n",
    "    results.append((fname, total_reads, misclustered_reads, rate))\n",
    "\n",
    "    print(f\"\\n{fname:45} | Total: {total_reads:5d} | Misclustered: {misclustered_reads:5d} | Rate: {rate*100:.2f}%\")\n",
    "\n",
    "    # Plot histogram of Levenshtein distances \n",
    "    plt.figure(figsize=(8, 5))\n",
    "    plt.hist(distances, bins=50, edgecolor='black', alpha=0.8)\n",
    "    plt.axvspan(0, min_expected, color='lightgrey', alpha=0.5, label=\"Too few edits\")\n",
    "    plt.axvspan(max_expected, max(distances), color='lightgrey', alpha=0.5, label=\"Too many edits\")\n",
    "    plt.axvline(min_expected, color='red', linestyle='--', label=f\"Min Expected ({min_expected})\")\n",
    "    plt.axvline(max_expected, color='red', linestyle='--')\n",
    "    plt.axvline(expected_mean, color='green', linestyle='--', label=f\"Expected Mean ({expected_mean:.2f})\")\n",
    "    plt.xlabel(\"Levenshtein Distance to Ground Truth\")\n",
    "    plt.ylabel(\"Frequency\")\n",
    "    plt.title(f\"Distance Distribution {fname}\")\n",
    "    plt.legend()\n",
    "    plt.tight_layout()\n",
    "    plot_path = os.path.join(data_dir, fname.replace(\".pkl\", \"_dist_hist.pdf\"))\n",
    "    plt.savefig(plot_path)\n",
    "    plt.close()\n",
    "\n",
    "    # Plot distribution of filtered cluster sizes \n",
    "    cluster_sizes = [len(reads) for reads in filtered_clusters.values()]\n",
    "    plt.figure(figsize=(8, 5))\n",
    "    plt.hist(cluster_sizes, bins=50, edgecolor='black', alpha=0.8)\n",
    "    plt.xlabel(\"Cluster Size (filtered)\")\n",
    "    plt.ylabel(\"Frequency\")\n",
    "    plt.title(f\"Filtered Cluster Sizes — {fname}\")\n",
    "    plt.tight_layout()\n",
    "    plot_path = os.path.join(data_dir, fname.replace(\".pkl\", \"_size_hist.pdf\"))\n",
    "    plt.savefig(plot_path)\n",
    "    plt.close()\n",
    "\n",
    "# Summary printout\n",
    "print(\"\\nSorted Summary of All Clusterings\")\n",
    "results.sort(key=lambda x: x[-1])  # sort by miscluster rate\n",
    "for fname, total, mis, rate in results:\n",
    "    print(f\"{fname:45} | Total: {total:5d} | Misclustered: {mis:5d} | Rate: {rate*100:.2f}%\")\n",
    "\n",
    "# Optional: save summary to CSV\n",
    "summary_path = os.path.join(data_dir, \"clustering_eval_summary.csv\")\n",
    "with open(summary_path, 'w', newline='') as f:\n",
    "    writer = csv.writer(f)\n",
    "    writer.writerow([\"filename\", \"total_reads\", \"misclustered_reads\", \"miscluster_rate\"])\n",
    "    for fname, total, mis, rate in results:\n",
    "        writer.writerow([fname, total, mis, f\"{rate:.4f}\"])\n",
    "print(f\"\\nSaved summary to: {summary_path}\")\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.) Split in 80\\% train and 10\\% validation and test sets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import random\n",
    "import csv\n",
    "\n",
    "# Configs\n",
    "random.seed(42)\n",
    "pkl_path = os.path.join(data_dir,\"index_clusters_sw42_ew54_th0.pkl\")\n",
    "\n",
    "out_dir = os.path.join(data_dir, \"splits\")\n",
    "os.makedirs(out_dir, exist_ok=True)\n",
    "# Load clusters\n",
    "with open(pkl_path, 'rb') as f:\n",
    "    clusters = pickle.load(f)\n",
    "\n",
    "\n",
    "# Filter only valid cluster keys (i.e., those with a ground truth)\n",
    "valid_ids = [k for k in clusters if k in gt_map]\n",
    "random.shuffle(valid_ids)\n",
    "\n",
    "# Split cluster IDs\n",
    "n = len(valid_ids)\n",
    "n_train = int(0.8 * n)\n",
    "n_val = int(0.1 * n)\n",
    "n_test = n - n_train - n_val\n",
    "\n",
    "train_ids = valid_ids[:n_train]\n",
    "val_ids   = valid_ids[n_train:n_train + n_val]\n",
    "test_ids  = valid_ids[n_train + n_val:]\n",
    "\n",
    "# Helper to extract clusters and ground truths \n",
    "def extract_split_data(split_ids):\n",
    "    split_clusters = [clusters[k] for k in split_ids]\n",
    "    split_gts = [gt_map[k] for k in split_ids]\n",
    "    return split_clusters, split_gts\n",
    "\n",
    "train_clusters, train_ground_truth = extract_split_data(train_ids)\n",
    "val_clusters, val_ground_truth     = extract_split_data(val_ids)\n",
    "test_clusters, test_ground_truth   = extract_split_data(test_ids)\n",
    "\n",
    "# Optional save read-gt pairs to CSVs\n",
    "def save_to_csv(clusters, gts, path):\n",
    "    with open(path, 'w', newline='') as f:\n",
    "        writer = csv.writer(f)\n",
    "        writer.writerow(['read', 'ground_truth'])\n",
    "        for reads, gt in zip(clusters, gts):\n",
    "            for r in reads:\n",
    "                writer.writerow([r, gt])\n",
    "\n",
    "save_to_csv(train_clusters, train_ground_truth, os.path.join(out_dir, \"train.csv\"))\n",
    "save_to_csv(val_clusters, val_ground_truth,     os.path.join(out_dir, \"val.csv\"))\n",
    "save_to_csv(test_clusters, test_ground_truth,   os.path.join(out_dir, \"test.csv\"))\n",
    "\n",
    "print(f\"Saved train/val/test splits with {len(train_clusters)} / {len(val_clusters)} / {len(test_clusters)} clusters\")\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.) Filter train and val set to not contain sequences too close to a test set gt sequence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filter if too close [5,13] distance to gt in test\n",
    "\n",
    "import Levenshtein  \n",
    "from tqdm.notebook import tqdm \n",
    "\n",
    "# Define the filter range\n",
    "min_dist = 5\n",
    "max_dist = 13\n",
    "\n",
    "# Create a set of test ground-truth sequences\n",
    "test_gts = set(test_ground_truth)\n",
    "\n",
    "def should_remove(read, test_gts):\n",
    "    \"\"\"Check if read is too close to any test-set GT\"\"\"\n",
    "    for gt in test_gts:\n",
    "        dist = Levenshtein.distance(read, gt)\n",
    "        if min_dist <= dist <= max_dist:\n",
    "            return True\n",
    "    return False\n",
    "\n",
    "def filter_clusters(clusters, gts, test_gts):\n",
    "    \"\"\"Remove reads that are too close to any test GT\"\"\"\n",
    "    filtered_clusters = []\n",
    "    filtered_gts = []\n",
    "    total_removed = 0\n",
    "    total_kept = 0\n",
    "\n",
    "    for reads, gt in tqdm(zip(clusters, gts), total=len(clusters), desc=f\"Filtering\"):\n",
    "        new_reads = [r for r in reads if not should_remove(r, test_gts)]\n",
    "        if new_reads:\n",
    "            filtered_clusters.append(new_reads)\n",
    "            filtered_gts.append(gt)\n",
    "            total_kept += len(new_reads)\n",
    "            total_removed += len(reads) - len(new_reads)\n",
    "        else:\n",
    "            total_removed += len(reads)\n",
    "\n",
    "    print(f\"Removed {total_removed} reads; Kept {total_kept}\", flush=True)\n",
    "    return filtered_clusters, filtered_gts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Apply to train and val clusters\n",
    "train_clusters_filtered,   train_gt_filtered   = filter_clusters(train_clusters, train_ground_truth, test_gts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Apply to train and val clusters\n",
    "val_clusters_filtered,   val_gt_filtered   = filter_clusters(val_clusters, val_ground_truth, test_gts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Remove any clusters that now only have 1 read\n",
    "def remove_too_small_clusters(clusters, gts, min_size=2, split_name=\"\"):\n",
    "    total = len(clusters)\n",
    "    filtered = [(c, gt) for c, gt in zip(clusters, gts) if len(c) >= min_size]\n",
    "    remaining = len(filtered)\n",
    "    removed = total - remaining\n",
    "\n",
    "    print(f\"{split_name}: Removed {removed} clusters with fewer than {min_size} reads\")\n",
    "    print(f\"{split_name}: {remaining} clusters remain\")\n",
    "\n",
    "    return [c for c, _ in filtered], [gt for _, gt in filtered]\n",
    "\n",
    "train_clusters_filtered, train_gt_filtered = remove_too_small_clusters(\n",
    "    train_clusters_filtered, train_gt_filtered, min_size=2, split_name=\"Train\"\n",
    ")\n",
    "\n",
    "val_clusters_filtered, val_gt_filtered = remove_too_small_clusters(\n",
    "    val_clusters_filtered, val_gt_filtered, min_size=2, split_name=\"Val\"\n",
    ")\n",
    "\n",
    "n_train_before = len(train_clusters)\n",
    "n_train_after_filtering = len(train_clusters_filtered)\n",
    "n_train_fully_removed = n_train_before - n_train_after_filtering\n",
    "\n",
    "print(f\"Train: {n_train_fully_removed} clusters were completely removed after filtering\")\n",
    "\n",
    "# Compute how many clusters were entirely removed after filtering\n",
    "n_val_before = len(val_clusters)\n",
    "n_val_after_filtering = len(val_clusters_filtered)\n",
    "n_val_fully_removed = n_val_before - n_val_after_filtering\n",
    "\n",
    "print(f\"Val: {n_val_fully_removed} clusters were completely removed after filtering\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inspect a few train clusters and their ground truth\n",
    "for i in range(5):  # show 5 random examples\n",
    "    print(f\"\\nCluster {i + 1}:\")\n",
    "    print(\"Ground truth:\", train_ground_truth[i])\n",
    "    print(\"Reads:\")\n",
    "    for r in train_clusters[i]:\n",
    "        print(f\"  {r}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.) Create subclusters of max. 10 reads for test and val set\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from typing import List\n",
    "import random \n",
    "from typing import List, Optional\n",
    "import random\n",
    "\n",
    "def create_subclusters(clusters: List[List[str]], ground_truths: List[str], max_reads: Optional[int] = 10):\n",
    "    \"\"\"\n",
    "    For each cluster, split it into subclusters of at most `max_reads` reads using random sampling.\n",
    "    If `max_reads` is None, use all reads in the cluster as a single subcluster.\n",
    "    Returns a list of examples in the format: read1|read2|...|readN : ground_truth\n",
    "    \"\"\"\n",
    "    examples = []\n",
    "    new_cluster_sizes = []\n",
    "\n",
    "    for reads, gt in zip(clusters, ground_truths):\n",
    "        reads = reads.copy()\n",
    "        random.shuffle(reads)\n",
    "\n",
    "        if max_reads is None:\n",
    "            # One full cluster if large enough\n",
    "            if len(reads) >= 2:\n",
    "                example = \"|\".join(reads) + \":\" + gt\n",
    "                examples.append(example)\n",
    "                new_cluster_sizes.append(len(reads))\n",
    "        else:\n",
    "            while len(reads) >= 2:\n",
    "                subcluster_size = random.randint(2, min(max_reads, len(reads)))\n",
    "                subcluster_reads = reads[:subcluster_size]\n",
    "                reads = reads[subcluster_size:]\n",
    "                example = \"|\".join(subcluster_reads) + \":\" + gt\n",
    "                examples.append(example)\n",
    "                new_cluster_sizes.append(len(subcluster_reads))\n",
    "            # discard leftover if only 1 read remains\n",
    "\n",
    "    return examples, new_cluster_sizes\n",
    "\n",
    "\n",
    "# Create subclusters for test set. For train and validation set we will dynamically during training sample cluster size. \n",
    "train_examples, train_cluster_sizes = create_subclusters(train_clusters_filtered, train_gt_filtered, max_reads=None) \n",
    "val_examples, val_cluster_sizes = create_subclusters(val_clusters_filtered, val_gt_filtered)\n",
    "test_examples, test_cluster_sizes = create_subclusters(test_clusters, test_ground_truth)  "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 8.) Filter out sequences with unexpected characters\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_and_log_sequences_with_N(examples, name=\"\"):\n",
    "    bad = [ex for ex in examples if 'N' in ex.split(\":\")[0]]\n",
    "    if bad:\n",
    "        print(f\"[WARNING] {name}: Found {len(bad)} examples with 'N'. Printing first 5:\")\n",
    "        for ex in bad[:5]:\n",
    "            print(\"  \", ex)\n",
    "    filtered = [ex for ex in examples if 'N' not in ex.split(\":\")[0]]\n",
    "    print(f\"[INFO] {name}: {len(examples) - len(filtered)} examples removed due to 'N'\")\n",
    "    return filtered"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_examples = filter_and_log_sequences_with_N(train_examples, \"train\")\n",
    "val_examples   = filter_and_log_sequences_with_N(val_examples, \"val\")\n",
    "test_examples  = filter_and_log_sequences_with_N(test_examples, \"test\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 9.) Save Train, Val, and Test Data to Disk and Upload Test Data to Weights & Biases"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import wandb\n",
    "import torch\n",
    "import pickle\n",
    "import numpy as np\n",
    "from datetime import datetime\n",
    "from pathlib import Path\n",
    "\n",
    "# Path Setup\n",
    "\n",
    "# notebook is in .../TReconLM/data/noisy_dna\n",
    "project_root = Path().cwd().parents[1]        \n",
    "\n",
    "sys.path.insert(0, str(project_root))\n",
    "\n",
    "from src.data_pkg.prepare import encode_list, pad_encoded_data\n",
    "# Config \n",
    "wandb_entity = \n",
    "wandb_project = \"TRACE_RECONSTRUCTION\"\n",
    "sequence_type = \"nuc\"\n",
    "block_size = 800\n",
    "\n",
    "# Load vocab\n",
    "data_pkg_dir = os.path.join(project_root, \"src\", \"data_pkg\")\n",
    "meta_path = os.path.join(data_pkg_dir, f'meta_{sequence_type}.pkl')\n",
    "\n",
    "if os.path.exists(meta_path):\n",
    "    with open(meta_path, 'rb') as f:\n",
    "        meta = pickle.load(f)\n",
    "    meta_vocab_size = meta['vocab_size']\n",
    "    stoi, itos = meta['stoi'], meta['itos']\n",
    "    encode = lambda s: [stoi[c] for c in s]\n",
    "    decode = lambda l: ''.join([itos[i] for i in l])\n",
    "else:\n",
    "    raise FileNotFoundError(f\"Meta file not found: {meta_path}\")\n",
    "\n",
    "\n",
    "# Clean test examples by removing trailing C's from reads \n",
    "def remove_trailing_Cs(examples):\n",
    "    cleaned = []\n",
    "    for ex in examples:\n",
    "        reads_part, gt = ex.split(\":\")\n",
    "        reads = reads_part.split(\"|\")\n",
    "        reads = [r.rstrip(\"C\") for r in reads]\n",
    "        cleaned.append(\"|\".join(reads) + \":\" + gt)\n",
    "    return cleaned\n",
    "\n",
    "test_cleaned_examples = remove_trailing_Cs(test_examples)\n",
    "\n",
    "# Save raw text data \n",
    "def write_data_to_file(filepath, data):\n",
    "    with open(filepath, 'w') as f:\n",
    "        for line in data:\n",
    "            f.write(line.strip() + '\\n')\n",
    "\n",
    "write_data_to_file(os.path.join(out_dir, \"train.txt\"), train_examples)\n",
    "write_data_to_file(os.path.join(out_dir, \"val.txt\"), val_examples)\n",
    "write_data_to_file(os.path.join(out_dir, \"test.txt\"), test_examples)\n",
    "write_data_to_file(os.path.join(out_dir, \"test_cleaned.txt\"), test_cleaned_examples)\n",
    "\n",
    "# Generate reads.txt and ground_truth.txt for test set \n",
    "def write_reads_and_gt(examples, reads_path, gt_path):\n",
    "    with open(reads_path, 'w') as rf, open(gt_path, 'w') as gf:\n",
    "        for ex in examples:\n",
    "            reads_part, gt = ex.split(\":\")\n",
    "            reads = reads_part.split(\"|\")\n",
    "            for r in reads:\n",
    "                rf.write(r.strip() + '\\n')\n",
    "            rf.write('===============================\\n')\n",
    "            gf.write(gt.strip() + '\\n')\n",
    "\n",
    "write_reads_and_gt(\n",
    "    test_examples,\n",
    "    os.path.join(out_dir, \"reads.txt\"),\n",
    "    os.path.join(out_dir, \"ground_truth.txt\")\n",
    ")\n",
    "\n",
    "write_reads_and_gt(\n",
    "    test_cleaned_examples,\n",
    "    os.path.join(out_dir, \"reads_cleaned.txt\"),\n",
    "    os.path.join(out_dir, \"ground_truth_cleaned.txt\")\n",
    ")\n",
    "\n",
    "# Encode and pad \n",
    "def encode_and_pad(sequences, name):\n",
    "    print(f\"[INFO] Encoding {name}: {len(sequences)} examples\")\n",
    "\n",
    "    encoded = encode_list(sequences, stoi)\n",
    "    padded = pad_encoded_data(encoded, block_size, stoi)\n",
    "    np_data = np.array(padded, dtype=np.int64)\n",
    "\n",
    "    x = torch.from_numpy(np_data[:, 0:block_size-1])\n",
    "    y = torch.from_numpy(np_data[:, 1:block_size])\n",
    "\n",
    "    torch.save(x, os.path.join(out_dir, f\"{name}_x.pt\"))\n",
    "    torch.save(y, os.path.join(out_dir, f\"{name}_y.pt\"))\n",
    "\n",
    "    print(f\"[INFO] Saved: {name}_x.pt and {name}_y.pt\")\n",
    "\n",
    "encode_and_pad(test_examples, \"test\")\n",
    "encode_and_pad(test_cleaned_examples, \"test_cleaned\")\n",
    "\n",
    "# Upload to W&B (if under 2 GiB) \n",
    "test_files = [\n",
    "    \"test.txt\", \"test_cleaned.txt\",\n",
    "    \"test_x.pt\", \"test_y.pt\",\n",
    "    \"test_cleaned_x.pt\", \"test_cleaned_y.pt\",\n",
    "    \"reads.txt\", \"ground_truth.txt\",\n",
    "    \"reads_cleaned.txt\", \"ground_truth_cleaned.txt\"\n",
    "]\n",
    "\n",
    "total_bytes = sum(os.path.getsize(os.path.join(out_dir, f)) for f in test_files)\n",
    "max_bytes = 2 * 1024**3  # 2 GiB\n",
    "\n",
    "if total_bytes < max_bytes:\n",
    "    run = wandb.init(\n",
    "        project=wandb_project,\n",
    "        entity=wandb_entity,\n",
    "        name=\"file1_test_cleaned_vs_raw\",\n",
    "        job_type=\"data_processing\"\n",
    "    )\n",
    "\n",
    "    artifact = wandb.Artifact(\n",
    "        name=f\"File1-test-{datetime.now().strftime('%Y%m%d_%H%M%S')}\",\n",
    "        type=\"dataset\",\n",
    "        description=\"Test set with original and cleaned examples\",\n",
    "        metadata={\n",
    "            \"sequence_type\": sequence_type,\n",
    "            \"block_size\": block_size,\n",
    "            \"vocab_size\": meta_vocab_size,\n",
    "            \"includes_cleaned_version\": True\n",
    "        }\n",
    "    )\n",
    "\n",
    "    for f in test_files:\n",
    "        artifact.add_file(os.path.join(out_dir, f))\n",
    "\n",
    "    run.log_artifact(artifact)\n",
    "    run.finish()\n",
    "else:\n",
    "    print(f\"[SKIPPED] Total size {total_bytes / 1024**3:.2f} GiB exceeds W&B limit.\")\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10.) Analyze trailing C’s in dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from collections import Counter\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def count_trailing_Cs(s):\n",
    "    count = 0\n",
    "    for c in reversed(s):\n",
    "        if c == 'C':\n",
    "            count += 1\n",
    "        else:\n",
    "            break\n",
    "    return count\n",
    "\n",
    "def analyze_trailing_Cs(test_path: str):\n",
    "    trailing_C_lengths = []\n",
    "    total_reads = 0\n",
    "    affected_reads = 0\n",
    "\n",
    "    with open(test_path, 'r') as f:\n",
    "        for line in f:\n",
    "            line = line.strip()\n",
    "            if not line or \":\" not in line:\n",
    "                continue\n",
    "            reads_part, _ = line.split(\":\")\n",
    "            reads = reads_part.split(\"|\")\n",
    "\n",
    "            for r in reads:\n",
    "                total_reads += 1\n",
    "                c_tail_len = count_trailing_Cs(r)\n",
    "                if c_tail_len > 0:\n",
    "                    trailing_C_lengths.append(c_tail_len)\n",
    "                    affected_reads += 1\n",
    "\n",
    "    # Summary \n",
    "    print(f\"Total reads: {total_reads}\")\n",
    "    print(f\"Reads with trailing 'C's: {affected_reads} ({(affected_reads/total_reads)*100:.2f}%)\")\n",
    "\n",
    "    # Histogram \n",
    "    if trailing_C_lengths:\n",
    "        counter = Counter(trailing_C_lengths)\n",
    "        print(\"\\nTrailing 'C' tail length distribution:\")\n",
    "        for k in sorted(counter):\n",
    "            print(f\"Length {k}: {counter[k]} reads\")\n",
    "\n",
    "        plt.figure(figsize=(6, 4))\n",
    "        plt.hist(trailing_C_lengths, bins=range(1, max(trailing_C_lengths)+2), align='left', edgecolor='black')\n",
    "        plt.xlabel(\"Length of trailing 'C's\")\n",
    "        plt.ylabel(\"Number of reads\")\n",
    "        plt.title(\"Distribution of Trailing 'C' Tails in Test Reads\")\n",
    "        plt.grid(True, linestyle='--', linewidth=0.5)\n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "    else:\n",
    "        print(\"No trailing 'C's found.\")\n",
    "\n",
    "test_file = os.path.join(out_dir, \"test.txt\")\n",
    "\n",
    "analyze_trailing_Cs(test_file)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "treconlm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
