{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "008e1f3a-1f7d-46a5-86c3-eae2b442ca02",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading LLaMA model using transformers pipeline...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "054c34bedf8840ed9bb104f1293b6094",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some parameters are on the meta device because they were offloaded to the disk.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LLaMA model loaded successfully.\n"
     ]
    }
   ],
   "source": [
    "import transformers\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import defaultdict\n",
    "\n",
    "# Load LLaMA 3 model using transformers pipeline\n",
    "print(\"Loading LLaMA model using transformers pipeline...\")\n",
    "MODEL_NAME = \"meta-llama/Llama-2-7b-hf\"  # Replace with the actual model name on Hugging Face\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"mps\")\n",
    "tokenizer = transformers.LlamaTokenizer.from_pretrained(MODEL_NAME, legacy=False)\n",
    "tokenizer.pad_token = tokenizer.eos_token  # Set pad token\n",
    "# model = transformers.LlamaForCausalLM.from_pretrained(MODEL_NAME, device_map=\"auto\")\n",
    "model = transformers.LlamaForCausalLM.from_pretrained(MODEL_NAME, device_map=\"auto\", offload_folder=\"offload\")\n",
    "print(\"LLaMA model loaded successfully.\")\n",
    "\n",
    "# Function to get embeddings of a list of texts\n",
    "def get_embeddings(texts):\n",
    "    tokenized_output = tokenizer(texts, return_tensors=\"pt\", padding=True, truncation=True, max_length=512).to(device)\n",
    "    input_ids = tokenized_output[\"input_ids\"]\n",
    "    with torch.no_grad():\n",
    "        outputs = model(input_ids, output_hidden_states=True)\n",
    "    hidden_states = outputs.hidden_states\n",
    "    last_hidden_states = hidden_states[-1]\n",
    "    return last_hidden_states.mean(dim=1)\n",
    "\n",
    "# Load the animal-attribute pairs\n",
    "animal_attribute_file = './data/animal-habit.txt'\n",
    "animal_attribute_pairs = defaultdict(list)\n",
    "all_animals = set()\n",
    "all_attributes = set()\n",
    "\n",
    "with open(animal_attribute_file, 'r') as file:\n",
    "    for line in file:\n",
    "        parts = line.strip().split(',')\n",
    "        if len(parts) >= 2:\n",
    "            animal, attribute = parts[:2]\n",
    "            # print(animal, attribute)\n",
    "            animal_attribute_pairs[attribute].append(animal)\n",
    "            all_animals.add(animal)\n",
    "            all_attributes.add(attribute)\n",
    "\n",
    "# Get embeddings for all animals\n",
    "animal_embeddings = {animal: get_embeddings([animal]) for animal in all_animals}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "412925d7-d055-4dd9-a37a-fdb89a85c2a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "animal_embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3701c491-aac2-481f-af38-21873f3f23d5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
