{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_sHINySkHtkC",
        "outputId": "eef3da44-a7a9-4e61-9ab9-f58cdc65de88"
      },
      "outputs": [],
      "source": [
        "# ============================================\n",
        "# 🧰 Install dependencies\n",
        "# ============================================\n",
        "!pip install pybloom_live tqdm pandas msgspec polars pyarrow networkit numpy scipy nltk --quiet\n",
        "\n",
        "# Download NLTK tokenizer\n",
        "import nltk\n",
        "nltk.download('punkt', quiet=True)\n",
        "\n",
        "# ============================================\n",
        "# 📥 Clone RedPajama-Data source and add sys.path\n",
        "# ============================================\n",
        "!git clone https://github.com/togethercomputer/RedPajama-Data.git -q\n",
        "\n",
        "import sys, os\n",
        "sys.path.append(\"/content/RedPajama-Data/app/src\")\n",
        "\n",
        "from utilities.text.normalization import normalize\n",
        "from utilities.text.ngrams import form_ngrams\n",
        "from dedupe.utils import optimal_param\n",
        "from dedupe.minhash import MinHash as RPMinHash\n",
        "\n",
        "# 👉 Introducing RedPajama's Document and quality_signals registration functions.\n",
        "from core.document import Document\n",
        "from core.quality_signals.content import register_content_callables\n",
        "from core.quality_signals.lines import register_lines_callables\n",
        "from core.quality_signals.natural_language import register_natural_language_callables\n",
        "from core.quality_signals.repetitions import register_repetitions_callables\n",
        "from core.quality_signals.classifiers import register_classifier_callables\n",
        "from core.quality_signals.importance_weights import register_importance_weights_callables\n",
        "\n",
        "# ============================================\n",
        "# 📦 Paths and configuration\n",
        "# ============================================\n",
        "from pathlib import Path\n",
        "\n",
        "DATA_PATH = \"/content/final_uniform_replace.jsonl\"\n",
        "MYTEXT_PATH = \"/content/myText.txt\"\n",
        "\n",
        "EXACT_OUT_DIR = Path(\"/content/rp_exact_dedup_output\")\n",
        "EXACT_OUT_DIR.mkdir(parents=True, exist_ok=True)\n",
        "LSH_OUT_DIR = Path(\"/content/rp_lsh_output\")\n",
        "LSH_OUT_DIR.mkdir(parents=True, exist_ok=True)\n",
        "\n",
        "# ✅ Use the artifacts directory that comes with the RedPajama repository\n",
        "ARTIFACTS_DIR = Path(\"/content/RedPajama-Data/artifacts\")\n",
        "BAD_WORDS_DIR = ARTIFACTS_DIR / \"bad_words\"\n",
        "BAD_URLS_DIR = ARTIFACTS_DIR / \"bad_urls\"\n",
        "CLASSIFIERS_DIR = ARTIFACTS_DIR / \"classifiers\"\n",
        "DSIR_DIR = ARTIFACTS_DIR / \"dsir\"\n",
        "\n",
        "# Here an empty dict is used to simulate the case where there is no dsir / classifier file.\n",
        "# Calls are made in the same way as the official worker, except that the model is not actually loaded.\n",
        "EMPTY_DSIr_FILES = {}\n",
        "EMPTY_CLASSIFIER_FILES = {}\n",
        "\n",
        "# ============================================\n",
        "# 💧 Load invisible characters from myText.txt\n",
        "# ============================================\n",
        "with open(MYTEXT_PATH, \"r\", encoding=\"utf-8\") as f:\n",
        "    chars = f.read()\n",
        "INVISIBLE_CHARS = [c for c in chars if not c.isprintable() and c != \"\\n\"]\n",
        "INVISIBLE_CODES = [f\"U+{ord(c):04X}\" for c in INVISIBLE_CHARS]\n",
        "print(f\"💧 Loaded {len(INVISIBLE_CHARS)} invisible watermark characters:\")\n",
        "print(\", \".join(INVISIBLE_CODES[:10]) + (\" ...\" if len(INVISIBLE_CODES) > 10 else \"\"))\n",
        "\n",
        "def count_invisible(text: str, idx: int) -> int:\n",
        "    \"\"\"\n",
        "    Count only the specific invisible character for this line index.\n",
        "    Note: It is assumed that only one character, INVISIBLE_CHARS[idx], is used as the watermark for the idx sample.\n",
        "    \"\"\"\n",
        "    if idx >= len(INVISIBLE_CHARS):\n",
        "        return 0\n",
        "    ch = INVISIBLE_CHARS[idx]\n",
        "    return text.count(ch)\n",
        "\n",
        "# ============================================\n",
        "# 🧼 Stage 0: RedPajama-style quality signals / cleaning\n",
        "# ============================================\n",
        "\n",
        "def init_rp_quality_callables(\n",
        "    lang: str = \"en\",\n",
        "    ldnoobw_dir: Path = BAD_WORDS_DIR,\n",
        "    ut1_dir: Path = BAD_URLS_DIR,\n",
        "):\n",
        "    \"\"\"\n",
        "    Modeled after core.worker.Worker.__init_quality_signals,\n",
        "    but do try/except for missing files and use as much as you can.\n",
        "    \"\"\"\n",
        "    callables = []\n",
        "\n",
        "    # 1) content signal (depends on bad_words / bad_urls)\n",
        "    try:\n",
        "        callables += register_content_callables(\n",
        "            language=lang,\n",
        "            bad_urls_dir=ut1_dir,\n",
        "            bad_words_dir=ldnoobw_dir,\n",
        "        )\n",
        "    except FileNotFoundError as e:\n",
        "        print(f\"[RP] Skipping content signals: {e}\")\n",
        "\n",
        "    # 2) Repeatable signals\n",
        "    try:\n",
        "        callables += register_repetitions_callables()\n",
        "    except Exception as e:\n",
        "        print(f\"[RP] Skipping repetitions signals: {e}\")\n",
        "\n",
        "    # 3) Natural Language / Unnatural Language Signals\n",
        "    try:\n",
        "        callables += register_natural_language_callables()\n",
        "    except Exception as e:\n",
        "        print(f\"[RP] Skipping natlang signals: {e}\")\n",
        "\n",
        "    # 4) Line Level Signal\n",
        "    try:\n",
        "        callables += register_lines_callables()\n",
        "    except Exception as e:\n",
        "        print(f\"[RP] Skipping line-level signals: {e}\")\n",
        "\n",
        "    # 5) fastText classifier signals (if there is no .bin it will give an error, just skip it)\n",
        "    try:\n",
        "        wikiref_model = next((p for p in (CLASSIFIERS_DIR / lang).glob(\"wikiref*.bin\")), None) \\\n",
        "            if (CLASSIFIERS_DIR / lang).exists() else None\n",
        "        palm_model = next((p for p in (CLASSIFIERS_DIR / lang).glob(\"palm*.bin\")), None) \\\n",
        "            if (CLASSIFIERS_DIR / lang).exists() else None\n",
        "        wikipedia_model = next((p for p in (CLASSIFIERS_DIR / lang).glob(\"wikipedia*.bin\")), None) \\\n",
        "            if (CLASSIFIERS_DIR / lang).exists() else None\n",
        "\n",
        "        callables += register_classifier_callables(\n",
        "            wikiref_model=str(wikiref_model) if wikiref_model else None,\n",
        "            palm_model=str(palm_model) if palm_model else None,\n",
        "            wikipedia_model=str(wikipedia_model) if wikipedia_model else None,\n",
        "        )\n",
        "    except FileNotFoundError as e:\n",
        "        print(f\"[RP] Skipping classifier signals: {e}\")\n",
        "    except Exception as e:\n",
        "        print(f\"[RP] Skipping classifier signals (other error): {e}\")\n",
        "\n",
        "    # 6) DSIR importance weighting (skip if no .npy file)\n",
        "    try:\n",
        "        def collect_dsir(domain: str):\n",
        "            dom_dir = DSIR_DIR / lang\n",
        "            if not dom_dir.exists():\n",
        "                return None\n",
        "            counts = list(dom_dir.glob(f\"{domain}.counts.npy\"))\n",
        "            lambdas = list(dom_dir.glob(f\"{domain}.lambda.npy\"))\n",
        "            if not counts or not lambdas:\n",
        "                return None\n",
        "            return [str(counts[0]), str(lambdas[0])]\n",
        "\n",
        "        callables += register_importance_weights_callables(\n",
        "            source_fps=collect_dsir(\"ccnet\"),\n",
        "            wiki_fps=collect_dsir(\"wikipedia\"),\n",
        "            openwebtext_fps=collect_dsir(\"openwebtext\"),\n",
        "            books_fps=collect_dsir(\"books\"),\n",
        "            language=lang,\n",
        "        )\n",
        "    except FileNotFoundError as e:\n",
        "        print(f\"[RP] Skipping DSIR signals: {e}\")\n",
        "    except Exception as e:\n",
        "        print(f\"[RP] Skipping DSIR signals (other error): {e}\")\n",
        "\n",
        "    return callables\n",
        "\n",
        "RP_QUALITY_CALLABLES = init_rp_quality_callables(lang=\"en\")\n",
        "\n",
        "def compute_rp_quality_signals(text: str, language: str = \"en\") -> dict:\n",
        "    \"\"\"\n",
        "    Using RedPajama's official Document + quality_signals callables,\n",
        "    computes all RP signals for a given text and returns a dict.\n",
        "    \"\"\"\n",
        "    document = Document(\n",
        "        text,\n",
        "        domain=\"\",\n",
        "        precompute_ngrams=True,\n",
        "        precompute_hash_features=True,\n",
        "        dsir_buckets=10_000,       # Consistent with Worker\n",
        "    )\n",
        "\n",
        "    signals = {}\n",
        "    for func in RP_QUALITY_CALLABLES:\n",
        "        try:\n",
        "            signals[func.field_name] = func(document)\n",
        "        except Exception as e:\n",
        "            signals[func.field_name] = f\"ERROR: {e.__class__.__name__}\"\n",
        "\n",
        "    return signals\n",
        "\n",
        "# ============================================\n",
        "# 🔹 Stage 1️⃣: Exact Deduplication (Bloom Filter)\n",
        "# ============================================\n",
        "import hashlib, json, msgspec\n",
        "from pybloom_live import BloomFilter\n",
        "from tqdm import tqdm\n",
        "\n",
        "capacity = 1_000_000\n",
        "error_rate = 1e-4\n",
        "unique_fp = EXACT_OUT_DIR / \"unique.jsonl\"\n",
        "duplicates_fp = EXACT_OUT_DIR / \"duplicates.jsonl\"\n",
        "bloom_fp = EXACT_OUT_DIR / \"bloomfilter.pkl\"  # Optional: If you want to persist it, you can store it in a pickle.\n",
        "\n",
        "bloom = BloomFilter(capacity=capacity, error_rate=error_rate)\n",
        "print(f\"🧮 Bloom Filter initialized. capacity={capacity}, error_rate={error_rate}\")\n",
        "\n",
        "def compute_digest(text: str) -> str:\n",
        "    return hashlib.sha1(text.encode(\"utf-8\")).hexdigest()\n",
        "\n",
        "num_docs = 0\n",
        "num_dupes = 0\n",
        "\n",
        "with open(DATA_PATH, \"r\", encoding=\"utf-8\") as fin, \\\n",
        "     open(unique_fp, \"w\", encoding=\"utf-8\") as fout_unique, \\\n",
        "     open(duplicates_fp, \"w\", encoding=\"utf-8\") as fout_dupes:\n",
        "    for line in tqdm(fin, desc=\"Processing documents\"):\n",
        "        try:\n",
        "            record = json.loads(line)\n",
        "        except json.JSONDecodeError:\n",
        "            continue\n",
        "        text = record.get(\"watermarked\", \"\").strip()\n",
        "        if not text:\n",
        "            continue\n",
        "\n",
        "        # ===== 🧼 Call RedPajama official cleaning / quality pipeline =====\n",
        "        rp_signals = compute_rp_quality_signals(text, language=\"en\")\n",
        "        record[\"rp_quality_signals\"] = rp_signals\n",
        "        # If you want to actually filter based on these signals in the future, you can write rules here, for example:\n",
        "        # if rp_signals[\"rps_doc_is_natural_language\"][0][2] < 0.5: continue\n",
        "        # Currently this code just calculates and saves the signal, it does not do hard filter.\n",
        "\n",
        "        digest = compute_digest(text)\n",
        "        num_docs += 1\n",
        "        if digest in bloom:\n",
        "            num_dupes += 1\n",
        "            fout_dupes.write(json.dumps({\"digest\": digest, **record}, ensure_ascii=False) + \"\\n\")\n",
        "        else:\n",
        "            bloom.add(digest)\n",
        "            fout_unique.write(json.dumps({\"digest\": digest, **record}, ensure_ascii=False) + \"\\n\")\n",
        "\n",
        "print(\"✅ Exact Dedup complete!\")\n",
        "print(f\"Total: {num_docs:,}, Duplicates: {num_dupes:,}, Unique: {num_docs - num_dupes:,}\")\n",
        "print(f\"Results saved in: {EXACT_OUT_DIR}\")\n",
        "\n",
        "# (Optional) Persist BloomFilter to file\n",
        "import pickle\n",
        "with open(bloom_fp, \"wb\") as f:\n",
        "    pickle.dump(bloom, f)\n",
        "\n",
        "# ============================================\n",
        "# 💧 Watermark retention check (exact dedup, 1-to-1 character mapping)\n",
        "# ============================================\n",
        "total_wm = 0\n",
        "unique_wm_retained = 0\n",
        "dupe_wm_removed = 0\n",
        "\n",
        "with open(DATA_PATH, \"r\", encoding=\"utf-8\") as f:\n",
        "    for i, line in enumerate(f):\n",
        "        try:\n",
        "            rec = json.loads(line)\n",
        "        except json.JSONDecodeError:\n",
        "            continue\n",
        "        if rec.get(\"is_watermarked\", False):\n",
        "            total_wm += 1\n",
        "\n",
        "with open(unique_fp, \"r\", encoding=\"utf-8\") as f:\n",
        "    for i, line in enumerate(f):\n",
        "        rec = json.loads(line)\n",
        "        if rec.get(\"is_watermarked\", False):\n",
        "            text = rec.get(\"watermarked\", \"\")\n",
        "            if count_invisible(text, i) > 0:\n",
        "                unique_wm_retained += 1\n",
        "\n",
        "with open(duplicates_fp, \"r\", encoding=\"utf-8\") as f:\n",
        "    for i, line in enumerate(f):\n",
        "        rec = json.loads(line)\n",
        "        if rec.get(\"is_watermarked\", False):\n",
        "            text = rec.get(\"watermarked\", \"\")\n",
        "            if count_invisible(text, i) > 0:\n",
        "                dupe_wm_removed += 1\n",
        "\n",
        "if total_wm > 0:\n",
        "    retention_rate = unique_wm_retained / total_wm * 100\n",
        "    lost_rate = dupe_wm_removed / total_wm * 100\n",
        "else:\n",
        "    retention_rate = lost_rate = 0\n",
        "\n",
        "print(\"\\n=================== 💧 1-to-1 Watermark Retention after Exact Dedup ===================\")\n",
        "print(f\"Total watermarked docs: {total_wm}\")\n",
        "print(f\"Retained (unique) docs w/ watermark: {unique_wm_retained}\")\n",
        "print(f\"Removed (duplicate) docs w/ watermark: {dupe_wm_removed}\")\n",
        "print(f\"✅ Retention rate: {retention_rate:.2f}% | ⚠️ Lost rate: {lost_rate:.2f}%\")\n",
        "print(\"======================================================================================\")\n",
        "\n",
        "# ============================================\n",
        "# 🔸 Stage 2️⃣: RedPajama-style Fuzzy Dedup (MinHash + LSH)\n",
        "# Using unique.jsonl as input\n",
        "# ============================================\n",
        "print(\"\\n🚀 Entering LSH fuzzy deduplication stage...\")\n",
        "NEW_DATA_PATH = str(unique_fp)   # Input is the unique file from exact dedup\n",
        "\n",
        "# ============================================\n",
        "# 🧩 Paths and parameters (RedPajama-style)\n",
        "# ============================================\n",
        "DATA_PATH = NEW_DATA_PATH\n",
        "OUT_DIR = LSH_OUT_DIR\n",
        "OUT_DIR.mkdir(parents=True, exist_ok=True)\n",
        "\n",
        "SIMILARITIES = [0.7]\n",
        "NUM_PERM = 128         # number of permutations\n",
        "NGRAM_SIZE = 3         # token 3-gram\n",
        "SEED = 42\n",
        "\n",
        "# ============================================\n",
        "# 📦 Build MinHash parquet file (RedPajama-compatible format)\n",
        "# Each row: id, id_int, shard_id, signature_sim0.7 (list[bytes])\n",
        "# ============================================\n",
        "import polars as pl\n",
        "\n",
        "def sha1_64_int(doc_id: str) -> int:\n",
        "    \"\"\"Convert document ID to 64-bit integer (same as RedPajama worker.py).\"\"\"\n",
        "    byteorder = sys.byteorder  # usually 'little'\n",
        "    return int.from_bytes(\n",
        "        hashlib.sha1(doc_id.encode(\"utf-8\")).digest()[:8],\n",
        "        byteorder=byteorder,\n",
        "        signed=False,\n",
        "    )\n",
        "\n",
        "def build_minhash_parquet(\n",
        "    data_path: str,\n",
        "    out_parquet: Path,\n",
        "    similarities: list,\n",
        "    num_perm: int,\n",
        "    ngram_size: int,\n",
        "    seed: int,\n",
        "    shard_id: str = \"local/0000\",\n",
        "):\n",
        "    \"\"\"Generate MinHash signatures for all docs and save to a Parquet file using RedPajama's MinHash.\"\"\"\n",
        "    mh = RPMinHash(\n",
        "        similarity_thresholds=similarities,\n",
        "        ngram_size=ngram_size,\n",
        "        num_permutations=num_perm,\n",
        "        seed=seed,\n",
        "    )\n",
        "\n",
        "    ids, id_ints = [], []\n",
        "    sig_cols = {f\"signature_sim{s}\": [] for s in similarities}\n",
        "\n",
        "    # Read JSONL line by line\n",
        "    with open(data_path, \"r\", encoding=\"utf-8\") as f:\n",
        "        for i, line in enumerate(f):\n",
        "            try:\n",
        "                rec = json.loads(line)\n",
        "            except json.JSONDecodeError:\n",
        "                continue\n",
        "\n",
        "            text = rec.get(\"watermarked\", \"\")\n",
        "            if not text:\n",
        "                continue\n",
        "\n",
        "            # RedPajama-compatible normalize + tokenize\n",
        "            norm = normalize(text)\n",
        "            tokens = tuple(norm.split())  # equivalent to Document.normalized_words for words\n",
        "\n",
        "            # Compute signatures via RedPajama MinHash\n",
        "            sigs = mh.compute_banded_signatures(tokens)\n",
        "\n",
        "            doc_id = f\"{shard_id}/{i}\"        # matches worker-style ID\n",
        "            id_int = sha1_64_int(doc_id)\n",
        "\n",
        "            ids.append(doc_id)\n",
        "            id_ints.append(id_int)\n",
        "            for s in similarities:\n",
        "                key = f\"signature_sim{s}\"\n",
        "                sig_cols[key].append(sigs[key])\n",
        "\n",
        "    # Assemble Polars DataFrame\n",
        "    df = pl.DataFrame(\n",
        "        {\n",
        "            \"id\": ids,\n",
        "            \"id_int\": id_ints,\n",
        "            \"shard_id\": [shard_id] * len(ids),\n",
        "            **sig_cols,\n",
        "        }\n",
        "    )\n",
        "\n",
        "    df.write_parquet(out_parquet)\n",
        "    return out_parquet\n",
        "\n",
        "minhash_parquet = OUT_DIR / \"minhash.parquet\"\n",
        "_ = build_minhash_parquet(\n",
        "    data_path=DATA_PATH,\n",
        "    out_parquet=minhash_parquet,\n",
        "    similarities=SIMILARITIES,\n",
        "    num_perm=NUM_PERM,\n",
        "    ngram_size=NGRAM_SIZE,\n",
        "    seed=SEED,\n",
        "    shard_id=\"local/0000\",\n",
        ")\n",
        "print(f\"✅ MinHash parquet written: {minhash_parquet}\")\n",
        "\n",
        "# ============================================\n",
        "# 🧭 LSH-based fuzzy deduplication (RedPajama run_lsh.py logic)\n",
        "# Group by band → connect same-bucket pairs → compute connected components\n",
        "# ============================================\n",
        "import pyarrow.dataset as ds\n",
        "import networkit.components as nk_components\n",
        "import networkit.graph as nk_graph\n",
        "import numpy as np\n",
        "\n",
        "def run_rp_lsh(\n",
        "    minhash_parquets: list,\n",
        "    out_dir: Path,\n",
        "    similarity: float,\n",
        "    num_perm: int,\n",
        "    max_docs: int = -1,\n",
        "):\n",
        "    \"\"\"Run RedPajama-style LSH fuzzy deduplication.\"\"\"\n",
        "    sig_key = f\"signature_sim{similarity}\"\n",
        "\n",
        "    # Load Parquet dataset\n",
        "    try:\n",
        "        dset = ds.dataset(source=minhash_parquets, format=\"parquet\")\n",
        "    except TypeError:\n",
        "        dset = ds.dataset(minhash_parquets, format=\"parquet\")\n",
        "\n",
        "    # Scan and select only necessary columns\n",
        "    query = (\n",
        "        pl.scan_pyarrow_dataset(dset)\n",
        "        .select(pl.col([\"id_int\", sig_key]))\n",
        "        .filter(~pl.col(sig_key).is_null())\n",
        "    )\n",
        "\n",
        "    # Compute (bands, rows) for this similarity using RedPajama's optimal_param\n",
        "    b, r = optimal_param(similarity, num_perm)\n",
        "    num_bands = b\n",
        "\n",
        "    # Expand by band and bucketize\n",
        "    query = (\n",
        "        query\n",
        "        .with_columns(pl.lit(list(range(num_bands))).alias(\"band\"))\n",
        "        .explode(sig_key, \"band\")\n",
        "        .group_by(sig_key, \"band\")\n",
        "        .agg(pl.col(\"id_int\"))\n",
        "    )\n",
        "\n",
        "    # Keep only buckets with more than one doc\n",
        "    try:\n",
        "        query = query.filter(pl.col(\"id_int\").list.lengths() > 1)\n",
        "    except AttributeError:\n",
        "        query = query.filter(pl.col(\"id_int\").list.len() > 1)\n",
        "\n",
        "    # Create edges between docs sharing same bucket\n",
        "    query = (\n",
        "        query\n",
        "        .select(\n",
        "            pl.col(\"id_int\"),\n",
        "            pl.col(\"id_int\").list.min().alias(\"min_node\"),\n",
        "        )\n",
        "        .explode(\"id_int\")\n",
        "        .filter(pl.col(\"id_int\") != pl.col(\"min_node\"))\n",
        "        .select(pl.concat_list([\"id_int\", \"min_node\"]).alias(\"edges\"))\n",
        "        .unique(\"edges\")\n",
        "    )\n",
        "\n",
        "    # Collect edges as list of [id_int, min_node]\n",
        "    edges_df = query.collect()\n",
        "    edges = edges_df[\"edges\"].to_list()  # list of [n1, n2]\n",
        "    print(f\"🧱 Edges: {len(edges)} total\")\n",
        "\n",
        "    # Build undirected graph\n",
        "    graph = nk_graph.Graph()\n",
        "    node_map = {}\n",
        "    for n1, n2 in edges:\n",
        "        if n1 not in node_map:\n",
        "            node_map[n1] = graph.addNode()\n",
        "        if n2 not in node_map:\n",
        "            node_map[n2] = graph.addNode()\n",
        "        graph.addEdge(node_map[n1], node_map[n2])\n",
        "\n",
        "    rev_map = {v: k for k, v in node_map.items()}\n",
        "\n",
        "    # Compute connected components\n",
        "    cc = nk_components.ConnectedComponents(G=graph)\n",
        "    cc.run()\n",
        "    comps = cc.getComponents()\n",
        "    print(f\"🔗 Connected components (clusters): {len(comps)}\")\n",
        "\n",
        "    # Convert to doc → cluster_id mapping\n",
        "    def process_comp(comp):\n",
        "        if not comp:\n",
        "            return np.empty((0, 2), dtype=np.uint64)  # skip empty clusters\n",
        "        nodes = np.array(list(map(rev_map.get, comp))).reshape(-1, 1)\n",
        "        cid = np.repeat(min(map(rev_map.get, comp)), len(nodes)).reshape(-1, 1)\n",
        "        return np.hstack((nodes, cid))\n",
        "\n",
        "    if len(comps) > 0:\n",
        "        data = np.vstack(tuple(map(process_comp, comps)))\n",
        "        clusters_df = pl.DataFrame(data, schema=[\"id_int\", \"cluster_id\"])\n",
        "    else:\n",
        "        clusters_df = pl.DataFrame({\"id_int\": [], \"cluster_id\": []})\n",
        "\n",
        "    out_path = out_dir / f\"clusters_sim{similarity}.parquet\"\n",
        "    clusters_df.write_parquet(out_path)\n",
        "    print(f\"✅ Clusters written: {out_path}\")\n",
        "    return out_path\n",
        "\n",
        "# Run LSH fuzzy deduplication\n",
        "clusters_parquet = run_rp_lsh(\n",
        "    minhash_parquets=[minhash_parquet],\n",
        "    out_dir=OUT_DIR,\n",
        "    similarity=SIMILARITIES[0],\n",
        "    num_perm=NUM_PERM,\n",
        "    max_docs=-1,\n",
        ")\n",
        "\n",
        "# ============================================\n",
        "# 💧 LSH watermark detect\n",
        "# ============================================\n",
        "clusters_parquet = str(LSH_OUT_DIR / \"clusters_sim0.7.parquet\")\n",
        "clusters = pl.read_parquet(clusters_parquet)\n",
        "\n",
        "def doc_id_of(i):\n",
        "    return f\"local/0000/{i}\"\n",
        "\n",
        "def id_int_of(i):\n",
        "    return int.from_bytes(\n",
        "        hashlib.sha1(doc_id_of(i).encode(\"utf-8\")).digest()[:8],\n",
        "        byteorder=sys.byteorder,\n",
        "        signed=False,\n",
        "    )\n",
        "\n",
        "wm_id_ints = []\n",
        "with open(DATA_PATH, \"r\", encoding=\"utf-8\") as f:\n",
        "    for i, line in enumerate(f):\n",
        "        try:\n",
        "            rec = json.loads(line)\n",
        "        except json.JSONDecodeError:\n",
        "            continue\n",
        "        if rec.get(\"is_watermarked\", False):\n",
        "            text = rec.get(\"watermarked\", \"\")\n",
        "            if count_invisible(text, i) > 0:\n",
        "                wm_id_ints.append(id_int_of(i))\n",
        "\n",
        "wm_total = len(wm_id_ints)\n",
        "print(f\"\\n📘 Total watermarked documents: {wm_total}\")\n",
        "\n",
        "# Identify duplicates in clusters\n",
        "if clusters.height > 0:\n",
        "    rep = clusters.group_by(\"cluster_id\").agg(\n",
        "        pl.col(\"id_int\").min().alias(\"rep_id\")\n",
        "    )\n",
        "    clusters = clusters.join(rep, on=\"cluster_id\", how=\"left\")\n",
        "    clusters = clusters.with_columns(\n",
        "        (pl.col(\"id_int\") != pl.col(\"rep_id\")).alias(\"is_duplicate_member\")\n",
        "    )\n",
        "else:\n",
        "    clusters = clusters.with_columns(\n",
        "        pl.Series(name=\"rep_id\", values=[]),\n",
        "        pl.Series(name=\"is_duplicate_member\", values=[]),\n",
        "    )\n",
        "\n",
        "if clusters.is_empty():\n",
        "    dupe_id_ints = set()\n",
        "else:\n",
        "    dupe_id_ints = set(\n",
        "        clusters.filter(pl.col(\"is_duplicate_member\"))\n",
        "        .get_column(\"id_int\")\n",
        "        .to_list()\n",
        "    )\n",
        "\n",
        "wm_dupes = sum(1 for _id in wm_id_ints if _id in dupe_id_ints)\n",
        "wm_retained = wm_total - wm_dupes\n",
        "ret_rate = (wm_retained / wm_total * 100) if wm_total > 0 else 0.0\n",
        "\n",
        "print(\"====== 💧 1-to-1 Watermark Retention after LSH Dedup ======\")\n",
        "print(f\"Removed watermarked docs: {wm_dupes}\")\n",
        "print(f\"Retained watermarked docs: {wm_retained}\")\n",
        "print(f\"✅ Retention rate: {ret_rate:.2f}%\")\n",
        "print(\"===========================================================\")\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "欢迎使用 Colab",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
