{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import argparse\n",
    "import json\n",
    "import logging\n",
    "import math\n",
    "import os\n",
    "import random\n",
    "from pathlib import Path\n",
    "from tqdm import tqdm\n",
    "\n",
    "import datasets\n",
    "from datasets import load_dataset, DatasetDict\n",
    "\n",
    "import evaluate\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import transformers\n",
    "from transformers import AutoTokenizer, AutoModel, default_data_collator, SchedulerType, get_scheduler\n",
    "from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry\n",
    "from transformers.utils.versions import require_version\n",
    "\n",
    "from huggingface_hub import Repository, create_repo\n",
    "\n",
    "from accelerate import Accelerator\n",
    "from accelerate.logging import get_logger\n",
    "from accelerate.utils import set_seed\n",
    "\n",
    "from peft import PeftModel\n",
    "\n",
    "import hnswlib"
   ],
   "id": "5e143e59bc67983c"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "class AutoModelForSentenceEmbedding(nn.Module):\n",
    "    def __init__(self, model_name, tokenizer, normalize=True):\n",
    "        super(AutoModelForSentenceEmbedding, self).__init__()\n",
    "\n",
    "        self.model = AutoModel.from_pretrained(\n",
    "            model_name)  # , quantizaton_config=BitsAndBytesConfig(load_in_8bit=True), device_map={\"\":0})\n",
    "        self.normalize = normalize\n",
    "        self.tokenizer = tokenizer\n",
    "\n",
    "    def forward(self, **kwargs):\n",
    "        model_output = self.model(**kwargs)\n",
    "        embeddings = self.mean_pooling(model_output, kwargs[\"attention_mask\"])\n",
    "        if self.normalize:\n",
    "            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)\n",
    "\n",
    "        return embeddings\n",
    "\n",
    "    def mean_pooling(self, model_output, attention_mask):\n",
    "        token_embeddings = model_output[0]  # First element of model_output contains all token embeddings\n",
    "        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n",
    "        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n",
    "\n",
    "    def __getattr__(self, name: str):\n",
    "        \"\"\"Forward missing attributes to the wrapped module.\"\"\"\n",
    "        try:\n",
    "            return super().__getattr__(name)  # defer to nn.Module's logic\n",
    "        except AttributeError:\n",
    "            return getattr(self.model, name)\n",
    "\n",
    "\n",
    "def get_cosing_embeddings(query_embs, product_embs):\n",
    "    return torch.sum(query_embs * product_embs, axis=1)"
   ],
   "id": "c2d543f697edb5dd"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "model_name_or_path = \"intfloat/e5-large-v2\"\n",
    "peft_model_id = \"smangrul/peft_lora_e5_semantic_search\"\n",
    "dataset_name = \"smangrul/amazon_esci\"\n",
    "max_length = 70\n",
    "batch_size = 256"
   ],
   "id": "beedb69bd3f283d5"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import pandas as pd\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
    "dataset = load_dataset(dataset_name)\n",
    "train_product_dataset = dataset[\"train\"].to_pandas()[[\"product_title\"]]\n",
    "val_product_dataset = dataset[\"validation\"].to_pandas()[[\"product_title\"]]\n",
    "product_dataset_for_indexing = pd.concat([train_product_dataset, val_product_dataset])\n",
    "product_dataset_for_indexing = product_dataset_for_indexing.drop_duplicates()\n",
    "product_dataset_for_indexing.reset_index(drop=True, inplace=True)\n",
    "product_dataset_for_indexing.reset_index(inplace=True)"
   ],
   "id": "9ed04fda54b293cd"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "product_dataset_for_indexing",
   "id": "f1c9a92842c8d575"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "pd.set_option(\"max_colwidth\", 300)\n",
    "product_dataset_for_indexing.sample(10)"
   ],
   "id": "67b83904605b9fb6"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from datasets import Dataset\n",
    "\n",
    "dataset = Dataset.from_pandas(product_dataset_for_indexing)\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    products = examples[\"product_title\"]\n",
    "    result = tokenizer(products, padding=\"max_length\", max_length=70, truncation=True)\n",
    "    return result\n",
    "\n",
    "\n",
    "processed_dataset = dataset.map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    "    remove_columns=dataset.column_names,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "processed_dataset"
   ],
   "id": "2fc314496191724c"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# base model\n",
    "model = AutoModelForSentenceEmbedding(model_name_or_path, tokenizer)\n",
    "\n",
    "# peft config and wrapping\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)\n",
    "\n",
    "print(model)"
   ],
   "id": "c02ef3fd18abb8cd"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "dataloader = DataLoader(\n",
    "    processed_dataset,\n",
    "    shuffle=False,\n",
    "    collate_fn=default_data_collator,\n",
    "    batch_size=batch_size,\n",
    "    pin_memory=True,\n",
    ")"
   ],
   "id": "ef12430398deb22c"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "next(iter(dataloader))",
   "id": "7d57436e173e58f3"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "ids_to_products_dict = {i: p for i, p in zip(dataset[\"index\"], dataset[\"product_title\"])}\n",
    "ids_to_products_dict"
   ],
   "id": "ed0dae073caa61e4"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "device = \"cuda\"\n",
    "model.to(device)\n",
    "model.eval()\n",
    "model = model.merge_and_unload()"
   ],
   "id": "732b41090817ecef"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import numpy as np\n",
    "\n",
    "num_products = len(dataset)\n",
    "d = 1024\n",
    "\n",
    "product_embeddings_array = np.zeros((num_products, d))\n",
    "for step, batch in enumerate(tqdm(dataloader)):\n",
    "    with torch.no_grad():\n",
    "        with torch.amp.autocast(dtype=torch.bfloat16, device_type=\"cuda\"):\n",
    "            product_embs = model(**{k: v.to(device) for k, v in batch.items()}).detach().float().cpu()\n",
    "    start_index = step * batch_size\n",
    "    end_index = start_index + batch_size if (start_index + batch_size) < num_products else num_products\n",
    "    product_embeddings_array[start_index:end_index] = product_embs\n",
    "    del product_embs, batch"
   ],
   "id": "3d9e4787a8472a74"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def construct_search_index(dim, num_elements, data):\n",
    "    # Declaring index\n",
    "    search_index = hnswlib.Index(space=\"ip\", dim=dim)  # possible options are l2, cosine or ip\n",
    "\n",
    "    # Initializing index - the maximum number of elements should be known beforehand\n",
    "    search_index.init_index(max_elements=num_elements, ef_construction=200, M=100)\n",
    "\n",
    "    # Element insertion (can be called several times):\n",
    "    ids = np.arange(num_elements)\n",
    "    search_index.add_items(data, ids)\n",
    "\n",
    "    return search_index"
   ],
   "id": "8c9161d4916ffe62"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "product_search_index = construct_search_index(d, num_products, product_embeddings_array)",
   "id": "84b021c6c59d0a09"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def get_query_embeddings(query, model, tokenizer, device):\n",
    "    inputs = tokenizer(query, padding=\"max_length\", max_length=70, truncation=True, return_tensors=\"pt\")\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        query_embs = model(**{k: v.to(device) for k, v in inputs.items()}).detach().cpu()\n",
    "    return query_embs[0]\n",
    "\n",
    "\n",
    "def get_nearest_neighbours(k, search_index, query_embeddings, ids_to_products_dict, threshold=0.7):\n",
    "    # Controlling the recall by setting ef:\n",
    "    search_index.set_ef(100)  # ef should always be > k\n",
    "\n",
    "    # Query dataset, k - number of the closest elements (returns 2 numpy arrays)\n",
    "    labels, distances = search_index.knn_query(query_embeddings, k=k)\n",
    "\n",
    "    return [\n",
    "        (ids_to_products_dict[label], (1 - distance))\n",
    "        for label, distance in zip(labels[0], distances[0])\n",
    "        if (1 - distance) >= threshold\n",
    "    ]"
   ],
   "id": "830f23c61b535f9e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "query = \"NLP and ML books\"\n",
    "k = 10\n",
    "query_embeddings = get_query_embeddings(query, model, tokenizer, device)\n",
    "search_results = get_nearest_neighbours(k, product_search_index, query_embeddings, ids_to_products_dict, threshold=0.7)\n",
    "\n",
    "print(f\"{query=}\")\n",
    "for product, cosine_sim_score in search_results:\n",
    "    print(f\"cosine_sim_score={round(cosine_sim_score, 2)} {product=}\")"
   ],
   "id": "d0ebba03d67a3d49"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "d28675b778960fef"
  }
 ],
 "metadata": {},
 "nbformat": 5,
 "nbformat_minor": 9
}
