{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fc08d8ec-79a5-4152-9388-fd85dd80e70f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import openai\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7743df77-485e-4abc-b7d7-e3be719a14ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/Users/cherian/Projects/conformal-safety/data/factscore_final_dataset.pkl', 'rb') as fp:\n",
    "    dataset = pickle.load(fp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "720d9a15-00f1-4514-a768-bca0732abd75",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b389232b2d1e4e67b3b6f78be3992a10",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/8515 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from openai import OpenAI\n",
    "from tqdm.notebook import tqdm\n",
    "from concurrent.futures import ThreadPoolExecutor\n",
    "\n",
    "def get_embeddings(client, afs, model=\"text-embedding-3-small\"):\n",
    "    embeds = []\n",
    "    for af in afs:\n",
    "        text = af['atom'].strip()\n",
    "        e = client.embeddings.create(input = [text], model=model).data[0].embedding\n",
    "        embeds.append(e)\n",
    "    return embeds\n",
    "\n",
    "# df['ada_embedding'] = df.combined.apply(lambda x: get_embedding(x, model='text-embedding-3-small'))\n",
    "# df.to_csv('output/embedded_1k_reviews.csv', index=False)\n",
    "\n",
    "\n",
    "client = OpenAI()\n",
    "\n",
    "with ThreadPoolExecutor(max_workers=25) as executor:\n",
    "    embeddings = list(\n",
    "        tqdm(\n",
    "            executor.map(\n",
    "                lambda x : get_embeddings(client, x),\n",
    "                [dat['atomic_facts'] for dat in dataset]\n",
    "            ),\n",
    "            total=len(dataset)\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "5e9cc8f2-0103-4155-9301-ff7b4a5d4234",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "embeddings_to_np = [np.asarray(embed) for embed in embeddings]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "9645ee57-c677-499b-828d-3095ef142932",
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings_dict = {dat['prompt']:embed for dat, embed in zip(dataset, embeddings_to_np)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "18e11b8a-94af-4366-b191-361472271f51",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(28, 5)\n"
     ]
    }
   ],
   "source": [
    "embeddings_to_np = [embed[:,0:5] for embed in embeddings_to_np]\n",
    "for embed in embeddings_to_np:\n",
    "    print(embed.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "79f730dc-d381-4f24-a01a-da131051cfd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.savez('/Users/cherian/Projects/conformal-safety/data/factscore_final_embeddings.npz', **embeddings_dict)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
