{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "_hOBuZvjtIal",
        "outputId": "8a1cf4fc-aa80-4cd6-a3f4-26ede83c6d9a"
      },
      "outputs": [],
      "source": [
        "# ============================================\n",
        "# 🧰 Install libraries\n",
        "# ============================================\n",
        "!pip install pandas ftfy matplotlib fasttext --quiet\n",
        "!pip install pytablewriter best_download lm_dataformat --quiet\n",
        "!pip install datasketch --quiet\n",
        "!pip install justext --quiet\n",
        "!pip install pycld2 --quiet\n",
        "\n",
        "import os, sys, re, html, json, urllib.request\n",
        "import justext\n",
        "import pandas as pd\n",
        "import matplotlib.pyplot as plt\n",
        "from ftfy import fix_text\n",
        "from collections import OrderedDict\n",
        "\n",
        "# --- FastText compatibility patch for NumPy 2.0 ---\n",
        "import numpy as np\n",
        "\n",
        "_old_array = np.array\n",
        "\n",
        "def _safe_array(obj, *args, **kwargs):\n",
        "    # Remove old/unsupported NumPy args that FastText might pass\n",
        "    kwargs.pop(\"copy\", None)\n",
        "    kwargs.pop(\"subok\", None)\n",
        "    return np.asarray(obj, *args, **kwargs)\n",
        "\n",
        "np.array = _safe_array\n",
        "\n",
        "import fasttext\n",
        "from datasketch import MinHash, MinHashLSH\n",
        "import pycld2 as cld2\n",
        "\n",
        "# ============================================\n",
        "# 🧰 Clone The Pile repo (if not already there)\n",
        "# ============================================\n",
        "if not os.path.isdir(\"/content/the-pile\"):\n",
        "    !git clone https://github.com/EleutherAI/the-pile.git /content/the-pile\n",
        "\n",
        "sys.path.append(\"/content/the-pile\")\n",
        "\n",
        "from the_pile.utils import strip_markdown_colons, remove_advertisement\n",
        "\n",
        "# ============================================\n",
        "# 1️⃣ Load data and watermark reference\n",
        "# ============================================\n",
        "data_path = \"/content/final_uniform_replace.jsonl\"\n",
        "mytext_path = \"/content/myText.txt\"\n",
        "\n",
        "records = [json.loads(l) for l in open(data_path, \"r\", encoding=\"utf-8\")]\n",
        "df = pd.DataFrame(records)\n",
        "df = df[df[\"is_watermarked\"] == True].copy()\n",
        "raw_texts = df[\"watermarked\"].tolist()\n",
        "print(f\"✅ Loaded {len(raw_texts)} watermarked samples.\")\n",
        "\n",
        "# =============== 13-gram MinHashLSH document-level deduplication ===============\n",
        "\n",
        "def get_char_shingles(text, k=13):\n",
        "    \"\"\"Return a set of character-level k-gram shingles.\"\"\"\n",
        "    text = text.replace(\"\\n\", \" \")\n",
        "    if len(text) <= k:\n",
        "        return {text}\n",
        "    return {text[i:i+k] for i in range(len(text) - k + 1)}\n",
        "\n",
        "def minhash_for_text(text, num_perm=128, k=13):\n",
        "    \"\"\"Build a MinHash signature from 13-gram character shingles.\"\"\"\n",
        "    shingles = get_char_shingles(text, k=k)\n",
        "    m = MinHash(num_perm=num_perm)\n",
        "    for sh in shingles:\n",
        "        m.update(sh.encode(\"utf-8\"))\n",
        "    return m\n",
        "\n",
        "def dedupe_documents_minhash(docs, threshold=0.5, num_perm=128, k=13):\n",
        "    \"\"\"\n",
        "    Document-level near-duplicate removal using MinHashLSH.\n",
        "    Returns:\n",
        "      kept_docs: list of (orig_idx, text) for unique docs\n",
        "      dup_map:   list of (dup_idx, kept_key) for duplicates (optional, for stats)\n",
        "    \"\"\"\n",
        "    lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)\n",
        "    kept_docs = []\n",
        "    dup_map = []\n",
        "\n",
        "    for idx, doc in enumerate(docs):\n",
        "        m = minhash_for_text(doc, num_perm=num_perm, k=k)\n",
        "        # query for near-duplicates among already-inserted docs\n",
        "        matches = lsh.query(m)\n",
        "        if not matches:\n",
        "            # new unique document\n",
        "            key = f\"doc_{idx}\"\n",
        "            lsh.insert(key, m)\n",
        "            kept_docs.append((idx, doc))  # keep original index for later alignment\n",
        "        else:\n",
        "            # near-duplicate of some existing doc\n",
        "            dup_map.append((idx, matches[0]))  # (duplicate_idx, representative_key)\n",
        "\n",
        "    return kept_docs, dup_map\n",
        "\n",
        "# Run MinHashLSH deduplication on the raw texts\n",
        "kept_docs, dup_map = dedupe_documents_minhash(raw_texts, threshold=0.5, num_perm=128, k=13)\n",
        "\n",
        "print(f\"🧹 After MinHashLSH dedupe: {len(kept_docs)} unique documents (from {len(raw_texts)})\")\n",
        "if dup_map:\n",
        "    print(f\"ℹ️ Removed {len(dup_map)} near-duplicate documents.\")\n",
        "\n",
        "# Replace `texts` with deduplicated texts but keep original indices\n",
        "dedup_indices  = [i for (i, _) in kept_docs]   # indices into raw_texts\n",
        "texts          = [t for (_, t) in kept_docs]   # deduplicated texts\n",
        "\n",
        "# --- Load invisible characters from myText.txt ---\n",
        "with open(mytext_path, \"r\", encoding=\"utf-8\") as f:\n",
        "    chars = f.read()\n",
        "ZWC = [c for c in chars if not c.isprintable() and c != \"\\n\"]\n",
        "ZWC_CODES = [f\"U+{ord(c):04X}\" for c in ZWC]\n",
        "print(f\"💧 Loaded {len(ZWC)} invisible watermark characters:\")\n",
        "print(\", \".join(ZWC_CODES[:10]) + (\" ...\" if len(ZWC_CODES) > 10 else \"\"))\n",
        "\n",
        "# ============================================\n",
        "# 2️⃣ Cleaning functions (The Pile-level + your own)\n",
        "# ============================================\n",
        "\n",
        "def clean_html(raw_html):\n",
        "    \"\"\"\n",
        "    Clean HTML using jusText (boilerplate removal only),\n",
        "    output is a single text string without paragraph restructuring.\n",
        "    \"\"\"\n",
        "    try:\n",
        "        # Run jusText to detect boilerplate\n",
        "        paragraphs = justext.justext(raw_html, justext.get_stoplist(\"English\"))\n",
        "\n",
        "        # Keep ONLY non-boilerplate text\n",
        "        cleaned_parts = [\n",
        "            para.text for para in paragraphs\n",
        "            if not para.is_boilerplate\n",
        "        ]\n",
        "\n",
        "        # Flatten into single string, preserve minimal spacing\n",
        "        cleaned = \" \".join(cleaned_parts)\n",
        "        cleaned = html.unescape(cleaned).strip()\n",
        "\n",
        "        return cleaned\n",
        "\n",
        "    except Exception:\n",
        "        # In case jusText crashes on malformed HTML\n",
        "        return html.unescape(raw_html).strip()\n",
        "\n",
        "def fix_unicode(text):\n",
        "    return fix_text(text)\n",
        "\n",
        "def remove_duplicate_lines(text):\n",
        "    lines = text.splitlines()\n",
        "    seen = OrderedDict()\n",
        "    for line in lines:\n",
        "        line = line.strip()\n",
        "        if line and line not in seen:\n",
        "            seen[line] = True\n",
        "    return \"\\n\".join(seen.keys())\n",
        "\n",
        "def is_english_pycld2(text, min_chars=50):\n",
        "    \"\"\"\n",
        "    Use pycld2 to keep only reliably detected English text.\n",
        "    \"\"\"\n",
        "    snippet = text.strip()\n",
        "    if len(snippet) < min_chars:\n",
        "        return False, \"too short for language detection\"\n",
        "\n",
        "    try:\n",
        "        reliable, _, details = cld2.detect(snippet, bestEffort=True)\n",
        "        # details[0]: (language name, language code, percent, score)\n",
        "        lang_code = details[0][1]  # e.g. 'EN'\n",
        "        if reliable and lang_code.lower() == \"en\":\n",
        "            return True, f\"pycld2: reliable English ({lang_code})\"\n",
        "        else:\n",
        "            return False, f\"pycld2: non-English or unreliable (reliable={reliable}, lang={lang_code})\"\n",
        "    except Exception as e:\n",
        "        return False, f\"pycld2 error: {e}\"\n",
        "\n",
        "# ============================================\n",
        "# 3️⃣ Load FastText model (for stats only)\n",
        "# ============================================\n",
        "ft_path = \"lid.176.bin\"\n",
        "\n",
        "def ensure_fasttext_model(path):\n",
        "    if os.path.exists(path):\n",
        "        if os.path.getsize(path) < 10 * 1024 * 1024:\n",
        "            print(\"⚠️ Existing lid.176.bin looks too small, removing and re-downloading...\")\n",
        "            os.remove(path)\n",
        "\n",
        "    if not os.path.exists(path):\n",
        "        print(\"📥 Downloading FastText model (lid.176.bin)...\")\n",
        "        urllib.request.urlretrieve(\n",
        "            \"https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin\",\n",
        "            path,\n",
        "        )\n",
        "        print(\"✅ Downloaded FastText model.\")\n",
        "\n",
        "ensure_fasttext_model(ft_path)\n",
        "ft_model = fasttext.load_model(ft_path)\n",
        "\n",
        "fasttext_stats = []\n",
        "\n",
        "# ============================================\n",
        "# 4️⃣ Cleaning with pycld2 English filter + FastText (stats) + debug\n",
        "# ============================================\n",
        "def pile_clean_text_debug(text, min_quality=0.5):\n",
        "    # 1) The Pile-style light cleaning\n",
        "    text = strip_markdown_colons(text)\n",
        "    text = remove_advertisement(text)\n",
        "\n",
        "    # 2) HTML → plain text via jusText\n",
        "    cleaned_html = clean_html(text)\n",
        "    if not cleaned_html.strip():\n",
        "        return None, \"HTML removal -> empty\"\n",
        "\n",
        "    # 3) Unicode fix + intra-document line dedupe\n",
        "    text = fix_unicode(cleaned_html)\n",
        "    text = remove_duplicate_lines(text)\n",
        "\n",
        "    # 4) length\n",
        "    if len(text.strip()) < 50:\n",
        "        return None, \"too short (<50 chars)\"\n",
        "\n",
        "    # 5) pycld2\n",
        "    is_en, lang_reason = is_english_pycld2(text)\n",
        "    if not is_en:\n",
        "        return None, f\"filtered by pycld2 — {lang_reason}\"\n",
        "\n",
        "    # 6) FastText\n",
        "    try:\n",
        "        labels, probs = ft_model.predict(text.replace(\"\\n\", \" \"), k=1)\n",
        "        label = labels[0].replace(\"__label__\", \"\")\n",
        "        prob = float(probs[0])\n",
        "    except Exception as e:\n",
        "        label, prob = \"error\", 0.0\n",
        "\n",
        "    fasttext_stats.append((label, prob))\n",
        "\n",
        "    return text, None\n",
        "\n",
        "# ============================================\n",
        "# 5️⃣ Apply cleaning (with index tracking)\n",
        "# ============================================\n",
        "cleaned_records = []\n",
        "filtered = []\n",
        "\n",
        "for j, t in enumerate(texts):\n",
        "    orig_idx = dedup_indices[j]  # index into raw_texts / ZWC\n",
        "    cleaned, reason = pile_clean_text_debug(t)\n",
        "    if cleaned:\n",
        "        cleaned_records.append((orig_idx, cleaned))\n",
        "    else:\n",
        "        filtered.append((orig_idx, reason))\n",
        "\n",
        "print(f\"\\n✅ Cleaned {len(cleaned_records)} / {len(texts)}\")\n",
        "for i, reason in filtered[:10]:\n",
        "    print(f\"🚫 Sample {i} filtered out — {reason}\")\n",
        "\n",
        "if fasttext_stats:\n",
        "    labels_only = [lbl for (lbl, p) in fasttext_stats]\n",
        "    label_counts = pd.Series(labels_only).value_counts()\n",
        "    print(\"\\n🌐 FastText language stats (top 10):\")\n",
        "    print(label_counts.head(10))\n",
        "\n",
        "# ============================================\n",
        "# 6️⃣ Per-document (1-to-1) watermark retention analysis — index aligned\n",
        "# ============================================\n",
        "def count_char(text, ch):\n",
        "    return text.count(ch)\n",
        "\n",
        "results = []\n",
        "for orig_idx, cleaned_text in cleaned_records:\n",
        "    if orig_idx >= len(ZWC):\n",
        "        break\n",
        "    wm_char = ZWC[orig_idx]\n",
        "    orig_text = raw_texts[orig_idx]  # use original (pre-dedup, pre-cleaning) text\n",
        "\n",
        "    orig_count = count_char(orig_text, wm_char)\n",
        "    cleaned_count = count_char(cleaned_text, wm_char)\n",
        "    retention_ratio = (cleaned_count / orig_count) if orig_count > 0 else 0\n",
        "    reduced = cleaned_count < orig_count\n",
        "    results.append({\n",
        "        \"doc_id\": orig_idx,\n",
        "        \"char_code\": f\"U+{ord(wm_char):04X}\",\n",
        "        \"orig_count\": orig_count,\n",
        "        \"cleaned_count\": cleaned_count,\n",
        "        \"retention_ratio\": retention_ratio,\n",
        "        \"reduced\": reduced\n",
        "    })\n",
        "\n",
        "df_ret = pd.DataFrame(results)\n",
        "\n",
        "if df_ret.empty:\n",
        "    print(\"\\n⚠️ No documents passed cleaning — nothing to analyze for watermark retention.\")\n",
        "else:\n",
        "    retained_docs = (df_ret[\"cleaned_count\"] > 0).sum()\n",
        "    reduced_docs = df_ret[\"reduced\"].sum()\n",
        "    overall_retention = (\n",
        "        df_ret[\"cleaned_count\"].sum() / df_ret[\"orig_count\"].sum() * 100\n",
        "        if df_ret[\"orig_count\"].sum() > 0 else 0\n",
        "    )\n",
        "\n",
        "    print(f\"\\n✅ Per-document (1-to-1) watermark analysis complete (index-aligned).\")\n",
        "    print(f\"💧 Documents with retained watermark: {retained_docs}/{len(df_ret)}\")\n",
        "    print(f\"📉 Watermark reduced in {reduced_docs} documents.\")\n",
        "    print(f\"✅ Overall watermark retention rate: {overall_retention:.2f}%\")\n",
        "\n",
        "    print(\"\\n📊 Sample of per-document watermark retention:\")\n",
        "    print(df_ret.head(10))\n",
        "\n",
        "    # ============================================\n",
        "    # 7️⃣ Visualization — per-document retention distribution\n",
        "    # ============================================\n",
        "    plt.figure(figsize=(6,4))\n",
        "    plt.hist(df_ret[\"retention_ratio\"] * 100, bins=10, color=\"#4CAF50\", edgecolor=\"black\")\n",
        "    plt.xlabel(\"Per-document watermark retention (%)\")\n",
        "    plt.ylabel(\"Document count\")\n",
        "    plt.title(\"Per-document (1-to-1) Watermark Retention after Cleaning\")\n",
        "    plt.grid(alpha=0.3)\n",
        "    plt.show()\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "欢迎使用 Colab",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
