{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.benchmark import get_cross_sections, get_feature_dicts\n",
    "from utils.setup import init_model\n",
    "from utils.data_generator import init_ioi, init_ind\n",
    "from utils.data_generator import evaluate_example_generator\n",
    "from utils.attribute_values import names, pythia_names, countries, animals, numbers\n",
    "from utils.data_generator import evaluate_example_generator\n",
    "import gc\n",
    "import random\n",
    "\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_cross_section_train_batches=5\n",
    "n_cross_sections = 300\n",
    "n_feature_train_batches = 200\n",
    "n_features = 10\n",
    "batch_size = 50\n",
    "\n",
    "\n",
    "name_features = random.sample(names,n_features)\n",
    "pythia_name_features = random.sample(pythia_names,n_features)\n",
    "gemma_name_features = random.sample(names,n_features)\n",
    "country_features = random.sample(countries,n_features)\n",
    "animal_features = random.sample(animals,n_features)\n",
    "number_features = random.sample(numbers,n_features)\n",
    "\n",
    "upstream_components = [\"attn.hook_result\"]\n",
    "downstream_components = [\"hook_q_input\",\"hook_k_input\",\"hook_v_input\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#model = init_model(\"gpt2-small\", n_devices=2)\n",
    "model = init_model(\"EleutherAI/pythia-70m-deduped\", device=\"cuda:0\")\n",
    "#model = init_model(\"gemma-2-2b\", n_devices=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_ioi_train_examples, n_ioi_test_examples = init_ioi(model, names, batch_size, train=True)\n",
    "print(f\"IOI Train Examples Prediction Accuracy: {evaluate_example_generator(model, get_ioi_train_examples, 'io', 50)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_ind_train_examples_gemma, _ = init_ind(model, names, batch_size, train=True, cross_entropy_threshhold=5)\n",
    "print(f\"IND Train Examples Prediction Accuracy: {evaluate_example_generator(model, get_ind_train_examples_gemma, 'ind2', 50)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_ind_train_examples_pythia, _ = init_ind(model, pythia_names, batch_size, train=True, cross_entropy_threshhold=3, seq_length=5)\n",
    "print(f\"IND Train Examples Prediction Accuracy: {evaluate_example_generator(model, get_ind_train_examples_pythia, 'ind2', 50)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_ind_train_examples_names, n_ind_test_examples_names = init_ind(model, names, batch_size, train=True, cross_entropy_threshhold=3)\n",
    "print(f\"IND Train Examples Prediction Accuracy: {evaluate_example_generator(model, get_ind_train_examples_names, 'ind2', 50)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_ind_train_examples_animals, _ = init_ind(model, animals, batch_size, train=True, cross_entropy_threshhold=3)\n",
    "print(f\"IND Train Examples Prediction Accuracy: {evaluate_example_generator(model, get_ind_train_examples_animals, 'ind2', 50)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_ind_train_examples_countries, _ = init_ind(model, countries, batch_size, train=True, cross_entropy_threshhold=3)\n",
    "print(f\"IND Train Examples Prediction Accuracy: {evaluate_example_generator(model, get_ind_train_examples_countries, 'ind2', 50)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_ind_train_examples_numbers, _ = init_ind(model, numbers, batch_size, train=True, cross_entropy_threshhold=3)\n",
    "print(f\"IND Train Examples Prediction Accuracy: {evaluate_example_generator(model, get_ind_train_examples_numbers, 'ind2', 50)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "correct = \"io\"\n",
    "incorrect = \"subject\"\n",
    "\n",
    "position_offset = get_ioi_train_examples(0)[0][\"value-positions\"][-1][0]\n",
    "ioi_cross_sections = get_cross_sections(\n",
    "    model, get_ioi_train_examples, correct, incorrect, upstream_components, downstream_components, n_cross_section_train_batches, n_cross_sections, position_offset, incl_ap=False\n",
    "    )\n",
    "torch.save(ioi_cross_sections,\"./cross_sections/gpt2/ioi\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "ioi_cross_sections = torch.load(\"./cross_sections/gpt2/ioi\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_feature_dicts(model, ioi_cross_sections, get_ioi_train_examples, n_feature_train_batches, name_features, method=\"mse\", \n",
    "                                                path=\"./feature_dicts/gpt2/ioi/names/ioi_names\", step=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "correct = \"ind2\"\n",
    "incorrect = \"ind1\"\n",
    "\n",
    "get_ind_examples = lambda x: random.choice([get_ind_train_examples_names(x),get_ind_train_examples_countries(x),get_ind_train_examples_numbers(x)])\n",
    "\n",
    "position_offset = get_ind_train_examples_names(0)[0][\"value-positions\"][-1][0]\n",
    "ind_cross_sections = get_cross_sections(\n",
    "    model, get_ind_train_examples_names, correct, incorrect, upstream_components, downstream_components, n_cross_section_train_batches, n_cross_sections, position_offset, incl_ap=False\n",
    "    )\n",
    "torch.save(ind_cross_sections,\"./cross_sections/gpt2/ind\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "ind_cross_sections = torch.load(\"./cross_sections/gpt2/ind\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_feature_dicts(model, ind_cross_sections, get_ind_train_examples_names, n_feature_train_batches, name_features, method=\"mse\"\n",
    "                                                ,path=\"./feature_dicts/gpt2/ind/names/ind_names\", step=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_feature_dicts(model, ind_cross_sections, get_ind_train_examples_countries, n_feature_train_batches, country_features, method=\"mse\"\n",
    "                                                ,path=\"./feature_dicts/gpt2/ind/countries/ind_countries\", step=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_feature_dicts(model, ind_cross_sections, get_ind_train_examples_animals, n_feature_train_batches, animal_features, method=\"mse\"\n",
    "                                                ,path=\"./feature_dicts/gpt2/ind/animals/ind_animals\", step=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_feature_dicts(model, ind_cross_sections, get_ind_train_examples_numbers, n_feature_train_batches, number_features, method=\"mse\"\n",
    "                                                ,path=\"./feature_dicts/gpt2/ind/numbers/ind_numbers\", step=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "correct = \"ind2\"\n",
    "incorrect = \"ind1\"\n",
    "\n",
    "position_offset = get_ind_train_examples_pythia(0)[0][\"value-positions\"][-1][0]\n",
    "ind_cross_sections_pythia = get_cross_sections(\n",
    "    model, get_ind_train_examples_pythia, correct, incorrect, upstream_components, downstream_components, n_cross_section_train_batches, n_cross_sections, position_offset, incl_ap=False\n",
    "    )\n",
    "torch.save(ind_cross_sections_pythia,\"./cross_sections/pythia/pythia\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "ind_cross_sections_pythia = torch.load(\"./cross_sections/pythia/pythia\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_feature_dicts(model, ind_cross_sections_pythia, get_ind_train_examples_pythia, n_feature_train_batches, pythia_name_features, method=\"mse\"\n",
    "                                                ,path=\"./feature_dicts/pythia/ind_names\", step=30)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gemma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "correct = \"ind2\"\n",
    "incorrect = \"ind1\"\n",
    "\n",
    "position_offset = get_ind_train_examples_gemma(0)[0][\"value-positions\"][-1][0]\n",
    "ind_cross_sections_gemma = get_cross_sections(\n",
    "    model, get_ind_train_examples_gemma, correct, incorrect, upstream_components, downstream_components, n_cross_section_train_batches, n_cross_sections, position_offset, incl_ap=False\n",
    "    )\n",
    "torch.save(ind_cross_sections_gemma,\"./cross_sections/gemma/gemma\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "ind_cross_sections_gemma = torch.load(\"./cross_sections/gemma/gemma\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_feature_dicts(model, ind_cross_sections_gemma, get_ind_train_examples_gemma, n_feature_train_batches, name_features, method=\"mse\"\n",
    "                                                ,path=\"./feature_dicts/gemma/ind_names\", step=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dissenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
