{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import torch\n",
    "import copy\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from src import models, data\n",
    "from src.metrics import recall\n",
    "from src.attributelens.attributelens import Attribute_Lens\n",
    "import src.attributelens.utils as lens_utils\n",
    "from src.operators import JacobianIclMeanEstimator\n",
    "import plotly.graph_objects as go\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda:0\"\n",
    "mt = models.load_model(\"gptj\", device=device)\n",
    "print(f\"dtype: {mt.model.dtype}, device: {mt.model.device}, memory: {mt.model.get_memory_footprint()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lens = Attribute_Lens(mt=mt, top_k=10)\n",
    "att_info = lens.apply_attribute_lens(\n",
    "    prompt=\" Bill Bradley was a\",\n",
    "    relation_operator=None # operator\n",
    ")\n",
    "#att_info['subject_range'] = (8, 13)\n",
    "print('prediction:', att_info['nextwords'][-1])\n",
    "p = lens_utils.visualize_attribute_lens(\n",
    "    att_info, layer_skip=3, must_have_layers=[],\n",
    ")\n",
    "\n",
    "p.write_image('bill_bradley_lens.pdf')\n",
    "p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = data.load_dataset()\n",
    "print('\\n'.join([d.name for d in dataset]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datums =[d for d in dataset if d.name == \"plays pro sport\"][0]\n",
    "#print(datums)\n",
    "print(len(datums.samples))\n",
    "\n",
    "np.random.seed(4)\n",
    "indices = np.random.choice(range(len(datums.samples)), 5, replace=False)\n",
    "samples = [datums.samples[i] for i in indices]\n",
    "\n",
    "training_samples = copy.deepcopy(datums.__dict__)\n",
    "training_samples[\"samples\"] = samples\n",
    "training_samples = data.Relation(**training_samples)\n",
    "\n",
    "print(training_samples.samples)\n",
    "\n",
    "mean_estimator = JacobianIclMeanEstimator(\n",
    "    mt=mt,\n",
    "    h_layer=15,\n",
    "    bias_scale_factor=0.5       # so that the bias doesn't knock out the prediction too much in the direction of training examples\n",
    ") \n",
    "\n",
    "operator = mean_estimator(training_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lens = Attribute_Lens(mt=mt, top_k=10)\n",
    "att_info = lens.apply_attribute_lens(\n",
    "    prompt=\" Bill Bradley was a\",\n",
    "    relation_operator=operator,\n",
    ")\n",
    "#att_info['subject_range'] = (8, 13)\n",
    "print('prediction:', att_info['nextwords'][-1])\n",
    "p = lens_utils.visualize_attribute_lens(\n",
    "    att_info, layer_skip=3, must_have_layers=[], colorscale='greens'\n",
    ")\n",
    "p.write_image('bill_bradley_sport.pdf')\n",
    "p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datums =[d for d in dataset if d.name == \"person went to university\"][0]\n",
    "mean_estimator = JacobianIclMeanEstimator(\n",
    "    mt=mt,\n",
    "    h_layer=15,\n",
    "    bias_scale_factor=0.5       # so that the bias doesn't knock out the prediction too much in the direction of training examples\n",
    ") \n",
    "\n",
    "np.random.seed(8)\n",
    "indices = np.random.choice(range(len(datums.samples)), 5, replace=False)\n",
    "#indices = np.array([ 1, 20,  3,  7,  0])\n",
    "\n",
    "samples = [datums.samples[i] for i in indices]\n",
    "\n",
    "training_samples = copy.deepcopy(datums.__dict__)\n",
    "training_samples[\"samples\"] = samples\n",
    "training_samples = data.Relation(**training_samples)\n",
    "\n",
    "training_samples.samples\n",
    "operator2 = mean_estimator(training_samples)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lens = Attribute_Lens(mt=mt, top_k=10)\n",
    "att_info = lens.apply_attribute_lens(\n",
    "    prompt=\" Bill Bradley was a\",\n",
    "    relation_operator=operator2,\n",
    ")\n",
    "#att_info['subject_range'] = (8, 13)\n",
    "print('prediction:', att_info['nextwords'][-1])\n",
    "p = lens_utils.visualize_attribute_lens(\n",
    "    att_info, layer_skip=3, must_have_layers=[], colorscale='oranges'\n",
    ")\n",
    "p.write_image('bill_bradley_school.pdf')\n",
    "p"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
