{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9e0e1bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import torch\n",
    "import heapq\n",
    "from tqdm import tqdm\n",
    "from tqdm import trange\n",
    "from datasets import load_dataset\n",
    "import information_geometry as ig\n",
    "\n",
    "base_path = \"BASE_PATH\" # Replace with the actual base path where data is stored\n",
    "os.makedirs(base_path, exist_ok=True)\n",
    "\n",
    "MODEL_NAME = \"google/gemma-3-4b-pt\"\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Using device: {DEVICE}\")\n",
    "\n",
    "model, tokenizer, vocab_dict, vocab_list, G = ig.load_model_and_vocab(MODEL_NAME, device=DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22d9ab09",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Extract top 10k longest articles from C4 dataset (validation split) for English and French ####\n",
    "def top_longest_articles_from_stream(dataset_name, lang_code, num_samples=10000, max_scan=1e7):\n",
    "    ds = load_dataset(dataset_name, lang_code, split=\"validation\", streaming=True)\n",
    "    top_articles = []\n",
    "    processed_count = 0\n",
    "\n",
    "    print(f\"\\nScanning {dataset_name} ({lang_code}) validation split...\")\n",
    "\n",
    "    for example in tqdm(ds, desc=f\"Scanning {lang_code}\"):\n",
    "        text = example['text']\n",
    "        text_length = len(text)\n",
    "        processed_count += 1\n",
    "        \n",
    "        if len(top_articles) < num_samples:\n",
    "            heapq.heappush(top_articles, (text_length, text))\n",
    "        elif text_length > top_articles[0][0]:\n",
    "            heapq.heapreplace(top_articles, (text_length, text))\n",
    "        \n",
    "        if processed_count >= max_scan:\n",
    "            break\n",
    "            \n",
    "    top_articles.sort(key=lambda x: x[0], reverse=True)\n",
    "    return [article[1] for article in top_articles]\n",
    "\n",
    "en_texts = top_longest_articles_from_stream(\"allenai/c4\", \"en\")\n",
    "fr_texts = top_longest_articles_from_stream(\"allenai/c4\", \"fr\")\n",
    "\n",
    "with open(os.path.join(base_path, 'texts_10k_en.json'), 'w', encoding='utf-8') as f:\n",
    "    json.dump(en_texts, f, ensure_ascii=False, indent=4)\n",
    "\n",
    "with open(os.path.join(base_path, 'texts_10k_fr.json'), 'w', encoding='utf-8') as f:\n",
    "    json.dump(fr_texts, f, ensure_ascii=False, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "292a4c45",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Save texts in shards for embedding computation ###\n",
    "def save_text_shards(texts, num_shards, lang, base_path):\n",
    "    shard_size = len(texts) // num_shards\n",
    "    os.makedirs(os.path.join(base_path, 'embeddings'), exist_ok=True)\n",
    "    \n",
    "    for shard_idx in trange(num_shards):\n",
    "        start_idx = shard_idx * shard_size\n",
    "        end_idx = len(texts) if shard_idx == num_shards - 1 else (shard_idx + 1) * shard_size\n",
    "        \n",
    "        shard_texts = texts[start_idx:end_idx]\n",
    "        with open(os.path.join(base_path, f'embeddings/{lang}_texts_shard_{shard_idx}.json'), 'w') as f:\n",
    "            json.dump(shard_texts, f)\n",
    "\n",
    "\n",
    "save_text_shards(en_texts, num_shards=5, lang='en', base_path=base_path)\n",
    "save_text_shards(fr_texts, num_shards=5, lang='fr', base_path=base_path)\n",
    "\n",
    "### Then, run llm_embedding.py to compute and save embeddings for each shard ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ca31fcf5",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_en_embeddings = []\n",
    "for shard_ind in range(5):\n",
    "    shard_path = os.path.join(base_path, f'embeddings/en_embeddings_shard_{shard_ind}.pt')\n",
    "    shard_data = torch.load(shard_path)\n",
    "    all_en_embeddings.append(shard_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6fc46bb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_fr_embeddings = []\n",
    "for shard_ind in range(5):\n",
    "    shard_path = os.path.join(base_path, f'embeddings/fr_embeddings_shard_{shard_ind}.pt')\n",
    "    shard_data = torch.load(shard_path)\n",
    "    all_fr_embeddings.append(shard_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3a836718",
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_indices_batch(embeddings, vocab_list, keys):\n",
    "    new_indices = []\n",
    "    for i in range(len(embeddings['topk_ids'])):\n",
    "        topk_indices = embeddings['topk_ids'][i][:3]\n",
    "        topk_values = torch.tensor(embeddings['topk_probs'][i][:3])\n",
    "        if all(vocab_list[idx] in keys for idx in topk_indices):\n",
    "            if topk_values.sum() > 0.7:\n",
    "                new_indices.append(i)\n",
    "    return new_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0b2c13f",
   "metadata": {},
   "outputs": [],
   "source": [
    "concept_names = [\"verb_en_fr\", \"verb_ing\", \"verb_past\", \"verb_third\"]\n",
    "\n",
    "#### Prepare datasets for each concept ####\n",
    "for concept_name in concept_names:\n",
    "    torch.random.manual_seed(100)\n",
    "    print(f\"\\nProcessing concept: {concept_name}\")\n",
    "\n",
    "    concept_path = os.path.join(base_path, concept_name)\n",
    "    os.makedirs(concept_path, exist_ok=True)\n",
    "\n",
    "    with open(os.path.join(f\"data/mapping_{concept_name}.json\"), \"r\") as f:\n",
    "        mapping = json.load(f)\n",
    "    mapping = ig.get_clean_mapping(mapping, vocab_dict)\n",
    "    \n",
    "    shard_idx0 = []\n",
    "    shard_idx1 = []\n",
    "    embeddings0 = []\n",
    "    embeddings1 = []\n",
    "    text_idx0 = []\n",
    "    text_idx1 = []\n",
    "    token_pos0 = []\n",
    "    token_pos1 = []\n",
    "    for i in range(5):\n",
    "        indices0 = filter_indices_batch(all_en_embeddings[i], vocab_list, list(mapping.keys()))\n",
    "        embeddings0.append(all_en_embeddings[i]['embeddings'][indices0])\n",
    "        text_idx0.append(all_en_embeddings[i]['text_idx'][indices0])\n",
    "        token_pos0.append(all_en_embeddings[i]['token_pos'][indices0])\n",
    "        shard_idx0.append(i * torch.ones(len(indices0), dtype=torch.long))\n",
    "\n",
    "        if concept_name == \"verb_en_fr\":\n",
    "            indices1 = filter_indices_batch(all_fr_embeddings[i], vocab_list, list(mapping.values()))\n",
    "            embeddings1.append(all_fr_embeddings[i]['embeddings'][indices1])\n",
    "            text_idx1.append(all_fr_embeddings[i]['text_idx'][indices1])\n",
    "            token_pos1.append(all_fr_embeddings[i]['token_pos'][indices1])\n",
    "            shard_idx1.append(i * torch.ones(len(indices1), dtype=torch.long))\n",
    "        else:\n",
    "            indices1 = filter_indices_batch(all_en_embeddings[i], vocab_list, list(mapping.values()))\n",
    "            embeddings1.append(all_en_embeddings[i]['embeddings'][indices1])\n",
    "            text_idx1.append(all_en_embeddings[i]['text_idx'][indices1])\n",
    "            token_pos1.append(all_en_embeddings[i]['token_pos'][indices1])\n",
    "            shard_idx1.append(i * torch.ones(len(indices1), dtype=torch.long))\n",
    "        print(f\"{len(indices0)} base indices, {len(indices1)} target indices in shard {i}.\")\n",
    "\n",
    "    embeddings0 = torch.cat(embeddings0, dim=0)\n",
    "    embeddings1 = torch.cat(embeddings1, dim=0)\n",
    "    text_idx0 = torch.cat(text_idx0, dim=0)\n",
    "    text_idx1 = torch.cat(text_idx1, dim=0)\n",
    "    token_pos0 = torch.cat(token_pos0, dim=0)\n",
    "    token_pos1 = torch.cat(token_pos1, dim=0)\n",
    "    shard_idx0 = torch.cat(shard_idx0, dim=0)\n",
    "    shard_idx1 = torch.cat(shard_idx1, dim=0)\n",
    "\n",
    "    num_train = 400\n",
    "    num_test = 100\n",
    "\n",
    "    if len(embeddings0) < num_train + num_test or len(embeddings1) < num_train + num_test:\n",
    "        print(f\"Not enough embeddings for concept {concept_name}. Skipping.\")\n",
    "        continue\n",
    "\n",
    "    base_perm = torch.randperm(len(embeddings0))[:num_train + num_test]\n",
    "    train_embeddings0 = embeddings0[base_perm[:num_train]]\n",
    "    test_embeddings0 = embeddings0[base_perm[num_train:]]\n",
    "    \n",
    "\n",
    "    target_perm = torch.randperm(len(embeddings1))[:num_train + num_test]\n",
    "    train_embeddings1 = embeddings1[target_perm[:num_train]]\n",
    "    test_embeddings1 = embeddings1[target_perm[num_train:]]\n",
    "\n",
    "    text_indices = {\n",
    "        \"train_idx0\": (shard_idx0[base_perm[:num_train]],\n",
    "                       text_idx0[base_perm[:num_train]],\n",
    "                       token_pos0[base_perm[:num_train]]),\n",
    "        \"test_idx0\": (shard_idx0[base_perm[num_train:]],\n",
    "                      text_idx0[base_perm[num_train:]],\n",
    "                      token_pos0[base_perm[num_train:]]),\n",
    "        \"train_idx1\": (shard_idx1[target_perm[:num_train]],\n",
    "                       text_idx1[target_perm[:num_train]],\n",
    "                       token_pos1[target_perm[:num_train]]),\n",
    "        \"test_idx1\": (shard_idx1[target_perm[num_train:]],\n",
    "                      text_idx1[target_perm[num_train:]],\n",
    "                      token_pos1[target_perm[num_train:]])\n",
    "    }\n",
    "\n",
    "    print(len(train_embeddings0), len(test_embeddings0), len(train_embeddings1), len(test_embeddings1))\n",
    "\n",
    "    torch.save(train_embeddings0, os.path.join(concept_path, \"train_embeddings0.pt\"))\n",
    "    torch.save(test_embeddings0 , os.path.join(concept_path, \"test_embeddings0.pt\"))\n",
    "    torch.save(train_embeddings1, os.path.join(concept_path, \"train_embeddings1.pt\"))\n",
    "    torch.save(test_embeddings1, os.path.join(concept_path, \"test_embeddings1.pt\"))\n",
    "    torch.save(text_indices, os.path.join(concept_path, \"text_indices.pt\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "kiho",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
