{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "\n",
    "REPO_DIR = f'{os.getcwd()}'\n",
    "SRC_DIR = os.path.join(REPO_DIR, 'src')\n",
    "MODEL_DIR = os.path.join(REPO_DIR, 'models')\n",
    "DATA_DIR = os.path.join(REPO_DIR, 'data')\n",
    "\n",
    "for d in [MODEL_DIR, DATA_DIR]:\n",
    "    if not os.path.exists(d):\n",
    "        os.makedirs(d)\n",
    "\n",
    "\n",
    "import sys\n",
    "sys.path.append(REPO_DIR)\n",
    "sys.path.append(SRC_DIR)\n",
    "\n",
    "import numpy as np\n",
    "import random\n",
    "import torch\n",
    "import accelerate\n",
    "from nnsight import NNsight\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "def set_seed(seed):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "set_seed(0)\n",
    "\n",
    "device = \"cpu\"\n",
    "if torch.backends.mps.is_available():\n",
    "    device = \"mps\"\n",
    "elif torch.cuda.is_available():\n",
    "    device = \"cuda\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model\n",
    "\n",
    "\n",
    "with open('../../auth/hf_token.txt', 'r') as f:\n",
    "    hf_token = f.read().strip()\n",
    "\n",
    "model_id = \"google/gemma-2-2b\"\n",
    "model_name = \"gemma-2-2b\"\n",
    "\n",
    "torch.set_grad_enabled(False) # avoid blowing up mem\n",
    "hf_model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_id,\n",
    "    cache_dir=MODEL_DIR,\n",
    "    token=hf_token,\n",
    "    device_map=device,\n",
    "    low_cpu_mem_usage=True,\n",
    "    attn_implementation=\"eager\"\n",
    ")\n",
    "\n",
    "tokenizer =  AutoTokenizer.from_pretrained(\n",
    "    model_id,\n",
    "    cache_dir=MODEL_DIR,\n",
    "    token=hf_token,\n",
    ")\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "tokenizer.padding_side = 'left'\n",
    "VOCAB = sorted(tokenizer.vocab, key=tokenizer.vocab.get)\n",
    "\n",
    "layer_idx = 10\n",
    "\n",
    "\n",
    "nnsight_model = NNsight(hf_model)\n",
    "nnsight_tracer_kwargs = {'scan': True, 'validate': False, 'use_cache': False, 'output_attentions': False}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ravel_dataset_builder import RAVELEntityPromptData\n",
    "\n",
    "full_entity_dataset = RAVELEntityPromptData.from_files('city', 'data', tokenizer)\n",
    "len(full_entity_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sampled_entity_dataset = full_entity_dataset.downsample(1000)\n",
    "print(f\"Number of prompts remaining: {len(sampled_entity_dataset)}\")\n",
    "\n",
    "prompt_max_length = 48\n",
    "sampled_entity_dataset.generate_completions(nnsight_model, tokenizer, max_length=prompt_max_length+8, prompt_max_length=prompt_max_length)\n",
    "\n",
    "sampled_entity_dataset.evaluate_correctness()\n",
    "\n",
    "# Filter correct completions\n",
    "correct_data = sampled_entity_dataset.filter_correct()\n",
    "\n",
    "# Filter top entities and templates\n",
    "filtered_data = correct_data.filter_top_entities_and_templates(top_n_entities=400, top_n_templates_per_attribute=12)\n",
    "\n",
    "# Calculate average accuracy\n",
    "accuracy = sampled_entity_dataset.calculate_average_accuracy()\n",
    "print(f\"Average accuracy: {accuracy:.2%}\")\n",
    "print(f\"Number of prompts remaining: {len(correct_data)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "correct_data.add_wikipedia_prompts('city', 'data', tokenizer, nnsight_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experimental Interventions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
