{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4545988f-e394-4a30-a80d-3d5c3540eaef",
   "metadata": {},
   "source": [
    "### Load training dataset and Vader sentiment analyzer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a57fc8f-46ca-4ebf-a4e0-ecdadb9098c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import dataclasses\n",
    "import gc\n",
    "import json\n",
    "import os\n",
    "import pickle\n",
    "import pprint\n",
    "\n",
    "from time import time\n",
    "\n",
    "import nltk\n",
    "import spacy\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import tqdm\n",
    "import transformers\n",
    "import wandb\n",
    "\n",
    "from scipy.sparse import csr_matrix\n",
    "from torch import FloatTensor, LongTensor, Tensor\n",
    "\n",
    "from dataclasses import dataclass\n",
    "from nltk.sentiment.vader import SentimentIntensityAnalyzer\n",
    "from torch import nn\n",
    "from tqdm import tqdm_notebook\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ddb89c9-9856-48b1-b346-a6f0bd8f2474",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'pythia-160m'\n",
    "policy_model_name  = 'pythia_160m_utility_reward'\n",
    "os.environ['WANDB_API_KEY'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8b6a48e-84cc-4db9-8501-dada3c0569a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "run = wandb.init(project=\"utility_reconstruction\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a0975b7-96d6-49ab-8df9-390378364143",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_model_name = f'EleutherAI/{model_name}'\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(full_model_name)\n",
    "model.cuda()\n",
    "model.eval()\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(full_model_name)\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "#model.resize_token_embeddings(len(tokenizer))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afc9a76f-aa78-49f1-ae8a-23b3d715a5f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "sentiment_analyzer = SentimentIntensityAnalyzer()\n",
    "lexicon = sentiment_analyzer.lexicon\n",
    "max_value = max(lexicon.values())\n",
    "min_value = min(lexicon.values())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd4ba3ec-e7a8-4180-8ad4-e9c1280b5048",
   "metadata": {},
   "source": [
    "## Tokenization and torch utilities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b5533bf-faaf-47c0-9fd9-3d1c903586d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def batch(iterable, n=1):\n",
    "    l = len(iterable)\n",
    "    for ndx in range(0, l, n):\n",
    "        yield iterable[ndx:min(ndx + n, l)]\n",
    "\n",
    "def clear_gpu_memory():\n",
    "    start_time = time()\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "    end_time = time()\n",
    "    total_time = round(end_time-start_time, 2)\n",
    "    print(f'Took {total_time} seconds to clear cache.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5dbbfcf-0577-4f2f-b987-7cc0ca71fafc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_list_of_lists(list_of_lists, pad_token):\n",
    "    max_length = max(len(lst) for lst in list_of_lists)\n",
    "    padded_list = [lst + [pad_token] * (max_length - len(lst)) for lst in list_of_lists]\n",
    "    return padded_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95fb77c1-dfaa-437c-b344-7950146656d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_number_of_tokens(word, tokenizer=tokenizer):\n",
    "    return len(tokenizer(word)['input_ids'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61533633-3263-4edb-9c96-d73e5bc07c8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_tokens_and_ids(text, tokenizer=tokenizer):\n",
    "    input_ids = tokenizer(text.lower(), truncation=True)['input_ids']\n",
    "    \n",
    "    tokens = [tokenizer.decode(input_id) for input_id in input_ids]\n",
    "    # The above produces artifacts such as a \" positive\" token and id, instead of \"positive\". So we redo this.\n",
    "\n",
    "    tokens = [token.lower().strip() for token in tokens]\n",
    "    tokenizer\n",
    "    return tokens, input_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e30a4f4c-22dd-4196-a7a8-b35bc9799bd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_single_target_token_id(word, tokenizer=tokenizer):\n",
    "    word = word.lower().strip()\n",
    "    num_tokens = check_number_of_tokens(word)\n",
    "    if num_tokens > 1:\n",
    "        # Backoff to include a single space.\n",
    "        word = f' {word}'\n",
    "        num_tokens = check_number_of_tokens(word)\n",
    "\n",
    "    return tokenizer(word)['input_ids'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e30cd968-56d3-41db-a0bd-598d9ea7b7db",
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class TextTokensIdsTarget:\n",
    "    attention_mask: list[int]\n",
    "    text: str\n",
    "    tokens: list[str]\n",
    "    ids: list[int]\n",
    "    target_token: str\n",
    "    target_token_id: int\n",
    "    target_token_position: int\n",
    "\n",
    "    @staticmethod\n",
    "    def get_tensorized(datapoints: \"TextTokensIdsTarget\"):\n",
    "        max_length = max([len(datapoint.tokens) for datapoint in datapoints])\n",
    "        \n",
    "        input_ids = [datapoint.ids for datapoint in datapoints]\n",
    "        attention_masks = [datapoint.attention_mask for datapoint in datapoints]\n",
    "\n",
    "        input_ids_padded = pad_list_of_lists(input_ids, tokenizer.encode(tokenizer.pad_token)[0])\n",
    "        attention_masks_padded = pad_list_of_lists(attention_masks, 0)\n",
    "        all_tokenized = {\n",
    "            \"input_ids\": torch.IntTensor(input_ids_padded).cuda(), \"attention_mask\": torch.ByteTensor(attention_masks_padded).cuda()\n",
    "        }\n",
    "        return all_tokenized\n",
    "\n",
    "def trim_example(input_text: str, target_words: list[str], verbose=False, tokenizer=tokenizer):\n",
    "    single_target_token_ids = [get_single_target_token_id(word.strip().lower()) for word in target_words]\n",
    "    \n",
    "    single_target_token_ids = [token_id for token_id in single_target_token_ids if token_id]\n",
    "    single_target_tokens = [tokenizer.decode(token_id).strip().lower() for token_id in single_target_token_ids]\n",
    "\n",
    "    input_tokens, input_token_ids = get_tokens_and_ids(input_text)\n",
    "\n",
    "    trimmed_input_tokens = []\n",
    "    trimmed_input_token_ids = []\n",
    "\n",
    "    for input_token, input_token_id in zip(input_tokens, input_token_ids):\n",
    "        trimmed_input_tokens.append(input_token)\n",
    "        trimmed_input_token_ids.append(input_token_id)\n",
    "        if input_token.strip().lower() in single_target_tokens:\n",
    "            break\n",
    "\n",
    "    assert len(trimmed_input_token_ids) == len(trimmed_input_tokens), \"Num of tokens and token ids should be equal\"\n",
    "\n",
    "    last_token = None\n",
    "\n",
    "    if trimmed_input_tokens:\n",
    "        last_token = trimmed_input_tokens[-1].lower().strip()\n",
    "        last_token_id = trimmed_input_token_ids[-1]\n",
    "\n",
    "    if len(trimmed_input_tokens) > tokenizer.model_max_length:\n",
    "        print(f'Dropping example since exceed model max length. Input text was:\\n{input_text}')\n",
    "        return None\n",
    "    \n",
    "    elif last_token and last_token in single_target_tokens:\n",
    "        text = tokenizer.decode(trimmed_input_token_ids)\n",
    "        target_token_position = len(trimmed_input_token_ids) - 1\n",
    "        return TextTokensIdsTarget(\n",
    "            attention_mask=[1]*len(trimmed_input_tokens),\n",
    "            text=text, tokens=trimmed_input_tokens, ids=trimmed_input_token_ids, \n",
    "            target_token=last_token, target_token_id=last_token_id,\n",
    "            target_token_position=target_token_position\n",
    "        )\n",
    "    else:\n",
    "        if verbose:\n",
    "            print(f'last token was {last_token} in {trimmed_input_tokens}, and was not in target tokens.')\n",
    "        return None"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3da5c75-e068-444c-9297-6a2bb7ef1924",
   "metadata": {},
   "source": [
    "### Load training examples for linear probe."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "913efa22-d8d2-4d1a-9f82-770e1052f59d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_wandb_json_artifact(\n",
    "    project_name='utility_reconstruction', artifact_name = 'contrastive_sentiment_pairs', version='v6'\n",
    "):\n",
    "    api = wandb.Api()\n",
    "    artifact = api.artifact(f'nlp_and_interpretability/{project_name}/{artifact_name}:{version}', type='data')\n",
    "    artifact_dir = artifact.download()\n",
    "\n",
    "    with open(f'artifacts/{artifact_name}:{version}/{artifact_name}', 'r') as f_in:\n",
    "        result = json.load(f_in)\n",
    "        return result\n",
    "\n",
    "all_input_dicts = load_wandb_json_artifact()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90f7cfaa-6817-481a-a5b0-f6cad586dcf5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def clean_capitalization_neutral_terms(input_dicts):\n",
    "    for input_dict in input_dicts:\n",
    "        neutral_words = list(input_dict['neutral_words'].values())\n",
    "        for neutral_word in neutral_words:\n",
    "            input_dict['neutral_text'] = input_dict['neutral_text'].replace(neutral_word, neutral_word.lower())\n",
    "\n",
    "    return input_dicts\n",
    "\n",
    "all_input_dicts = clean_capitalization_neutral_terms(all_input_dicts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b517c39-b9d3-42c9-a121-b7310df0c3b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TrainingPoint:\n",
    "\n",
    "    def __init__(self, input_dict: dict, tokenizer=tokenizer):\n",
    "        self.input_dict = input_dict\n",
    "        self.positive_text = input_dict['input_text']\n",
    "        self.negative_text = input_dict['output_text']\n",
    "        self.neutral_text = input_dict['neutral_text']\n",
    "        \n",
    "        # Dictionary of layer name to activations by mlp layer.\n",
    "        self.activations: dict = None\n",
    "\n",
    "        # Dictionary of layer name to autoencoder feature by mlp layer\n",
    "        self.autoencoder_feature: dict = None\n",
    "\n",
    "        # Reward value of target_token.\n",
    "        self.target_positive_reward = None\n",
    "        self.target_negative_reward = None\n",
    "\n",
    "        self.positive_text_tokens, self.positive_input_ids = get_tokens_and_ids(self.positive_text)\n",
    "        self.negative_text_tokens, self.negative_token_ids = get_tokens_and_ids(self.negative_text)\n",
    "        \n",
    "        self.positive_words = input_dict['positive_words']\n",
    "        self.negative_words = list(input_dict['new_words'].values())\n",
    "        self.neutral_words = list(input_dict['neutral_words'].values())\n",
    "\n",
    "        self.target_positive_reward = None\n",
    "        self.target_positive_token = None\n",
    "        self.target_positive_token_id = None\n",
    "    \n",
    "        self.target_negative_reward = None\n",
    "        self.target_negative_token = None\n",
    "        self.target_negative_token_id = None\n",
    "\n",
    "        self.target_neutral_token = None\n",
    "        self.target_neutral_token_id = None\n",
    "\n",
    "        try:\n",
    "            self.trimmed_positive_example: TextTokensIdTarget = trim_example(self.positive_text, self.positive_words)\n",
    "            if self.trimmed_positive_example:\n",
    "                positive_token = self.trimmed_positive_example.target_token.strip().lower()\n",
    "                self.target_positive_reward = lexicon.get(positive_token, None)\n",
    "                self.target_positive_token = positive_token\n",
    "                self.target_positive_token_id = self.trimmed_positive_example.target_token_id\n",
    "        \n",
    "        except Exception as e:\n",
    "            print(f'Caught exception {e} on {input_dict} for positive example.')\n",
    "            self.trimmed_positive_example = None\n",
    "        \n",
    "        try:\n",
    "            self.trimmed_negative_example: TextTokensIdTarget = trim_example(self.negative_text, self.negative_words)\n",
    "            if self.trimmed_negative_example:\n",
    "                negative_token = self.trimmed_negative_example.target_token.strip().lower()\n",
    "                self.target_negative_reward = lexicon.get(negative_token, None)\n",
    "                self.target_negative_token = negative_token\n",
    "                self.target_negative_token_id = self.trimmed_negative_example.target_token_id\n",
    "\n",
    "        except Exception as e:\n",
    "            print(f'Caught exception {e} on {input_dict} for negative example.')\n",
    "            self.trimmed_negative_example = None\n",
    "\n",
    "        try:\n",
    "            self.trimmed_neutral_example: TextTokensIdTarget = trim_example(self.neutral_text, self.neutral_words)\n",
    "            if self.trimmed_neutral_example:\n",
    "                self.target_neutral_token = self.trimmed_neutral_example.target_token.strip().lower()\n",
    "                self.target_neutral_token_id = self.trimmed_neutral_example.target_token_id\n",
    "\n",
    "        except Exception as e:\n",
    "            print(f'Caught exception {e} on {input_dict} for neutral example.')\n",
    "            self.trimmed_neutral_example = None\n",
    "\n",
    "    def __str__(self):\n",
    "        return pprint.pformat(self.__dict__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65316eee-8b79-411d-abc4-e7ece073ea7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_points = []\n",
    "for input_dict in tqdm_notebook(all_input_dicts):\n",
    "    training_points.append(TrainingPoint(input_dict))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1887651b-5d96-4570-ae7b-cb5399f4f0c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = training_points[10]\n",
    "x.autoencoder_feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cc71f1e-d09d-415a-97f9-38a4fc911f1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "successful_training_points = [\n",
    "    training_point for training_point in training_points if \n",
    "    training_point.trimmed_positive_example and training_point.trimmed_negative_example\n",
    "    and training_point.trimmed_neutral_example\n",
    "    and training_point.target_positive_reward is not None\n",
    "    and training_point.target_negative_reward is not None\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3297b711-b1f5-4a45-bde6-6fbb817247da",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(successful_training_points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b1f11bd-eaf4-4ee6-a5dd-eea748148853",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_training_points = successful_training_points[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52c1cf7f-b633-4bcf-bdc2-db6586ae1a3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "bad_training_points = [training_point for training_point in  training_points if training_point not in successful_training_points]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c613185-b7fc-4f40-91b0-ea187fc94e46",
   "metadata": {},
   "source": [
    "### Load autoencoders for linear probe."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52ff3e64-1584-4af4-9bbc-25dd78a0a3b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SparseAutoencoder(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, l1_coef):\n",
    "        super(SparseAutoencoder, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.input_size = input_size\n",
    "\n",
    "        self.kwargs = {'input_size': input_size, 'hidden_size': hidden_size, 'l1_coef': l1_coef}\n",
    "        self.l1_coef = l1_coef\n",
    "\n",
    "        self.encoder_weight = nn.Parameter(torch.randn(hidden_size, input_size))\n",
    "        nn.init.orthogonal_(self.encoder_weight)\n",
    "\n",
    "        self.encoder_bias = nn.Parameter(torch.zeros(self.hidden_size))\n",
    "        self.decoder_bias = nn.Parameter(torch.zeros(input_size))\n",
    "\n",
    "    def forward(self, x):\n",
    "        normalized_encoder_weight = F.normalize(self.encoder_weight, p=2, dim=1)\n",
    "\n",
    "        features = F.linear(x, normalized_encoder_weight, self.encoder_bias)\n",
    "        features = F.relu(features)\n",
    "\n",
    "        # reconstruction = F.linear(features, normalized_encoder_weight.t(), self.decoder_bias)\n",
    "        return features.detach()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65622ee1-3281-45a6-bb29-0f80cc056a91",
   "metadata": {},
   "outputs": [],
   "source": [
    "entity_name = 'nlp_and_interpretability'\n",
    "project_prefix = 'Autoencoder_training'\n",
    "artifact_prefix = 'autoencoders'\n",
    "\n",
    "def load_autoencoders_for_artifact(policy_model_name, alias='latest', run=run):\n",
    "    '''\n",
    "    Loads the autoencoders from one run into memory. Note that these paths are to some extent hardcoded\n",
    "    For example, try autoencoders_dict = load_autoencoders_for_artifact('pythia_70m_sentiment_reward')\n",
    "    '''\n",
    "    simplified_policy_model_name = policy_model_name.split('/')[-1].replace('-', '_')\n",
    "    full_path = f'{entity_name}/{project_prefix}_{simplified_policy_model_name}/{artifact_prefix}_{simplified_policy_model_name}:{alias}'\n",
    "    print(f'Loading artifact from {full_path}')\n",
    "\n",
    "    artifact = run.use_artifact(full_path)\n",
    "    directory = artifact.download()\n",
    "\n",
    "    save_dir = f'{directory}/saves'\n",
    "    autoencoders_base_big = load_models_from_folder(load_dir=f'{save_dir}/base_big', given_device='cpu')\n",
    "    autoencoders_base_small = load_models_from_folder(load_dir=f'{save_dir}/base_small', given_device='cpu')\n",
    "    autoencoders_rlhf_big = load_models_from_folder(load_dir=f'{save_dir}/rlhf_big', given_device='cpu')\n",
    "    autoencoders_rlhf_small = load_models_from_folder(load_dir=f'{save_dir}/rlhf_small', given_device='cpu')\n",
    "\n",
    "    return {\n",
    "        'base_big': autoencoders_base_big, 'base_small': autoencoders_base_small,\n",
    "        'rlhf_big': autoencoders_rlhf_big, 'rlhf_small': autoencoders_rlhf_small\n",
    "    }\n",
    "\n",
    "def load_models_from_folder(load_dir, given_device=None):\n",
    "    \"\"\"\n",
    "    Load PyTorch models from subfolders of a directory into a dictionary where keys are subfolder names.\n",
    "\n",
    "    Args:\n",
    "        load_dir (str): The directory from which models will be loaded.\n",
    "\n",
    "    Returns:\n",
    "        model_dict (dict): A dictionary where keys are subfolder names and values are PyTorch models.\n",
    "    \"\"\"\n",
    "    model_dict = {}\n",
    "\n",
    "    device = given_device if given_device else torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    for model_name in sorted(os.listdir(load_dir)):\n",
    "        model_path = os.path.join(load_dir, model_name)\n",
    "\n",
    "        kwargs, state = torch.load(model_path, map_location=device)\n",
    "\n",
    "        model = SparseAutoencoder(**kwargs)\n",
    "        model.load_state_dict(state)\n",
    "        model.to(device)\n",
    "        model.cuda()\n",
    "        model.eval()\n",
    "\n",
    "        model_dict[model_name] = model\n",
    "        print(f\"Loaded {model_name} from {model_path}\")\n",
    "\n",
    "    return model_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f914c2d1-5726-4a89-b98e-4e73cbe618fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "autoencoders_dictionaries = load_autoencoders_for_artifact(policy_model_name=policy_model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77e4d581-de44-4b8d-865c-63a95d93fe4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ModelCustomizer:\n",
    "    '''\n",
    "    Used to customize model layer numbers and other network parsing details\n",
    "    '''\n",
    "\n",
    "    def __init__(self):\n",
    "        '''\n",
    "        Initialize\n",
    "        '''\n",
    "        self.target_layers = None\n",
    "\n",
    "    def set_target_layers(self) -> list[str]:\n",
    "        '''\n",
    "        Set target layers\n",
    "        '''\n",
    "\n",
    "    def get_target_layers(self) -> list[str]:\n",
    "        '''\n",
    "        Get target layers.\n",
    "        '''\n",
    "\n",
    "    def parse_layer_name_to_layer_number(self, layer_name) -> str:\n",
    "        '''\n",
    "        Parse layer name to layer number\n",
    "        '''\n",
    "\n",
    "    def convert_ae_dict_keys(self, autoencoders_dict: [str, Tensor]):\n",
    "        '''\n",
    "        Parse ae dict keys to full layer names.\n",
    "        '''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "413ca384-19aa-4a72-ad37-7c9120adcc48",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GPTNeoCustomizer(ModelCustomizer):\n",
    "\n",
    "    def get_target_layers(self) -> list[str]:\n",
    "        if self.target_layers:\n",
    "            return self.target_layers\n",
    "        else:\n",
    "            return [self.layer_num_to_full_name(layer_no) for layer_no in range(12)]\n",
    "\n",
    "    def set_target_layers(self, target_layers):\n",
    "        self.target_layers = target_layers\n",
    "\n",
    "    def layer_num_to_full_name(self, layer_no):\n",
    "        return f'transformer.h.{layer_no}.mlp'\n",
    "\n",
    "    def parse_layer_name_to_layer_number(self, layer_name) -> str:\n",
    "        return layer_name.split('.')[-2]\n",
    "\n",
    "    # Standardize layer names to full names instead of 'int'\n",
    "    def convert_ae_dict_keys(self, autoencoders_dict: [str, Tensor]):\n",
    "        output_dict = {}\n",
    "        for key, autoencoder in autoencoders_dict.items():\n",
    "            output_dict[self.layer_num_to_full_name(key)] = autoencoder\n",
    "        return output_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5083034d-036b-47ae-beb3-f07d31002780",
   "metadata": {},
   "outputs": [],
   "source": [
    "class PythiaCustomizer(ModelCustomizer):\n",
    "    def __init__(self, num_layers):\n",
    "        super().__init__()\n",
    "        self.num_layers = num_layers\n",
    "        self.target_layers = None\n",
    "\n",
    "    def get_target_layers(self) -> list[str]:\n",
    "        if self.target_layers:\n",
    "            return self.target_layers\n",
    "        else:\n",
    "            return [self.layer_num_to_full_name(layer_no) for layer_no in range(self.num_layers)]\n",
    "\n",
    "    def set_target_layers(self, target_layers):\n",
    "        self.target_layers = target_layers\n",
    "\n",
    "    def layer_num_to_full_name(self, layer_no):\n",
    "        return f'gpt_neox.layers.{layer_no}.mlp'\n",
    "\n",
    "    def parse_layer_name_to_layer_number(self, layer_name) -> str:\n",
    "        return layer_name.split('.')[-2]\n",
    "\n",
    "    # Standardize layer names to full names instead of 'int'\n",
    "    def convert_ae_dict_keys(self, autoencoders_dict: [str, Tensor]):\n",
    "        output_dict = {}\n",
    "        for key, autoencoder in autoencoders_dict.items():\n",
    "            output_dict[self.layer_num_to_full_name(key)] = autoencoder\n",
    "        return output_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7946cef0-ef7b-4588-9d07-4a69c2dfb1b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_customizers = {\n",
    "    \"pythia-70m\": PythiaCustomizer(num_layers=6),\n",
    "    \"pythia-160m\": PythiaCustomizer(num_layers=12),\n",
    "    \"pythia-410m\": PythiaCustomizer(num_layers=24),\n",
    "    \"gpt-neo-125m\": GPTNeoCustomizer()\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c52f42f5-bb54-4efa-98b7-9a180d715c27",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_customizer = model_customizers[model_name]\n",
    "model_target_layers = model_customizer.get_target_layers()\n",
    "model_target_layers"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "747dae15-10ca-4a46-9c2d-b294baba4c0f",
   "metadata": {},
   "source": [
    "**Convert layer numbers to full layer names.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8f8b072-77c2-4ff8-aac3-f4c8d5109e16",
   "metadata": {},
   "outputs": [],
   "source": [
    "mapped_dictionaries = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8993ba7e-fa07-450b-a7b6-234354b9568e",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key, ae_dict in autoencoders_dictionaries.items():\n",
    "    mapped_dictionaries[key] = model_customizer.convert_ae_dict_keys(ae_dict)\n",
    "\n",
    "autoencoders_dictionaries = mapped_dictionaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e47c5055-16ac-4963-807a-517e0e0039aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "rlhf_small = autoencoders_dictionaries['rlhf_small']\n",
    "rlhf_big = autoencoders_dictionaries['rlhf_big']\n",
    "\n",
    "model_customizer.set_target_layers(list(rlhf_small.keys()))\n",
    "model_customizer.get_target_layers()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43c54d67-22d7-4662-8752-9a6f20236e43",
   "metadata": {},
   "source": [
    "### Extract Activations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a8a82ad-bc28-47c8-8bd5-3ff5657e5795",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ActivationsHook:\n",
    "    def __init__(self):\n",
    "        self.activations = []\n",
    "\n",
    "    def clear_activations(self):\n",
    "        for tensor in self.activations:\n",
    "            tensor = tensor.detach().cpu()\n",
    "        self.activations.clear()\n",
    "        self.activations = []\n",
    "\n",
    "    def hook_fn(self, module, input, output):\n",
    "        new_activations = torch.split(output.detach().cpu(), 1, dim=0)\n",
    "        self.activations.extend(new_activations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "285f2c36-5214-40a4-aba6-f6fb67199221",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ActivationsExtractor:\n",
    "    def __init__(self, model, tokenizer, target_layers):\n",
    "        self.model = model\n",
    "        self.target_layers = target_layers\n",
    "        self.tokenizer = tokenizer\n",
    "\n",
    "        # Create an instance of ActivationHook\n",
    "        self.activation_hooks = {}\n",
    "        \n",
    "        for layer_name in self.target_layers:\n",
    "            activation_hook = ActivationsHook()\n",
    "            self.activation_hooks[layer_name] = activation_hook\n",
    "            layer = dict(model.named_modules())[layer_name]\n",
    "            # Register the forward hook to the chosen layer\n",
    "            hook_handle = layer.register_forward_hook(activation_hook.hook_fn)\n",
    "\n",
    "    def clear_all_activations(self):\n",
    "        for layer_name, activation_hook in self.activation_hooks.items():\n",
    "            activation_hook.clear_activations()\n",
    "\n",
    "    def get_activations(self):\n",
    "        \"\"\"\n",
    "        Retrieve all the cached activations.\n",
    "        \"\"\"\n",
    "        return {\n",
    "            layer_name: activation_hook.activations for layer_name, activation_hook in self.activation_hooks.items()\n",
    "        }\n",
    "\n",
    "    def compute_activations_from_raw_texts(self, raw_texts: str):\n",
    "        self.clear_all_activations()\n",
    "\n",
    "        # Forward pass your input through the model\n",
    "        for text_batch in batch(texts):\n",
    "            input_data = self.tokenizer(text_batch, return_tensors='pt', padding=True)  # Example input shape\n",
    "            with torch.no_grad():\n",
    "                output = self.model(**input_data)\n",
    "\n",
    "        return self.get_activations()\n",
    "\n",
    "    def _flatten_activations(self, final_activations, num_samples):\n",
    "        flattened_activations = []\n",
    "        for i in range(num_samples):\n",
    "            current_activations = {}\n",
    "            for layer_name, activations_list in final_activations.items():\n",
    "                current_activations[layer_name] = [activations_list[i]]\n",
    "    \n",
    "            flattened_activations.append(current_activations)\n",
    "        \n",
    "        return flattened_activations\n",
    "\n",
    "    def compute_activations_from_text_tokens_ids_target(\n",
    "        self, samples: list[TextTokensIdsTarget], target_token_only=True, flatten=True\n",
    "    ):\n",
    "        self.clear_all_activations()\n",
    "\n",
    "        # Forward pass your input through the model\n",
    "        for text_batch in batch(samples):\n",
    "            tensorized = TextTokensIdsTarget.get_tensorized(text_batch)\n",
    "\n",
    "            with torch.no_grad():\n",
    "                output = self.model(**tensorized)\n",
    "\n",
    "        all_activations = self.get_activations()\n",
    "        activations_per_layer = [len(value) for value in all_activations.values()]\n",
    "\n",
    "        assert max(activations_per_layer) == min(activations_per_layer) == len(samples), 'Each layer should have num_samples activations'\n",
    "\n",
    "        if target_token_only:\n",
    "            all_target_token_activations = {layer_num: [] for layer_num in all_activations}\n",
    "            for layer_num, layer_activations in all_activations.items():\n",
    "                assert len(layer_activations) == len(samples), \"Each layer should have same activations as num samples!\"\n",
    "\n",
    "                zipped_layer_activations_and_samples = zip(layer_activations, samples)\n",
    "                for activations, sample in zipped_layer_activations_and_samples:\n",
    "                    relevant_token_activations = activations[:, sample.target_token_position, :]\n",
    "                    all_target_token_activations[layer_num].append(relevant_token_activations)\n",
    "\n",
    "            final_activations = all_target_token_activations\n",
    "\n",
    "        else:\n",
    "            final_activations = all_activations\n",
    "\n",
    "        if flatten:\n",
    "            final_activations = self._flatten_activations(final_activations, num_samples=len(samples))\n",
    "            \n",
    "        else:\n",
    "            print(f'Returning a dictionary mapping layer name to list of activations')\n",
    "\n",
    "        return final_activations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d6f718e-b4f0-44b1-aa93-e526a7618e74",
   "metadata": {},
   "outputs": [],
   "source": [
    "extractor = ActivationsExtractor(model=model, target_layers=model_customizer.get_target_layers(), tokenizer=tokenizer)\n",
    "\n",
    "sample_positives = [point.trimmed_positive_example for point in sample_training_points]\n",
    "\n",
    "sample_target_token_activations = extractor.compute_activations_from_text_tokens_ids_target(sample_positives)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c6fed50-6ee6-4bc3-abce-08f91fed8166",
   "metadata": {},
   "outputs": [],
   "source": [
    "class AutoencoderManager:\n",
    "    def __init__(self, model, tokenizer, autoencoders_dict):\n",
    "        self.model = model\n",
    "        self.tokenizer = tokenizer\n",
    "        self.autoencoders_dict = autoencoders_dict\n",
    "\n",
    "    def get_dictionary_features(self, activations, layer_name):\n",
    "        \"\"\"\n",
    "        Returns raw dictionary features for activations at a layer number.\n",
    "        \"\"\"\n",
    "        with torch.no_grad():\n",
    "            features = self.autoencoders_dict[layer_name](activations.cuda())\n",
    "            return features\n",
    "\n",
    "    def get_all_dictionary_features_for_list(self, activations_dict_list: list[dict[str, list[Tensor]]]):\n",
    "        return [self.get_all_dictionary_features_for_point(point) for point in activations_dict_list]\n",
    "        \n",
    "    def get_all_dictionary_features_for_point(self, activations_dict: dict[str, list[Tensor]]):\n",
    "        all_features = {}\n",
    "        for layer_name, autoencoder in self.autoencoders_dict.items():\n",
    "            activations = activations_dict[layer_name]\n",
    "            assert len(activations) == 1, \"Can only do conversion for single elements right now\"\n",
    "            curr_dict_features = self.get_dictionary_features(activations[0], layer_name)[0].tolist()\n",
    "            all_features[layer_name] = csr_matrix(curr_dict_features)\n",
    "        return all_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e6c9eb2-341b-40c6-a047-8efef12ea4d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "autoencoder_manager = AutoencoderManager(\n",
    "    model=model, tokenizer=tokenizer, autoencoders_dict=rlhf_small\n",
    ")\n",
    "sample_features = autoencoder_manager.get_all_dictionary_features_for_list(\n",
    "    activations_dict_list = sample_target_token_activations\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c004d5c9-8101-486d-8cb8-2215715e93f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LinearProbeTrainingPoint:\n",
    "    def __init__(\n",
    "        self, training_point: TrainingPoint,\n",
    "        # positive token\n",
    "        target_positive_token_id: int,\n",
    "        target_positive_token: str,\n",
    "        positive_activations: [str, Tensor],   # dictionary of layer_num to positive token activations\n",
    "        positive_token_ae_features: [str, Tensor], \n",
    "        # negative token\n",
    "        target_negative_token_id: int,\n",
    "        target_negative_token: str,\n",
    "        negative_activations: [str, Tensor],  # dictionary of layer_num to negative token activations\n",
    "        negative_token_ae_features: [str, Tensor],\n",
    "        # neutral token\n",
    "        target_neutral_token_id: int,\n",
    "        target_neutral_token: str,\n",
    "        neutral_activations: [str, Tensor],   # dictionary of layer_num to neutral activations\n",
    "        neutral_token_ae_features: [str, Tensor]\n",
    "    ):\n",
    "        self.training_point: TrainingPoint = training_point\n",
    "\n",
    "        self.target_positive_token = target_positive_token\n",
    "        self.target_positive_token_id = target_positive_token_id\n",
    "        self.target_positive_reward = self.training_point.target_positive_reward\n",
    "        self.positive_token_ae_features = positive_token_ae_features\n",
    "        self.positive_activations = positive_activations\n",
    "\n",
    "        self.target_negative_token = target_negative_token\n",
    "        self.target_negative_token_id = target_negative_token_id\n",
    "        self.target_negative_reward = self.training_point.target_negative_reward\n",
    "        self.negative_token_ae_features = negative_token_ae_features\n",
    "        self.negative_activations = negative_activations\n",
    "\n",
    "        self.target_neutral_token = target_neutral_token\n",
    "        self.target_neutral_token_id = target_neutral_token_id\n",
    "        self.neutral_token_ae_features = neutral_token_ae_features\n",
    "        self.neutral_activations = neutral_activations\n",
    "\n",
    "    def __str__(self):\n",
    "        return pprint.pformat(self.__dict__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd423a28-2889-4cae-8ef8-7959ea6da95a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LinearProbeTrainingDataManager:\n",
    "    \"\"\"\n",
    "    This takes all the sample training data, and calculates all activations for all training samples\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self, training_data: list[TrainingPoint], autoencoders_dict,\n",
    "        target_layers: list[str], model=model, tokenizer=tokenizer,\n",
    "    ):\n",
    "        self.activations_extractor = ActivationsExtractor(\n",
    "            model=model, tokenizer=tokenizer, target_layers=target_layers)\n",
    "        self.autoencoders_dict = autoencoders_dict\n",
    "        self.autoencoder_manager = AutoencoderManager(\n",
    "            model=model, tokenizer=tokenizer, autoencoders_dict=autoencoders_dict\n",
    "        )\n",
    "        self.training_data = training_data\n",
    "\n",
    "    def compute_training_points_single_batch(self, single_batch: list[TrainingPoint]):\n",
    "        positive_samples = [item.trimmed_positive_example for item in single_batch]\n",
    "        negative_samples = [item.trimmed_negative_example for item in single_batch]\n",
    "        neutral_samples = [item.trimmed_neutral_example for item in single_batch]\n",
    "    \n",
    "        positive_activations_list = self.activations_extractor.compute_activations_from_text_tokens_ids_target(\n",
    "            positive_samples, target_token_only=True, flatten=True\n",
    "        )\n",
    "        negative_activations_list = self.activations_extractor.compute_activations_from_text_tokens_ids_target(\n",
    "            negative_samples, target_token_only=True, flatten=True\n",
    "        )\n",
    "\n",
    "        neutral_activations_list = self.activations_extractor.compute_activations_from_text_tokens_ids_target(\n",
    "            neutral_samples, target_token_only=True, flatten=True\n",
    "        )\n",
    "\n",
    "        positive_dictionary_features = self.autoencoder_manager.get_all_dictionary_features_for_list(positive_activations_list)\n",
    "        negative_dictionary_features = self.autoencoder_manager.get_all_dictionary_features_for_list(negative_activations_list)\n",
    "        neutral_dictionary_features = self.autoencoder_manager.get_all_dictionary_features_for_list(neutral_activations_list)\n",
    "\n",
    "        assert (\n",
    "            len(positive_samples) == len(negative_samples) == len(neutral_samples) == \n",
    "            len(positive_activations_list) == len(negative_activations_list) ==\n",
    "            len(positive_dictionary_features) == len(negative_dictionary_features) == len(neutral_dictionary_features)\n",
    "        ), \"All samples, activations and dict features should align in length.\"\n",
    "\n",
    "        linear_probe_training_batch = []\n",
    "        zipped_point_and_features = zip(\n",
    "            single_batch, positive_dictionary_features, positive_activations_list, \n",
    "            negative_dictionary_features, negative_activations_list, \n",
    "            neutral_dictionary_features, neutral_activations_list\n",
    "        )\n",
    "\n",
    "        for training_point, positive_features, positive_activations, negative_features, negative_activations, neutral_features, neutral_activations in zipped_point_and_features:\n",
    "            linear_probe_training_point = LinearProbeTrainingPoint(\n",
    "                training_point=training_point,\n",
    "                positive_token_ae_features=positive_features,\n",
    "                positive_activations=positive_activations,\n",
    "                negative_token_ae_features=negative_features,\n",
    "                negative_activations=negative_activations,\n",
    "                neutral_token_ae_features=neutral_features,\n",
    "                neutral_activations=neutral_activations,\n",
    "                # Positive token\n",
    "                target_positive_token_id=training_point.target_positive_token_id,\n",
    "                target_positive_token=training_point.target_positive_token,\n",
    "                # Negative token\n",
    "                target_negative_token_id=training_point.target_negative_token_id,\n",
    "                target_negative_token=training_point.target_negative_token,\n",
    "                # Neutral token\n",
    "                target_neutral_token_id=training_point.target_neutral_token_id,\n",
    "                target_neutral_token=training_point.target_neutral_token,\n",
    "            )\n",
    "            linear_probe_training_batch.append(linear_probe_training_point)\n",
    "\n",
    "        return linear_probe_training_batch\n",
    "\n",
    "    def construct_training_dataset(self, all_training_data: list[TrainingPoint], option: str = None):\n",
    "        \"\"\"\n",
    "        Constructs final training dataset, consisting of input and expected output. \n",
    "        \"\"\"\n",
    "        all_results = []\n",
    "        for index, single_batch in tqdm_notebook(enumerate(batch(all_training_data))):\n",
    "            if index % 250 == 0:\n",
    "                print(f'Clearing cuda cache on batch {index}')\n",
    "                clear_gpu_memory()\n",
    "    \n",
    "            current_results = self.compute_training_points_single_batch(single_batch)\n",
    "            all_results.extend(current_results)\n",
    "\n",
    "        assert len(all_results) == len(all_training_data), \"Not all training points were converted to probe inputs!\"\n",
    "\n",
    "        return all_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d897bdb9-3af1-4293-94c7-b4a9efede36a",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_training_data = successful_training_points[:5]\n",
    "\n",
    "lp_training_data_manager = LinearProbeTrainingDataManager(\n",
    "    training_data=sample_training_data, autoencoders_dict=rlhf_small,\n",
    "    model=model, tokenizer=tokenizer, target_layers = model_customizer.get_target_layers()    \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d047d5e-692e-4bcd-b645-a8dbff5db40b",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_probe_inputs = lp_training_data_manager.construct_training_dataset(sample_training_data)\n",
    "x =[input_point.target_positive_token_id for input_point in sample_probe_inputs]\n",
    "tokenizer.decode(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3b7330a-8273-4c0a-8e91-b388ac543eba",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "full_training_dataset = lp_training_data_manager.construct_training_dataset(successful_training_points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbd01613-7581-41a2-b782-f68b71780743",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(full_training_dataset[10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f43506f8-f2c8-4294-ad1a-ff365bc0a312",
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_training_dataset_to_wandb(training_dataset: list[LinearProbeTrainingPoint]):\n",
    "    out_filename = \"training_dataset.pkl\"\n",
    "\n",
    "    with open(out_filename, \"wb\") as f_out:\n",
    "        pickle.dump(training_dataset, f_out)\n",
    "    \n",
    "    my_artifact = wandb.Artifact(f\"linear_probe_training_dataset_{policy_model_name}\", type=\"data\")\n",
    "    \n",
    "    # Add the list to the artifact\n",
    "    my_artifact.add_file(local_path=out_filename, name=\"linear_probe_training_dataset\")\n",
    "\n",
    "    metadata_dict = {\n",
    "        \"description\": \"Training dataset, with activations and rewards\",\n",
    "        \"source\": \"Generated by my script\",\n",
    "        \"num_examples\": len(training_dataset),\n",
    "        \"split\": \"full\"\n",
    "    }\n",
    "\n",
    "    my_artifact.metadata.update(metadata_dict)\n",
    "\n",
    "    # Log the artifact to the run\n",
    "    wandb.log_artifact(my_artifact)\n",
    "\n",
    "save_training_dataset_to_wandb(full_training_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e6b7727-f8c7-4270-8e36-e1856b5f8ca2",
   "metadata": {},
   "source": [
    "### Upload high value features to wandb. (Used for adhoc uploads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e84f4d0-8fc5-4f88-b813-232e03ea43d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt_neo_125m_features = [\n",
    "    ('10', 732), ('10', 131), ('10', 56), ('10', 600), ('10', 651), ('10', 433),\n",
    "    ('10', 648), ('10', 613), ('10', 391), ('10', 273), ('10', 84), ('10', 106),\n",
    "    ('10', 729), ('10', 405), ('10', 370), ('10', 59), ('10', 23), ('10', 736),\n",
    "    ('10', 417), ('10', 69), ('10', 401), ('10', 738), ('10', 389), ('11', 573),\n",
    "    ('11', 295), ('11', 276), ('11', 274), ('11', 95), ('11', 653), ('11', 598),\n",
    "    ('11', 469), ('11', 397), ('11', 635), ('11', 124), ('11', 661), ('11', 498),\n",
    "    ('11', 201), ('11', 135), ('11', 85), ('11', 401), ('11', 651), ('11', 119),\n",
    "    ('11', 506), ('11', 224), ('11', 288), ('11', 455), ('11', 24), ('11', 533),\n",
    "    ('11', 346), ('11', 33), ('7', 250), ('7', 389), ('7', 587), ('7', 134),\n",
    "    ('7', 332), ('7', 123), ('7', 489), ('7', 435), ('7', 602), ('7', 574),\n",
    "    ('7', 753), ('7', 68), ('7', 408), ('7', 36), ('7', 124), ('7', 301), ('7', 12),\n",
    "    ('7', 333), ('7', 223), ('7', 434), ('7', 122), ('7', 588), ('7', 335),\n",
    "    ('8', 732), ('8', 345), ('8', 400), ('8', 214), ('8', 348), ('8', 447), ('8', 541),\n",
    "    ('8', 155), ('8', 172), ('8', 156), ('8', 658), ('8', 463), ('8', 507), ('8', 735),\n",
    "    ('8', 551), ('8', 635), ('8', 434), ('8', 146), ('8', 662), ('8', 653), ('8', 743),\n",
    "    ('8', 566), ('8', 380), ('8', 505), ('9', 566), ('9', 423), ('9', 494), ('9', 48),\n",
    "    ('9', 426), ('9', 653), ('9', 457), ('9', 385), ('9', 23), ('9', 421), ('9', 572),\n",
    "    ('9', 3), ('9', 649), ('9', 678), ('9', 11), ('9', 84), ('9', 717), ('9', 429),\n",
    "    ('9', 356), ('9', 404)\n",
    "]\n",
    "\n",
    "gpt_neo_125m_version = 'v3'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5749212c-6bec-4b19-8446-00783d2f1a08",
   "metadata": {},
   "outputs": [],
   "source": [
    "pythia_410m_features = [\n",
    "    ('16', 759), ('16', 661), ('16', 120), ('16', 154), ('16', 551), ('16', 651), \n",
    "    ('16', 923), ('16', 801), ('16', 380), ('16', 480), ('16', 705), ('16', 825), \n",
    "    ('16', 166), ('16', 750), ('16', 694), ('16', 140), ('16', 261), ('16', 866), \n",
    "    ('16', 571), ('16', 469), ('16', 691), ('16', 852), ('16', 966), ('17', 634), \n",
    "    ('17', 981), ('17', 994), ('17', 471), ('17', 761), ('17', 726), ('17', 21), \n",
    "    ('17', 1002), ('17', 605), ('17', 68), ('17', 16), ('17', 466), ('17', 185), \n",
    "    ('17', 413), ('17', 97), ('17', 238), ('17', 522), ('17', 518), ('17', 629), \n",
    "    ('17', 860), ('17', 213), ('17', 41), ('21', 314), ('21', 884), ('21', 911), \n",
    "    ('21', 876), ('21', 897), ('21', 416), ('21', 424), ('21', 829), ('21', 246), \n",
    "    ('21', 192), ('21', 850), ('21', 210), ('21', 627), ('21', 174), ('21', 927), \n",
    "    ('21', 28), ('21', 631), ('21', 781), ('21', 806), ('21', 628), ('21', 115), \n",
    "    ('22', 391), ('22', 329), ('22', 759), ('22', 349), ('22', 946), ('22', 819), \n",
    "    ('22', 289), ('22', 287), ('22', 1012), ('22', 318), ('22', 138), ('22', 267), \n",
    "    ('22', 430), ('22', 159), ('22', 276), ('22', 197), ('22', 632), ('22', 1015), \n",
    "    ('22', 63), ('22', 701), ('22', 200), ('22', 715), ('22', 964), ('22', 563), \n",
    "    ('23', 139), ('23', 626), ('23', 899), ('23', 738), ('23', 370), ('23', 37), \n",
    "    ('23', 749), ('23', 487), ('23', 826), ('23', 621), ('23', 213), ('23', 552), \n",
    "    ('23', 1013), ('23', 321), ('23', 635), ('23', 500), ('23', 215), ('23', 1000), \n",
    "    ('23', 1023), ('23', 478), ('23', 581), ('23', 947), ('23', 1022), ('23', 577)\n",
    "]\n",
    "\n",
    "pythia_410m_version = 'v2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4d0fd13-2bd0-4aa2-b8ef-4cde3062c55c",
   "metadata": {},
   "outputs": [],
   "source": [
    "pythia_160m_features = [\n",
    "    ('10', 169), ('10', 507), ('10', 248), ('10', 215), ('10', 98), ('10', 145),\n",
    "    ('10', 40), ('10', 714), ('10', 241), ('10', 541), ('10', 445), ('10', 315),\n",
    "    ('10', 251), ('10', 116), ('11', 185), ('11', 261), ('11', 410), ('11', 207),\n",
    "    ('11', 575), ('11', 198), ('11', 331), ('11', 212), ('11', 590), ('11', 99),\n",
    "    ('11', 502), ('11', 471), ('11', 754), ('11', 218), ('11', 492), ('11', 55),\n",
    "    ('11', 513), ('11', 70), ('7', 415), ('7', 644), ('7', 546), ('7', 52),\n",
    "    ('7', 364), ('7', 260), ('7', 290), ('7', 472), ('7', 429), ('7', 123),\n",
    "    ('7', 61), ('7', 43), ('7', 387), ('7', 236), ('7', 469), ('7', 15), ('7', 501),\n",
    "    ('7', 379), ('8', 427), ('8', 284), ('8', 575), ('8', 498), ('8', 403),\n",
    "    ('8', 410), ('8', 148), ('8', 680), ('8', 144), ('8', 516), ('8', 670),\n",
    "    ('8', 102), ('8', 69), ('8', 260), ('9', 664), ('9', 556), ('9', 542),\n",
    "    ('9', 560), ('9', 158), ('9', 268), ('9', 70), ('9', 547), ('9', 569),\n",
    "    ('9', 193), ('9', 546), ('9', 589), ('9', 16), ('9', 583), ('9', 411),\n",
    "    ('9', 186), ('9', 634)\n",
    "]\n",
    "\n",
    "pythia_160m_version = 'v4'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b79d755-3bf9-4fbf-829b-adabd5a6810f",
   "metadata": {},
   "outputs": [],
   "source": [
    "pythia_70m_features = [\n",
    "    ('1', 45), ('1', 314), ('1', 161), ('1', 43), ('1', 254), ('1', 391),\n",
    "    ('1', 422), ('1', 420), ('1', 482), ('1', 127), ('1', 162), ('1', 193),\n",
    "    ('2', 183), ('2', 164), ('2', 291), ('2', 380), ('2', 97), ('2', 415),\n",
    "    ('2', 269), ('2', 229), ('2', 220), ('2', 457), ('2', 129), ('2', 96),\n",
    "    ('3', 23), ('3', 252), ('3', 255), ('3', 104), ('3', 379), ('3', 141),\n",
    "    ('3', 170), ('3', 128), ('3', 117), ('3', 244), ('3', 93), ('3', 130),\n",
    "    ('4', 161), ('4', 22), ('4', 303), ('4', 119), ('4', 404), ('4', 368),\n",
    "    ('4', 301), ('4', 96), ('4', 23), ('5', 9), ('5', 75), ('5', 377),\n",
    "    ('5', 68), ('5', 93), ('5', 381), ('5', 22), ('5', 39), ('5', 189),\n",
    "    ('5', 221), ('5', 231), ('5', 251)\n",
    "]\n",
    "\n",
    "pythia_70m_version = 'v11'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bad08658-4ec2-4623-87a6-cac12b21ed16",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_high_value_features_artifact = {\n",
    "    \"pythia_70m\": (pythia_70m_version, pythia_70m_features),\n",
    "    \"pythia_160m\": (pythia_160m_version, pythia_160m_features),\n",
    "    \"pythia_410m\": (pythia_410m_version, pythia_410m_features),\n",
    "    \"gpt_neo_125m\": (gpt_neo_125m_version, gpt_neo_125m_features)\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b802d444-2f69-44c0-8111-fba3c00e5bc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_high_value_features_dataset_to_wandb(output_high_value_features_artifact):\n",
    "    out_filename = \"high_value_features_dataset.json\"\n",
    "\n",
    "    with open(out_filename, \"w\") as f_out:\n",
    "        json.dump(output_high_value_features_artifact, f_out)\n",
    "    \n",
    "    my_artifact = wandb.Artifact(\"high_value_features_artifact\", type=\"data\")\n",
    "    \n",
    "    # Add the list to the artifact\n",
    "    my_artifact.add_file(local_path=out_filename, name=\"high_value_features_artifact\")\n",
    "\n",
    "    metadata_dict = {\n",
    "        \"description\": \"High value features for pythia70m, pythia160m, pythia410m, and gpt_neo_125m. Includes source versions for these.\",\n",
    "        \"source\": \"Generated by Marc's experiments using GPT-4\"\n",
    "    }\n",
    "\n",
    "    my_artifact.metadata.update(metadata_dict)\n",
    "\n",
    "    # Log the artifact to the run\n",
    "    wandb.log_artifact(my_artifact)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9fc2b22-93e8-47a6-af7c-329316baaeff",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_high_value_features_dataset_to_wandb(output_high_value_features_artifact)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
