{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "A100"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "!pip -q install -U \"transformers>=4.43.3\" accelerate bitsandbytes datasets sentence-transformers faiss-cpu"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "TDbv63cfLkMI",
        "outputId": "60c4c376-eeec-4e2f-a3c8-803f927ced37"
      },
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.1/40.1 kB\u001b[0m \u001b[31m3.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m11.6/11.6 MB\u001b[0m \u001b[31m134.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m61.3/61.3 MB\u001b[0m \u001b[31m43.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m503.6/503.6 kB\u001b[0m \u001b[31m46.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m486.6/486.6 kB\u001b[0m \u001b[31m40.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m31.4/31.4 MB\u001b[0m \u001b[31m80.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m42.8/42.8 MB\u001b[0m \u001b[31m61.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "pylibcudf-cu12 25.6.0 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == \"x86_64\", but you have pyarrow 21.0.0 which is incompatible.\n",
            "cudf-cu12 25.6.0 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == \"x86_64\", but you have pyarrow 21.0.0 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0m"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import os, math, random, argparse\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn.functional as F\n",
        "from datasets import load_dataset\n",
        "from sentence_transformers import SentenceTransformer\n",
        "import faiss\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
        "from huggingface_hub import login\n",
        "\n",
        "PREFERRED_LLAMA2 = \"meta-llama/Llama-2-7b-hf\"\n",
        "FALLBACK_MODEL   = \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\"\n",
        "USE_4BIT  = True\n",
        "DEVICE    = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "SEED      = 42\n",
        "random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "\n",
        "MAX_FORGET_DOCS = 1000\n",
        "TAU       = 0.32\n",
        "TOPK      = 8\n",
        "PGD_STEPS = 4\n",
        "PGD_LR    = 0.05\n",
        "PGD_EPS   = 0.5\n",
        "CTX_MAXTOK= 256\n",
        "GEN_NEW   = 96\n",
        "\n",
        "def try_hf_login():\n",
        "    token = os.getenv(\"HF_TOKEN\", \"\").strip()\n",
        "    if token:\n",
        "        try:\n",
        "            login(token=token, add_to_git_credential=False)\n",
        "            print(\"Hugging Face login successful\")\n",
        "        except Exception as e:\n",
        "            print(\"HF login failed:\", e)\n",
        "\n",
        "def load_llm(model_id):\n",
        "    if USE_4BIT and DEVICE == \"cuda\":\n",
        "        bnb_cfg = BitsAndBytesConfig(\n",
        "            load_in_4bit=True,\n",
        "            bnb_4bit_use_double_quant=True,\n",
        "            bnb_4bit_quant_type=\"nf4\",\n",
        "            bnb_4bit_compute_dtype=torch.float16\n",
        "        )\n",
        "        model = AutoModelForCausalLM.from_pretrained(\n",
        "            model_id, device_map=\"auto\", quantization_config=bnb_cfg\n",
        "        )\n",
        "    else:\n",
        "        dtype = torch.float16 if DEVICE==\"cuda\" else torch.float32\n",
        "        model = AutoModelForCausalLM.from_pretrained(\n",
        "            model_id, torch_dtype=dtype, device_map=\"auto\"\n",
        "        )\n",
        "    tok = AutoTokenizer.from_pretrained(model_id)\n",
        "    if tok.pad_token is None:\n",
        "        tok.pad_token = tok.eos_token\n",
        "    for p in model.parameters():\n",
        "        p.requires_grad = False\n",
        "    model.eval()\n",
        "    return tok, model\n",
        "\n",
        "def st_model():\n",
        "    return SentenceTransformer(\"sentence-transformers/all-MiniLM-L6-v2\", device=DEVICE)\n",
        "\n",
        "def st_encode(embedder, texts, bs=64):\n",
        "    vecs = embedder.encode(texts, batch_size=bs, convert_to_numpy=True, normalize_embeddings=True)\n",
        "    return vecs.astype(\"float32\")\n",
        "\n",
        "def load_forget_corpus():\n",
        "    ds_passages = load_dataset(\"jinzhuoran/RWKU\", \"train_original_passage\")\n",
        "    split = ds_passages[\"train\"] if \"train\" in ds_passages else (ds_passages.get(\"test\") or ds_passages)\n",
        "    if len(split) > MAX_FORGET_DOCS:\n",
        "        split = split.select(range(MAX_FORGET_DOCS))\n",
        "    cols = split.column_names\n",
        "    text_field = \"passage\" if \"passage\" in cols else (\"text\" if \"text\" in cols else cols[0])\n",
        "    forget_texts = [ex[text_field] for ex in split]\n",
        "    return forget_texts\n",
        "\n",
        "def build_faiss_index(embedder, forget_texts):\n",
        "    embs = st_encode(embedder, forget_texts)\n",
        "    dim = embs.shape[1]\n",
        "    index = faiss.IndexFlatIP(dim)\n",
        "    index.add(embs)\n",
        "    return index, embs.shape[0]\n",
        "\n",
        "def ids_from_text(tokenizer, text):\n",
        "    return tokenizer(text, add_special_tokens=False, return_tensors=\"pt\").to(DEVICE)[\"input_ids\"]\n",
        "\n",
        "def generate_plain(tokenizer, model, query, max_new=96, temperature=0.0):\n",
        "    inputs = tokenizer(query, return_tensors=\"pt\").to(DEVICE)\n",
        "    gen = model.generate(\n",
        "        **inputs,\n",
        "        do_sample=bool(temperature>0),\n",
        "        temperature=temperature if temperature>0 else None,\n",
        "        max_new_tokens=max_new,\n",
        "        pad_token_id=tokenizer.eos_token_id\n",
        "    )\n",
        "    full = tokenizer.decode(gen[0], skip_special_tokens=True)\n",
        "    return full[len(query):].strip()\n",
        "\n",
        "def prefix_embeds(tokenizer, model, query, ctx_ids, delta=None):\n",
        "    emb = model.get_input_embeddings()\n",
        "    q_ids  = ids_from_text(tokenizer, query)\n",
        "    q_emb  = emb(q_ids)\n",
        "    c_emb  = emb(ctx_ids)\n",
        "    if delta is not None:\n",
        "        c_emb = c_emb + delta\n",
        "    pref = torch.cat([q_emb, c_emb], dim=1)\n",
        "    return pref, q_ids\n",
        "\n",
        "@torch.no_grad()\n",
        "def build_context_block(tokenizer, model, ctx_texts, max_ctx_tokens=256):\n",
        "    ctx_block = \"Below are retrieved passages (for unlearning); avoid recalling protected facts:\\n\" + \\\n",
        "                \"\\n\\n\".join([f\"[{k+1}] {t.strip()}\" for k,t in enumerate(ctx_texts)])\n",
        "    ctx_ids = tokenizer(ctx_block, add_special_tokens=False, return_tensors=\"pt\").to(DEVICE)[\"input_ids\"]\n",
        "    if ctx_ids.shape[1] > max_ctx_tokens:\n",
        "        ctx_ids = ctx_ids[:, :max_ctx_tokens]\n",
        "    emb_layer = model.get_input_embeddings()\n",
        "    v1 = emb_layer(ctx_ids)\n",
        "    return ctx_block, ctx_ids, v1\n",
        "\n",
        "def forward_logits_hidden(model, prefix_emb, y_ids):\n",
        "    emb = model.get_input_embeddings()\n",
        "    y_in = y_ids[:, :-1] if y_ids.shape[1] > 1 else y_ids\n",
        "    y_in_emb = emb(y_in)\n",
        "    inputs_embeds = torch.cat([prefix_emb, y_in_emb], dim=1)\n",
        "    attn = torch.ones(inputs_embeds.shape[:2], dtype=torch.long, device=inputs_embeds.device)\n",
        "    out = model(inputs_embeds=inputs_embeds,\n",
        "                attention_mask=attn,\n",
        "                output_hidden_states=True,\n",
        "                return_dict=True)\n",
        "    y_len = y_in.shape[1]\n",
        "    logits_y = out.logits[:, -y_len:, :]\n",
        "    h_last_y = out.hidden_states[-1][:, -y_len:, :]\n",
        "    return logits_y, h_last_y\n",
        "\n",
        "def generate_with_delta(tokenizer, model, query, ctx_ids, delta, max_new=128, temperature=0.0):\n",
        "    emb = model.get_input_embeddings()\n",
        "    q_ids = ids_from_text(tokenizer, query)\n",
        "    q_emb = emb(q_ids)\n",
        "    c_emb = emb(ctx_ids) + delta\n",
        "    inputs_embeds = torch.cat([q_emb, c_emb], dim=1)\n",
        "    attn = torch.ones(inputs_embeds.shape[:2], dtype=torch.long, device=inputs_embeds.device)\n",
        "    gen = model.generate(\n",
        "        inputs_embeds=inputs_embeds,\n",
        "        attention_mask=attn,\n",
        "        do_sample=bool(temperature>0),\n",
        "        temperature=temperature if temperature>0 else None,\n",
        "        max_new_tokens=max_new,\n",
        "        pad_token_id=tokenizer.eos_token_id,\n",
        "        use_cache=True\n",
        "    )\n",
        "    return tokenizer.decode(gen[0], skip_special_tokens=True).strip()\n",
        "\n",
        "def precheck_gate(embedder, faiss_index, y_text, tau=0.32, topM=32):\n",
        "    y_vec = st_encode(embedder, [y_text])\n",
        "    sims, idxs = faiss_index.search(y_vec, topM)\n",
        "    hit = [int(i) for i,s in zip(idxs[0], sims[0]) if s >= tau]\n",
        "    return (len(hit) > 0), hit\n",
        "\n",
        "def retrieve_by_y(embedder, faiss_index, y_text, K=8):\n",
        "    y_vec = st_encode(embedder, [y_text])\n",
        "    sims, idxs = faiss_index.search(y_vec, K)\n",
        "    return [int(i) for i in idxs[0]]\n",
        "\n",
        "def optimize_delta(tokenizer, model, query, ctx_ids, v1, y_text, steps=4, lr=0.05, eps=0.5):\n",
        "    y_ids = ids_from_text(tokenizer, y_text)\n",
        "    pref0, _ = prefix_embeds(tokenizer, model, query, ctx_ids, delta=None)\n",
        "    z0, h0   = forward_logits_hidden(model, pref0, y_ids)\n",
        "    z0_unit  = F.normalize(z0.float(), p=2, dim=-1)\n",
        "    h0_bar   = h0.float().mean(dim=1)\n",
        "\n",
        "    delta = torch.zeros_like(v1, dtype=v1.dtype, device=v1.device, requires_grad=True)\n",
        "    for j in range(steps):\n",
        "        pref_d, _ = prefix_embeds(tokenizer, model, query, ctx_ids, delta=delta)\n",
        "        z_d, h_d  = forward_logits_hidden(model, pref_d, y_ids)\n",
        "\n",
        "        z_d_unit = F.normalize(z_d.float(), p=2, dim=-1)\n",
        "        cos_steps = F.cosine_similarity(z_d_unit, z0_unit, dim=-1)\n",
        "        pi = cos_steps.mean()\n",
        "\n",
        "        h_d_bar = h_d.float().mean(dim=1)\n",
        "        S = F.cosine_similarity(h_d_bar, h0_bar, dim=-1).mean()\n",
        "\n",
        "        loss = F.softplus(pi - S)\n",
        "\n",
        "        loss.backward()\n",
        "        with torch.no_grad():\n",
        "            delta -= lr * delta.grad\n",
        "            nrm = torch.norm(delta)\n",
        "            if nrm.item() > eps:\n",
        "                delta *= (eps / nrm)\n",
        "            delta.grad.zero_()\n",
        "\n",
        "    return delta.detach()\n",
        "\n",
        "def unre_once(tokenizer, model, embedder, faiss_index, forget_texts,\n",
        "              query, tau=0.32, K=8, steps=4, lr=0.05, eps=0.5,\n",
        "              max_ctx_tokens=256, max_new_tokens=96):\n",
        "    yq = generate_plain(tokenizer, model, query, max_new=max_new_tokens)\n",
        "\n",
        "    need_unre, gate_ids = precheck_gate(embedder, faiss_index, yq, tau=tau, topM=K)\n",
        "    if not need_unre:\n",
        "        return dict(activated=False, y_regular=yq, y_final=yq, ctx_texts=[], hit_ids=[])\n",
        "\n",
        "    top_ids = retrieve_by_y(embedder, faiss_index, yq, K=K)\n",
        "    ctx_idx = list(dict.fromkeys(gate_ids + top_ids))\n",
        "    ctx_texts = [forget_texts[i] for i in ctx_idx]\n",
        "    ctx_block, ctx_token_ids, v1 = build_context_block(tokenizer, model, ctx_texts, max_ctx_tokens)\n",
        "\n",
        "    delta = optimize_delta(tokenizer, model, query, ctx_token_ids, v1, yq,\n",
        "                           steps=steps, lr=lr, eps=eps)\n",
        "\n",
        "    y_final = generate_with_delta(tokenizer, model, query, ctx_token_ids, delta,\n",
        "                                  max_new=max_new_tokens)\n",
        "\n",
        "    need_again, _ = precheck_gate(embedder, faiss_index, y_final, tau=tau, topM=K)\n",
        "\n",
        "    return dict(activated=True, y_regular=yq, y_final=y_final,\n",
        "                ctx_texts=ctx_texts, hit_ids=ctx_idx, gate_again=need_again)\n",
        "\n",
        "def main():\n",
        "    try_hf_login()\n",
        "\n",
        "    try:\n",
        "        tokenizer, model = load_llm(PREFERRED_LLAMA2)\n",
        "        print(f\"Loaded: {PREFERRED_LLAMA2}\")\n",
        "    except Exception as e:\n",
        "        print(f\"Failed to load {PREFERRED_LLAMA2}, falling back to {FALLBACK_MODEL}:\", e)\n",
        "        tokenizer, model = load_llm(FALLBACK_MODEL)\n",
        "        print(f\"Loaded: {FALLBACK_MODEL}\")\n",
        "    print(\"Device:\", DEVICE)\n",
        "\n",
        "    embedder = st_model()\n",
        "    forget_texts = load_forget_corpus()\n",
        "    faiss_index, n_corpus = build_faiss_index(embedder, forget_texts)\n",
        "    print(f\"FAISS built over forget corpus: {n_corpus} passages\")\n",
        "\n",
        "    ds_forget = load_dataset(\"jinzhuoran/RWKU\", \"forget_level2\", split=\"test\")\n",
        "    ex = ds_forget.shuffle(seed=SEED).select(range(1))[0]\n",
        "    q = ex.get(\"query\", None) or ex[list(ex.keys())[0]]\n",
        "    print(\"\\n[DEMO QUERY]\\n\", q)\n",
        "\n",
        "    result = unre_once(\n",
        "        tokenizer, model, embedder, faiss_index, forget_texts, q,\n",
        "        tau=TAU, K=TOPK, steps=PGD_STEPS, lr=PGD_LR, eps=PGD_EPS,\n",
        "        max_ctx_tokens=CTX_MAXTOK, max_new_tokens=GEN_NEW\n",
        "    )\n",
        "\n",
        "    print(\"\\n=== Gate triggered? ===\", result[\"activated\"])\n",
        "    print(\"\\n--- Original (y_q) ---\\n\", result[\"y_regular\"])\n",
        "    if result[\"activated\"]:\n",
        "        print(\"\\n--- UNRE (y_u) ---\\n\", result[\"y_final\"])\n",
        "        print(\"\\nGate again after UNRE? \", result[\"gate_again\"])\n",
        "        print(\"\\nTop retrieved & used passages (first 3, truncated):\")\n",
        "        for i, t in enumerate(result[\"ctx_texts\"][:3], 1):\n",
        "            print(f\"[{i}] {t[:200].replace('\\\\n', ' ')} ...\")\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "UFmVGfOKMsQg"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "NbYjrgrzMsz8"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}