{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### After having a trained model, use that model to get the embeddings of sets of questions. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import BertModel, BertTokenizer, BertConfig\n",
    "import torch\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "from torch.nn.functional import cosine_similarity\n",
    "import os\n",
    "import numpy as np\n",
    "from numpy import linalg\n",
    "from sklearn.preprocessing import normalize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the folder to load the trained model \n",
    "folder_saved_model = '' # Experiment folder\n",
    "\n",
    "path_data_questions = '../data/XES3G5M/metadata/questions_translated_kc_sol_annotated.json'\n",
    "path_kc_questions_map = '/../data/XES3G5M/metadata/kc_questions_map.json'\n",
    "\n",
    "with open(path_data_questions, 'r') as file:\n",
    "    data_questions = json.load(file)\n",
    "\n",
    "with open(path_kc_questions_map, 'r') as file:\n",
    "    kc_questions_map = json.load(file)\n",
    "\n",
    "embeddings_save_folder = \"../data/XES3G5M/metadata/embeddings/representation_learning\"\n",
    "\n",
    "if not os.path.exists(embeddings_save_folder):\n",
    "    os.makedirs(embeddings_save_folder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the tokenizer\n",
    "tokenizer = BertTokenizer.from_pretrained(folder_saved_model + '/tokenizer')\n",
    "\n",
    "# Create a configuration object or load it if you have saved one\n",
    "config = BertConfig.from_pretrained('bert-base-uncased')\n",
    "\n",
    "# Initialize the model with this configuration\n",
    "model = BertModel(config)\n",
    "\n",
    "# Adjust the model's token embeddings to account for new tokens before loading the weights\n",
    "model.resize_token_embeddings(len(tokenizer))\n",
    "\n",
    "# Load the model weights\n",
    "model.load_state_dict(torch.load(folder_saved_model + '/bert_finetuned.bin'))\n",
    "\n",
    "# Move the model to the appropriate computing device\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = model.to(device)\n",
    "\n",
    "# Set the model to training or evaluation mode as needed\n",
    "model = model.eval()  # or model.train() if you continue training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Constants\n",
    "BATCH_SIZE = 1024  # Define your batch size here\n",
    "\n",
    "# Helper function to batch text data and convert to embeddings\n",
    "def text_to_embeddings(texts, max_length=128):\n",
    "    embeddings = []\n",
    "    for i in tqdm(range(0, len(texts), BATCH_SIZE), desc=\"Generating Embeddings\"):\n",
    "        batch_texts = texts[i:i + BATCH_SIZE]\n",
    "        inputs = tokenizer(batch_texts, return_tensors=\"pt\", padding=True, truncation=True, max_length=max_length)\n",
    "        inputs = {key: val.to(device) for key, val in inputs.items()}\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**inputs)\n",
    "        embeddings.append(outputs.last_hidden_state[:, 0, :])  # Extract [CLS] token embeddings\n",
    "    return torch.cat(embeddings, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "list_questions = [value['question'] for key, value in data_questions.items()]\n",
    "list_sol_steps = [[sol for sol in value['step_by_step_solution_list']] for key,value in data_questions.items()]\n",
    "\n",
    "#Prepend special tokens \n",
    "questions = ['[Q] ' + q for q in list_questions]\n",
    "sol_steps = [['[S] ' + step for step in sol_steps] for sol_steps in list_sol_steps]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Get the embeddings\n",
    "question_embeddings = text_to_embeddings(questions)\n",
    "\n",
    "# Flatten the solution steps and prepend with special token\n",
    "flat_solution_steps = [step for sublist in sol_steps for step in sublist]\n",
    "flat_solution_embeddings = text_to_embeddings(flat_solution_steps)\n",
    "\n",
    "# Map flat embeddings back to their respective lists using original lengths\n",
    "sol_step_embeddings = []\n",
    "start_idx = 0\n",
    "for steps in sol_steps:\n",
    "    end_idx = start_idx + len(steps)\n",
    "    sol_step_embeddings.append(flat_solution_embeddings[start_idx:end_idx])\n",
    "    start_idx = end_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert these embeddings to numpy array or lists to have necessary pre-computations\n",
    "np_question_embeddings = question_embeddings.cpu().detach().numpy()\n",
    "\n",
    "np_sol_step_embeddings = []\n",
    "for i in range(len(sol_step_embeddings)):\n",
    "    np_sol_step_embeddings.append(sol_step_embeddings[i].cpu().detach().numpy())\n",
    "\n",
    "np_sol_step_embeddings_mean = []\n",
    "for i in range(len(np_sol_step_embeddings)):\n",
    "    np_sol_step_embeddings_mean.append(np_sol_step_embeddings[i].mean(axis=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "dict_emb = {}\n",
    "for i in range(len(np_question_embeddings)):\n",
    "    emb_q = np_question_embeddings[i].copy().reshape(1,-1)\n",
    "    emb_sol = np_sol_step_embeddings_mean[i].copy().reshape(1,-1)\n",
    "    emb = (emb_q + emb_sol)/2\n",
    "\n",
    "    norm_emb = normalize(emb, axis=1, norm='l2').flatten()\n",
    "    dict_emb[str(i)] = norm_emb.tolist()\n",
    "\n",
    "save_path = os.path.join(embeddings_save_folder, 'qid2content_sol_avg_emb.json')\n",
    "\n",
    "with open(save_path, 'w') as f:\n",
    "    json.dump(dict_emb, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_hf_437",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
