{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af0013fe-1e8c-463a-a0c9-dcf828e8d858",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import pickle\n",
    "import pprint\n",
    "import random\n",
    "\n",
    "from collections import Counter, defaultdict\n",
    "from dataclasses import dataclass\n",
    "from nltk.sentiment.vader import SentimentIntensityAnalyzer\n",
    "\n",
    "import nltk\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import torch\n",
    "import wandb\n",
    "\n",
    "from matplotlib import pyplot\n",
    "from scipy.sparse import csr_matrix, vstack\n",
    "from scipy.stats import kendalltau\n",
    "from sklearn.linear_model import Ridge\n",
    "\n",
    "from torch import Tensor\n",
    "from tqdm import tqdm_notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3cc0499-e309-4274-9083-7971ddef23d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "nltk.download('vader_lexicon')\n",
    "sentiment_analyzer = SentimentIntensityAnalyzer()\n",
    "lexicon = sentiment_analyzer.lexicon\n",
    "\n",
    "min_vader_value = min(lexicon.values())\n",
    "max_vader_value = max(lexicon.values())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fdf6b59-1d37-48f1-b374-840c38ea254b",
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = 'linear_probe_training_dataset'\n",
    "model_name = 'pythia_70m'\n",
    "policy_model_name = f'{model_name}_utility_reward'\n",
    "project_name = 'utility_reconstruction'\n",
    "\n",
    "versions_dict = {\"gpt_neo_125m\": 'v1'}\n",
    "version = versions_dict.get(model_name, 'v0')\n",
    "random_seed = 42\n",
    "\n",
    "os.environ['WANDB_API_KEY'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f241eab9-d377-4240-8ed0-3292ef6e044b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ModelCustomizer:\n",
    "    '''\n",
    "    Used to customize model layer numbers and other network parsing details\n",
    "    '''\n",
    "\n",
    "    def __init__(self):\n",
    "        '''\n",
    "        Initialize\n",
    "        '''\n",
    "        self.target_layers = None\n",
    "\n",
    "    def set_target_layers(self) -> list[str]:\n",
    "        '''\n",
    "        Set target layers\n",
    "        '''\n",
    "\n",
    "    def get_target_layers(self) -> list[str]:\n",
    "        '''\n",
    "        Get target layers.\n",
    "        '''\n",
    "\n",
    "    def parse_layer_name_to_layer_number(self, layer_name) -> str:\n",
    "        '''\n",
    "        Parse layer name to layer number\n",
    "        '''\n",
    "\n",
    "    def convert_ae_dict_keys(self, autoencoders_dict: [str, Tensor]):\n",
    "        '''\n",
    "        Parse ae dict keys to full layer names.\n",
    "        '''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68841422-abc0-4c44-8f32-cefa68bac305",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GPTNeoCustomizer(ModelCustomizer):\n",
    "\n",
    "    def get_target_layers(self) -> list[str]:\n",
    "        if self.target_layers:\n",
    "            return self.target_layers\n",
    "        else:\n",
    "            return [self.layer_num_to_full_name(layer_no) for layer_no in range(12)]\n",
    "\n",
    "    def set_target_layers(self, target_layers):\n",
    "        self.target_layers = target_layers\n",
    "\n",
    "    def layer_num_to_full_name(self, layer_no):\n",
    "        return f'transformer.h.{layer_no}.mlp'\n",
    "\n",
    "    def parse_layer_name_to_layer_number(self, layer_name) -> str:\n",
    "        return layer_name.split('.')[-2]\n",
    "\n",
    "    # Standardize layer names to full names instead of 'int'\n",
    "    def convert_ae_dict_keys(self, autoencoders_dict: [str, Tensor]):\n",
    "        output_dict = {}\n",
    "        for key, autoencoder in autoencoders_dict.items():\n",
    "            output_dict[self.layer_num_to_full_name(key)] = autoencoder\n",
    "        return output_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe52c1d3-6154-44b0-af50-3a7597f8cb9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class PythiaCustomizer(ModelCustomizer):\n",
    "    def __init__(self, num_layers):\n",
    "        super().__init__()\n",
    "        self.num_layers = num_layers\n",
    "        self.target_layers = None\n",
    "\n",
    "    def get_target_layers(self) -> list[str]:\n",
    "        if self.target_layers:\n",
    "            return self.target_layers\n",
    "        else:\n",
    "            return [self.layer_num_to_full_name(layer_no) for layer_no in range(self.num_layers)]\n",
    "\n",
    "    def set_target_layers(self, target_layers):\n",
    "        self.target_layers = target_layers\n",
    "\n",
    "    def layer_num_to_full_name(self, layer_no):\n",
    "        return f'gpt_neox.layers.{layer_no}.mlp'\n",
    "\n",
    "    def parse_layer_name_to_layer_number(self, layer_name) -> str:\n",
    "        return layer_name.split('.')[-2]\n",
    "\n",
    "    # Standardize layer names to full names instead of 'int'\n",
    "    def convert_ae_dict_keys(self, autoencoders_dict: [str, Tensor]):\n",
    "        output_dict = {}\n",
    "        for key, autoencoder in autoencoders_dict.items():\n",
    "            output_dict[self.layer_num_to_full_name(key)] = autoencoder\n",
    "        return output_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91e7dd52-907d-4bfb-9307-0e2509360362",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_customizers = {\n",
    "    \"pythia_70m\": PythiaCustomizer(num_layers=6),\n",
    "    \"pythia_160m\": PythiaCustomizer(num_layers=12),\n",
    "    \"pythia_410m\": PythiaCustomizer(num_layers=24),\n",
    "    \"gpt_neo_125m\": GPTNeoCustomizer()\n",
    "}\n",
    "model_customizer = model_customizers[model_name]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cee2fc8c-70b0-4c0e-ad42-13606a477ddb",
   "metadata": {},
   "source": [
    "### Randomization and other utilities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6b1f21c-46e7-42bc-ae93-6a09bb1895b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def clamp(number, min_value, max_value):\n",
    "    return max(min(number, max_value), min_value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b19642c-bdec-4756-806b-d060f2299ded",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_average_values(list_of_dicts):\n",
    "    # Use defaultdict to simplify code\n",
    "    token_sum = defaultdict(float)\n",
    "    token_count = defaultdict(int)\n",
    "\n",
    "    # Accumulate sums and counts\n",
    "    for d in list_of_dicts:\n",
    "        for token, value in d.items():\n",
    "            token_sum[token] += value\n",
    "            token_count[token] += 1\n",
    "\n",
    "    # Calculate average values using a dictionary comprehension\n",
    "    average_values = {token: round(token_sum[token] / token_count[token], 3) for token in token_sum}\n",
    "\n",
    "    return average_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "decbfacd-2831-4672-b02c-6829a7c44a32",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rescale_value(value, values_list, new_min=min_vader_value, new_max=max_vader_value):\n",
    "    percentile_range = 90\n",
    "\n",
    "    old_max = np.percentile(values_list, percentile_range)\n",
    "    old_min = np.percentile(values_list, 100 - percentile_range)\n",
    "    \n",
    "    # First, normalize the value to a range between 0 and 1\n",
    "    normalized_value = (value - old_min) / (old_max - old_min)\n",
    "    \n",
    "    # Then, scale the normalized value to the new range\n",
    "    new_value = normalized_value * (new_max - new_min) + new_min\n",
    "\n",
    "    new_value = clamp(new_value, new_min, new_max)\n",
    "    \n",
    "    return round(new_value, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e67890f-786f-49b6-a141-b229052f795c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_split_list(lst, split_ratio=0.8, seed=random_seed):\n",
    "    if seed is not None:\n",
    "        random.seed(seed)\n",
    "    \n",
    "    shuffled_list = lst[:]\n",
    "    random.shuffle(shuffled_list)\n",
    "    \n",
    "    split_index = int(len(shuffled_list) * split_ratio)\n",
    "    return shuffled_list[:split_index], shuffled_list[split_index:]\n",
    "\n",
    "random_split_list([1,2,3,4,5,6,7,8,9,10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcdedb5d-0f1d-41f6-893f-fb7bf972ba7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def concantenate_matrices(layer_to_csr_dict):\n",
    "    \"\"\"\n",
    "    Given a dictionary of layername_to_features matrices, this flattens and concatenates\n",
    "    the matrices, in canoncial sorted order of the dictionary keys (the layernames).\n",
    "    \"\"\"\n",
    "    sorted_matrices = [\n",
    "        layer_to_csr_dict[key] for key in sorted(layer_to_csr_dict.keys())\n",
    "    ]\n",
    "    concatenated_matrix = vstack(sorted_matrices)\n",
    "    return concatenated_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc8deb1a-299a-4852-974b-7be3315a8154",
   "metadata": {},
   "outputs": [],
   "source": [
    "def euclidean_distance(matrix1: csr_matrix, matrix2: csr_matrix):\n",
    "    # Convert CSR matrices to dense arrays for cdist\n",
    "    dense_matrix1 = matrix1.toarray().flatten()\n",
    "    dense_matrix2 = matrix2.toarray().flatten()\n",
    "\n",
    "    # Compute Euclidean distance using cdist\n",
    "    distance = np.linalg.norm(dense_matrix1 - dense_matrix2)\n",
    "\n",
    "    return distance\n",
    "\n",
    "def euclidean_distance_bw_dicts_of_csr_matrices(\n",
    "    matrix_dict_1: dict[str, csr_matrix], matrix_dict_2: dict[str, csr_matrix]):\n",
    "\n",
    "    feature_matrix_1 = concantenate_matrices(matrix_dict_1)\n",
    "    feature_matrix_2 = concantenate_matrices(matrix_dict_2)\n",
    "\n",
    "    return euclidean_distance(feature_matrix_1, feature_matrix_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cec088ec-f34c-41a8-a008-bc8d9526d6ce",
   "metadata": {},
   "source": [
    "### Load artifact from wandb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05b19e4f-9fd9-4886-aace-0573ad036fa5",
   "metadata": {},
   "outputs": [],
   "source": [
    "run = wandb.init(project=f'{project_name}_{policy_model_name}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d12e0f28-1948-4248-800a-ebe5d4acc53c",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.run.config['random_seed'] = random_seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c58ec0f-d247-4ae6-beb9-10bcd55ca416",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_linear_probe_training_dataset(policy_model_name=policy_model_name, project_name=project_name, version=version):\n",
    "    artifact_path = f'linear_probe_training_dataset_{policy_model_name}:{version}'\n",
    "    \n",
    "    artifact = run.use_artifact(\n",
    "        f'nlp_and_interpretability/{project_name}/{artifact_path}', type='data'\n",
    "    )\n",
    "    artifact_dir = artifact.download()\n",
    "\n",
    "    with open(f'artifacts/{artifact_path}/{filename}', 'rb') as f_in:\n",
    "        training_dataset = pickle.load(f_in)\n",
    "\n",
    "    return training_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3702f8b9-fda0-4857-ba97-31549bbfe4b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class TextTokensIdsTarget:\n",
    "    attention_mask: list[int]\n",
    "    text: str\n",
    "    tokens: list[str]\n",
    "    ids: list[int]\n",
    "    target_token: str\n",
    "    target_token_id: int\n",
    "    target_token_position: int\n",
    "\n",
    "    @staticmethod\n",
    "    def get_tensorized(datapoints: \"TextTokensIdsTarget\"):\n",
    "        max_length = max([len(datapoint.tokens) for datapoint in datapoints])\n",
    "        \n",
    "        input_ids = [datapoint.ids for datapoint in datapoints]\n",
    "        attention_masks = [datapoint.attention_mask for datapoint in datapoints]\n",
    "\n",
    "        input_ids_padded = pad_list_of_lists(input_ids, tokenizer.encode(tokenizer.pad_token)[0])\n",
    "        attention_masks_padded = pad_list_of_lists(attention_masks, 0)\n",
    "        all_tokenized = {\n",
    "            \"input_ids\": torch.IntTensor(input_ids_padded).cuda(), \"attention_mask\": torch.ByteTensor(attention_masks_padded).cuda()\n",
    "        }\n",
    "        return all_tokenized\n",
    "\n",
    "### Source training point in the wandb artifact\n",
    "class TrainingPoint:\n",
    "\n",
    "    def __init__(self, input_dict: dict, tokenizer=None):\n",
    "        self.input_dict = input_dict\n",
    "        self.positive_text = input_dict['input_text']\n",
    "        self.negative_text = input_dict['output_text']\n",
    "        self.neutral_text = input_dict['neutral_text']\n",
    "        \n",
    "        # Dictionary of layer name to activations by mlp layer.\n",
    "        self.activations: dict = None\n",
    "\n",
    "        # Dictionary of layer name to autoencoder feature by mlp layer\n",
    "        self.autoencoder_feature: dict = None\n",
    "\n",
    "        # Reward value of target_token.\n",
    "        self.target_positive_reward = None\n",
    "        self.target_negative_reward = None\n",
    "\n",
    "        self.positive_text_tokens, self.positive_input_ids = get_tokens_and_ids(self.positive_text)\n",
    "        self.negative_text_tokens, self.negative_token_ids = get_tokens_and_ids(self.negative_text)\n",
    "        \n",
    "        self.positive_words = input_dict['positive_words']\n",
    "        self.negative_words = list(input_dict['new_words'].values())\n",
    "        self.neutral_words = list(input_dict['neutral_words'].values())\n",
    "\n",
    "        self.target_positive_reward = None\n",
    "        self.target_positive_token = None\n",
    "        self.target_positive_token_id = None\n",
    "    \n",
    "        self.target_negative_reward = None\n",
    "        self.target_negative_token = None\n",
    "        self.target_negative_token_id = None\n",
    "\n",
    "        self.target_neutral_token = None\n",
    "        self.target_neutral_token_id = None\n",
    "\n",
    "        try:\n",
    "            self.trimmed_positive_example: \"TextTokensIdTarget\" = trim_example(self.positive_text, self.positive_words)\n",
    "            if self.trimmed_positive_example:\n",
    "                positive_token = self.trimmed_positive_example.target_token.strip().lower()\n",
    "                self.target_positive_reward = lexicon.get(positive_token, None)\n",
    "                self.target_positive_token = positive_token\n",
    "                self.target_positive_token_id = self.trimmed_positive_example.target_token_id\n",
    "        \n",
    "        except Exception as e:\n",
    "            print(f'Caught exception {e} on {input_dict} for positive example.')\n",
    "            self.trimmed_positive_example = None\n",
    "        \n",
    "        try:\n",
    "            self.trimmed_negative_example: \"TextTokensIdTarget\" = trim_example(self.negative_text, self.negative_words)\n",
    "            if self.trimmed_negative_example:\n",
    "                negative_token = self.trimmed_negative_example.target_token.strip().lower()\n",
    "                self.target_negative_reward = lexicon.get(negative_token, None)\n",
    "                self.target_negative_token = negative_token\n",
    "                self.target_negative_token_id = self.trimmed_negative_example.target_token_id\n",
    "\n",
    "        except Exception as e:\n",
    "            print(f'Caught exception {e} on {input_dict} for negative example.')\n",
    "            self.trimmed_negative_example = None\n",
    "\n",
    "        try:\n",
    "            self.trimmed_neutral_example: \"TextTokensIdTarget\" = trim_example(self.neutral_text, self.neutral_words)\n",
    "            if self.trimmed_neutral_example:\n",
    "                self.target_neutral_token = self.trimmed_neutral_example.target_token.strip().lower()\n",
    "                self.target_neutral_token_id = self.trimmed_neutral_example.target_token_id\n",
    "\n",
    "        except Exception as e:\n",
    "            print(f'Caught exception {e} on {input_dict} for neutral example.')\n",
    "            self.trimmed_neutral_example = None\n",
    "\n",
    "    def __str__(self):\n",
    "        return pprint.pformat(self.__dict__)\n",
    "\n",
    "\n",
    "class LinearProbeTrainingPoint:\n",
    "    def __init__(\n",
    "        self, training_point: \"TrainingPoint\",\n",
    "        # positive token\n",
    "        target_positive_token_id: int,\n",
    "        target_positive_token: str,\n",
    "        positive_token_ae_features: [str, Tensor], \n",
    "        # negative token\n",
    "        target_negative_token_id: int,\n",
    "        target_negative_token: str,\n",
    "        negative_token_ae_features: [str, Tensor],\n",
    "        # neutral token\n",
    "        target_neutral_token_id: int,\n",
    "        target_neutral_token: str,\n",
    "        neutral_token_ae_features: [str, Tensor]\n",
    "    ):\n",
    "        self.training_point: \"TrainingPoint\" = training_point\n",
    "\n",
    "        self.target_positive_token = target_positive_token\n",
    "        self.target_positive_token_id = target_positive_token_id\n",
    "        self.target_positive_reward = self.training_point.target_positive_reward\n",
    "        self.positive_token_ae_features = positive_token_ae_features\n",
    "\n",
    "        self.target_negative_token = target_negative_token\n",
    "        self.target_negative_token_id = target_negative_token_id\n",
    "        self.target_negative_reward = self.training_point.target_negative_reward\n",
    "        self.negative_token_ae_features = negative_token_ae_features\n",
    "\n",
    "        self.target_neutral_token = target_neutral_token\n",
    "        self.target_neutral_token_id = target_neutral_token_id\n",
    "        self.neutral_token_ae_features = neutral_token_ae_features\n",
    "\n",
    "    def __str__(self):\n",
    "        return pprint.pformat(self.__dict__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ba321f7-398e-4861-9378-17c1901f16ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_training_dataset = load_linear_probe_training_dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0e45f08-7a78-4a51-ab05-bcab5629fc20",
   "metadata": {},
   "outputs": [],
   "source": [
    "x =full_training_dataset[4]\n",
    "x.training_point.activations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44bfafb8-67ad-4985-90e6-31fe9ec2c513",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_split_dataset, test_split_dataset = random_split_list(full_training_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d0a98c8-6453-431b-ad59-992584ffb118",
   "metadata": {},
   "source": [
    "### Define linear probe helper classes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d14ff89-3dbc-4120-97ce-6a92061c03ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class LinearProbeFinalInput:\n",
    "    token: str\n",
    "    token_id: int\n",
    "    divergence: float     # Divergence of the token to neutral token\n",
    "    features: csr_matrix  # Corresponds to the features of positive or negative token\n",
    "    point_type: str    # Can be positive or negative\n",
    "    source_training_point: LinearProbeTrainingPoint  #In case we need to inspect/retrieve original features\n",
    "\n",
    "    def __str__(self):\n",
    "        return pprint.pformat(self.__dict__)\n",
    "\n",
    "    def __repr__(self):\n",
    "        return str(self)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c5f5a31-26cf-4976-b800-fe0ab12eabff",
   "metadata": {},
   "source": [
    "### Construct training dataset and linear probe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56d9a73e-a0e0-43ec-b122-0ce6429f24f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def map_lp_training_point_to_pair_of_lp_final_inputs(lp_training_point: LinearProbeTrainingPoint) -> list[LinearProbeFinalInput]:\n",
    "    positive_features = concantenate_matrices(\n",
    "        lp_training_point.positive_token_ae_features)\n",
    "\n",
    "    negative_features = concantenate_matrices(\n",
    "        lp_training_point.negative_token_ae_features)\n",
    "\n",
    "    neutral_features = concantenate_matrices(\n",
    "        lp_training_point.neutral_token_ae_features)\n",
    "\n",
    "    positive_token = lp_training_point.target_positive_token\n",
    "    positive_token_id = lp_training_point.target_positive_token_id\n",
    "    positive_divergence = euclidean_distance(\n",
    "        positive_features, neutral_features\n",
    "    )\n",
    "\n",
    "    # Positive input training example.\n",
    "    positive_probe_final_input = LinearProbeFinalInput(\n",
    "        token=positive_token, token_id=positive_token_id,\n",
    "        divergence=positive_divergence, features=positive_features,\n",
    "        point_type='positive',\n",
    "        source_training_point = lp_training_point\n",
    "    )\n",
    "\n",
    "    negative_token = lp_training_point.target_negative_token\n",
    "    negative_token_id = lp_training_point.target_negative_token_id\n",
    "    negative_divergence = euclidean_distance(\n",
    "        negative_features, neutral_features\n",
    "    )\n",
    "\n",
    "    # Negative input training example - multiply divergence by minus one.\n",
    "    negative_probe_final_input = LinearProbeFinalInput(\n",
    "        token=negative_token, token_id=negative_token_id,\n",
    "        divergence=-1*negative_divergence, features=negative_features,\n",
    "        point_type='negative',\n",
    "        source_training_point = lp_training_point\n",
    "    )\n",
    "\n",
    "    return [positive_probe_final_input, negative_probe_final_input]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abe9a6e0-aa9e-419f-85e8-a654f66a205f",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_point = train_split_dataset[4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bf7c647-5493-4c6c-8676-988dad5e7c3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "positive_test_point, negative_test_point = map_lp_training_point_to_pair_of_lp_final_inputs(test_point)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99107626-892e-45aa-9207-95160b1e988d",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'\\nPositive point:\\n{pprint.pformat(positive_test_point)}')\n",
    "print(f'\\nNegative point:\\n{pprint.pformat(negative_test_point)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42d1c9ef-42dc-44c1-9353-bc75c7b8cba4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def map_lp_dataset_to_final_input_dataset(\n",
    "    input_dataset: list[LinearProbeTrainingPoint]) -> list[LinearProbeFinalInput]:\n",
    "\n",
    "    final_dataset = []\n",
    "\n",
    "    for datapoint in tqdm_notebook(input_dataset):\n",
    "        positive_point, negative_point = map_lp_training_point_to_pair_of_lp_final_inputs(datapoint)\n",
    "        final_dataset.append(positive_point)\n",
    "        final_dataset.append(negative_point)\n",
    "\n",
    "    return final_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ea55e8a-3c01-4e1d-965b-6a28da79d813",
   "metadata": {},
   "outputs": [],
   "source": [
    "mapped_train_split_dataset: list[LinearProbeFinalInput] = map_lp_dataset_to_final_input_dataset(\n",
    "    input_dataset = train_split_dataset\n",
    ")\n",
    "\n",
    "mapped_test_split_dataset: list[LinearProbeFinalInput] = map_lp_dataset_to_final_input_dataset(\n",
    "    input_dataset = test_split_dataset\n",
    ")\n",
    "\n",
    "mapped_full_split_dataset: list[LinearProbeFinalInput] = map_lp_dataset_to_final_input_dataset(\n",
    "    input_dataset = full_training_dataset\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ae3b9d7-c635-4f37-a6e2-40ff2ab02cea",
   "metadata": {},
   "outputs": [],
   "source": [
    "class FeatureConstructor:\n",
    "\n",
    "    def construct_feature_representation(self, linear_probe_inputs):\n",
    "        feature_rep = np.array([point.features.toarray().flatten() for point in linear_probe_inputs])\n",
    "        return feature_rep\n",
    "\n",
    "feature_constructor = FeatureConstructor()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c574c5f-d26d-4f78-9ce2-95440a035f98",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_linear_model(train_linear_probe_inputs: list[LinearProbeFinalInput], feature_constructor: FeatureConstructor = feature_constructor):\n",
    "    input_points = feature_constructor.construct_feature_representation(train_linear_probe_inputs)\n",
    "\n",
    "    output_points = np.array([point.divergence for point in train_linear_probe_inputs])\n",
    "\n",
    "    print(f'Shapes are {input_points.shape} and {output_points.shape}')\n",
    "\n",
    "    model = Ridge()\n",
    "    wandb.run.summary['linear_model_type'] = 'Ridge'\n",
    "    model.fit(input_points, output_points)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8f77205-df75-4cbf-9aee-70e596f9c98d",
   "metadata": {},
   "outputs": [],
   "source": [
    "linear_model_on_train = train_linear_model(train_linear_probe_inputs=mapped_train_split_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ae92cc7-4afc-43d6-86ca-16c63832ef1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "linear_model_on_full = train_linear_model(train_linear_probe_inputs=mapped_full_split_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "121272b9-3d72-4b07-9313-48c8bc77d6cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_fitted_values(linear_model, test_linear_probe_inputs, feature_constructor: FeatureConstructor = feature_constructor):\n",
    "    \"\"\"\n",
    "    \"\"\"\n",
    "    test_inputs = feature_constructor.construct_feature_representation(test_linear_probe_inputs)\n",
    "    test_values = linear_model.predict(test_inputs)\n",
    "    return test_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13d852ec-1308-4f79-99f9-cded3cfe6ce2",
   "metadata": {},
   "outputs": [],
   "source": [
    "fitted_values_on_test = get_fitted_values(linear_model=linear_model_on_train, test_linear_probe_inputs=mapped_test_split_dataset)\n",
    "fitted_values_on_full = get_fitted_values(linear_model=linear_model_on_full, test_linear_probe_inputs=mapped_full_split_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a56b1855-ae3a-4a49-8bbd-7ba08b6deea0",
   "metadata": {},
   "outputs": [],
   "source": [
    "fitted_values_and_inputs_on_test = list(zip(fitted_values_on_test, mapped_test_split_dataset))\n",
    "fitted_values_and_inputs_on_full = list(zip(fitted_values_on_full, mapped_full_split_dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b471793-7c97-476f-a77d-f69a0cf50834",
   "metadata": {},
   "outputs": [],
   "source": [
    "fitted_values_and_inputs_on_test[15]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e8c8fd5-95c6-4ea0-82a3-bc8924874dcb",
   "metadata": {},
   "source": [
    "### Do analysis on divergence values viz-a-viz original Vader lexicon."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f405dcd8-5c7f-4173-9797-e0f797aaf3b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def scale_values_and_input_list_to_range(values_and_input_list, min_range=min_vader_value, max_range=max_vader_value):\n",
    "    all_probe_values = []\n",
    "    all_lp_inputs = []\n",
    "    all_tokens = []\n",
    "\n",
    "    for values_and_inputs in values_and_input_list:\n",
    "        fitted_value = values_and_inputs[0]\n",
    "        lp_input = values_and_inputs[1]\n",
    "        token = lp_input.token\n",
    "\n",
    "        all_probe_values.append(fitted_value)\n",
    "        all_lp_inputs.append(lp_input)\n",
    "        all_tokens.append(token)\n",
    "\n",
    "    rescaled_token_to_value_dict_list = [{input.token: rescale_value(value, all_probe_values)} for value, input in values_and_input_list]\n",
    "    \n",
    "    return rescaled_token_to_value_dict_list, all_tokens, all_probe_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f399ba6e-3a0e-4294-ac2d-58c4dea3961d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Rescale values to a range and drop outliers.\n",
    "rescaled_token_to_value_dict_list_on_test, all_test_tokens, all_test_probe_values = scale_values_and_input_list_to_range(\n",
    "    values_and_input_list=fitted_values_and_inputs_on_test\n",
    ")\n",
    "rescaled_token_to_value_dict_list_on_full, all_full_tokens, all_full_probe_values = scale_values_and_input_list_to_range(\n",
    "    values_and_input_list=fitted_values_and_inputs_on_full\n",
    ")\n",
    "\n",
    "# These are the full token values.\n",
    "averaged_token_values_on_test = calculate_average_values(rescaled_token_to_value_dict_list_on_test)\n",
    "averaged_token_values_on_full = calculate_average_values(rescaled_token_to_value_dict_list_on_full)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23f303ae-db14-4e01-8a7e-97d6a4a67943",
   "metadata": {},
   "outputs": [],
   "source": [
    "rescaled_fitted_values_and_inputs_on_test = [\n",
    "    (rescale_value(fitted_value, all_test_probe_values), lp_input) for fitted_value, lp_input in fitted_values_and_inputs_on_test\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1afa77e0-ac76-49bd-91c6-1cbec12147d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_positive_test_tokens = [token for token in all_test_tokens if lexicon.get(token, 0) > 0]\n",
    "all_negative_test_tokens = [token for token in all_test_tokens if lexicon.get(token, 0) < 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e67e7a8-f009-4f14-bd52-d9d9fa0f00ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "rescaled_token_to_value_dict_list_on_test = sorted(\n",
    "    rescaled_token_to_value_dict_list_on_test, key=lambda x: list(x.keys())[0]\n",
    ")\n",
    "# rescaled_token_to_value_dict_list_on_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "008baeeb-a857-4a63-bdbe-b4a4e7b1c669",
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(random_seed)\n",
    "\n",
    "random_positive_tokens = random.sample(all_positive_test_tokens, 3)\n",
    "random_negative_tokens = random.sample(all_negative_test_tokens, 3)\n",
    "\n",
    "random_positive_token_values = {pos_token: averaged_token_values_on_test[pos_token] for pos_token in random_positive_tokens}\n",
    "random_negative_token_values = {neg_token: averaged_token_values_on_test[neg_token] for neg_token in random_negative_tokens}\n",
    "\n",
    "original_positive_values = {key: lexicon[key] for key in random_positive_tokens}\n",
    "original_negative_values = {key: lexicon[key] for key in random_negative_tokens}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c69485d-0ae1-42ff-8059-b2588426e329",
   "metadata": {},
   "source": [
    "### Find and log \"best\" and \"worst\" reconstructed points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88cc19fb-5fa2-4681-ae2c-822fa4b28290",
   "metadata": {},
   "outputs": [],
   "source": [
    "averaged_token_values_on_full\n",
    "\n",
    "def find_worst_and_best_reconstructed(reconstructed_token_values):\n",
    "    print(f'Original reconstruction has {len(reconstructed_token_values)}')\n",
    "\n",
    "    filtered_token_values = {key: lexicon[key] for key in reconstructed_token_values if lexicon[key] >= 2.0} \n",
    "\n",
    "    print(f'Filtered reconstruction has {len(filtered_token_values)}')\n",
    "\n",
    "    reconstructions = []\n",
    "\n",
    "    for token in filtered_token_values:\n",
    "        # Original - reconstructed\n",
    "        reconstruction_error = round(filtered_token_values[token] - reconstructed_token_values[token], 3)\n",
    "        absolute_error = abs(reconstruction_error)\n",
    "\n",
    "        reconstructions.append({\n",
    "            \"token\": token, \"original - reconstructed\": round(reconstruction_error, 3),\n",
    "            \"absolute_error\": absolute_error,\n",
    "            \"reconstructed_score\": reconstructed_token_values[token],\n",
    "            \"original_score\": filtered_token_values[token]\n",
    "        })\n",
    "\n",
    "\n",
    "    reconstructions = sorted(reconstructions, key=lambda x: x[\"original - reconstructed\"])\n",
    "    recovered_reconstructions = [item for item in reconstructions if item[\"reconstructed_score\"] >= 2.0]\n",
    "    bad_reconstructions = [item for item in reconstructions if item[\"reconstructed_score\"] < 0] \n",
    "    return recovered_reconstructions, bad_reconstructions, reconstructions\n",
    "\n",
    "recovered_reconstructions, bad_reconstructions, all_reconstructions = find_worst_and_best_reconstructed(reconstructed_token_values=averaged_token_values_on_full)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a533b54c-4f92-4026-9e2b-fb43c0f8e232",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open(f\"{model_name}_reconstructions.json\", \"w\") as f_out:\n",
    "    reconstructions_object = {\n",
    "        \"recovered_reconstructions\": recovered_reconstructions,\n",
    "        \"bad_reconstructions\": bad_reconstructions,\n",
    "        \"all_reconstructions\": all_reconstructions,\n",
    "        \"model_name\": model_name,\n",
    "        \"note\": \"\"\"Of all tokens with Vader utility >=2.0.\n",
    "        We assess tokens which also had reconstructed utility >=2.0 as \"recovered\".\n",
    "        And tokens which had reconstructed utility <=0.0 as \"bad\" reconstructions.\"\"\"\n",
    "    }\n",
    "    json.dump(reconstructions_object, f_out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86491fb9-2618-4ac4-99e5-e6e42c83446b",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.log(reconstructions_object)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ed6dab5-ae45-489c-81f1-e2f73a73e703",
   "metadata": {},
   "source": [
    "### Plot distribution of scores for positive and negative values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03cf08fe-0dcf-463b-8081-22d85be0e267",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_original_vs_modified_values(token_values):\n",
    "    all_original_token_values = {key: lexicon[key] for key in token_values}\n",
    "    token_values_list = token_values.values()\n",
    "    original_token_values_list = all_original_token_values.values()\n",
    "    sns.set(style='white')\n",
    "\n",
    "    # Plot the distributions\n",
    "    sns.histplot(\n",
    "        token_values_list, label='Reconstructed', palette=\"rocket\", kde=True, \n",
    "        bins=10, linewidth=0.1\n",
    "    )\n",
    "    sns.histplot(original_token_values_list, label='Original', palette=\"mako\", kde=True, bins=10, linewidth=0.3)\n",
    "    \n",
    "    # Add labels and title\n",
    "    plt.xlabel('Utility Values')\n",
    "    plt.ylabel('Frequency')\n",
    "    plt.title(f'Utility Distributions for {model_name}')\n",
    "    plt.legend() \n",
    "    plt.savefig(f'utility distributions_for_{model_name}.pdf')\n",
    "\n",
    "    # Add legend\n",
    "       \n",
    "    # Show plot\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "plot_original_vs_modified_values(token_values = averaged_token_values_on_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13533fda-1a83-4b10-91ed-c97e715fde8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_reconstruction_errors(token_values):\n",
    "    all_original_token_values = {key: lexicon[key] for key in token_values}\n",
    "    token_values_list = token_values.values()\n",
    "    original_token_values_list = all_original_token_values.values()\n",
    "\n",
    "    differences = {key: token_values[key] - all_original_token_values[key] for key in token_values}\n",
    "    \n",
    "    # Plot the distribution of differences\n",
    "    sns.histplot(list(differences.values()), kde=True, palette='rocket', bins=20)\n",
    "    plt.xlabel('Difference')\n",
    "    plt.ylabel('Frequency')\n",
    "    plt.title(f'Distribution of reconstruction errors for {model_name} utility values')\n",
    "    plt.savefig(f'Reconstruction error distributions_for_{model_name}.pdf')\n",
    "    \n",
    "    plt.show()\n",
    "\n",
    "plot_reconstruction_errors(token_values = averaged_token_values_on_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38eaeeb3-2d6b-413c-b305-63d142bccca2",
   "metadata": {},
   "source": [
    "### Sampled positive and negative values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb844369-e9aa-47b0-acb3-448cb43f9018",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'Reconstructed positive values: {random_positive_token_values}')\n",
    "print(f'Original positive values: {original_positive_values}')\n",
    "\n",
    "print(f'Reconstructed negative values: {random_negative_token_values}')\n",
    "print(f'Original negative values: {original_negative_values}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b838ef7d-fbc7-4bae-a4a9-187c0350e84f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def log_dictionary_as_table(table_name: str, dictionary_values: dict, columns=[\"token\", \"value\"]):\n",
    "    all_values = []\n",
    "\n",
    "    for token, value in dictionary_values.items():\n",
    "        all_values.append({\"token\": token, \"value\": value})\n",
    "\n",
    "    final_df = pd.DataFrame(all_values)\n",
    "\n",
    "    print(final_df)\n",
    "\n",
    "    wandb.log({table_name: final_df})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "832fe6dd-3715-4669-83b5-29cf6bba4a90",
   "metadata": {},
   "outputs": [],
   "source": [
    "log_dictionary_as_table(\n",
    "    \"sample_reconstructed_negative_token_utilities\", random_negative_token_values\n",
    ") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa53f931-4fd3-4fc4-ad57-43b0d89c49a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "log_dictionary_as_table(\n",
    "    \"sample_reconstructed_positive_token_utilities\", random_positive_token_values\n",
    ") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fd59708-4f5b-40aa-85a5-b81e51204a2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "log_dictionary_as_table(\n",
    "    \"full_reconstructed_token_utilities\", averaged_token_values_on_test\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed666e04-9a80-4870-9935-7bfb7319ca4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_original_token_values = {token: lexicon[token] for token in averaged_token_values_on_test}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "046763e5-258c-4ef0-85b8-b0f5dc9ef7e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "log_dictionary_as_table(\n",
    "    \"original_vader_token_utilities\", full_original_token_values\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ffde7ec-5493-4832-b001-8ef966d54ef8",
   "metadata": {},
   "outputs": [],
   "source": [
    "reconstructed_ranking = sorted(\n",
    "    [(token, value) for token, value in averaged_token_values_on_test.items() if token in all_negative_test_tokens],\n",
    "    key = lambda x: x[1]\n",
    ")\n",
    "\n",
    "original_ranking = sorted(\n",
    "    [(token, lexicon[token]) for token in averaged_token_values_on_test if token in all_negative_test_tokens],\n",
    "    key = lambda x: x[1]\n",
    ")\n",
    "\n",
    "reconstructed_ranking_tokens_only = [item[0] for item in reconstructed_ranking]\n",
    "original_ranking_tokens_only = [item[0] for item in original_ranking]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f5e04d8-74eb-473a-a694-d97a6b671779",
   "metadata": {},
   "outputs": [],
   "source": [
    "kendall_tau_result = kendalltau(original_ranking_tokens_only, reconstructed_ranking_tokens_only)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "662325d5-aa34-4940-86e8-ff149a36ed52",
   "metadata": {},
   "outputs": [],
   "source": [
    "kendall_tau_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6867fbf-dc7e-4d87-ac1e-6f9aeb694d0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.log({\"kendall_tau_result\": kendall_tau_result})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ef6e2eb-0278-4ae6-b309-dc08dddfdf73",
   "metadata": {},
   "source": [
    "### Find \"good\" and \"bad\" reconstructions on full linear probe."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31b4f123-7a2a-4381-bc75-b2870db5bfd9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "303c9889-ae20-4ab5-9000-514dcef91083",
   "metadata": {},
   "source": [
    "### Compute correlation of GPT-4 features with activations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c996bb26-c906-42c3-8082-cfc07adc8bf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "threshold_reward = 3.0\n",
    "high_value_features_key = policy_model_name.replace(\"_utility_reward\", \"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bcb16ce-3964-42ed-a364-b9557a5b1c8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_high_value_features(version='v0'):\n",
    "    artifact = run.use_artifact(f'nlp_and_interpretability/utility_reconstruction/high_value_features_artifact:{version}', type='data')\n",
    "    artifact_dir = artifact.download()\n",
    "\n",
    "    artifact_name = \"high_value_features_artifact\"\n",
    "\n",
    "    with open(f'artifacts/{artifact_name}:{version}/{artifact_name}', \"r\") as in_file:\n",
    "        high_value_features = json.load(in_file)\n",
    "\n",
    "    return high_value_features\n",
    "\n",
    "high_value_features = get_high_value_features()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fc13b0a-0db2-4164-9c24-9a94bc2e9ced",
   "metadata": {},
   "outputs": [],
   "source": [
    "high_value_features_for_model = high_value_features[high_value_features_key][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e48279b2-a9ac-4943-a4b6-99f2043b00d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "high_reward_inputs = [item for item in rescaled_fitted_values_and_inputs_on_test if item[0] >= threshold_reward]\n",
    "lower_reward_inputs = [item for item in rescaled_fitted_values_and_inputs_on_test if item[0] < threshold_reward]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d247deb-39a2-4633-8730-9e2a64c2ce06",
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class ActivationsTokenReward:\n",
    "    # Of the form (layer_num, feature_index) -> boolean (0 if inactive, 1 if active)\n",
    "    activations_dict: dict\n",
    "\n",
    "    # Of the form (layer_name: csr_matrix)\n",
    "    raw_activations_dict: dict\n",
    "    token: str\n",
    "    linear_probe_reward: float\n",
    "\n",
    "    def count_features(self):\n",
    "        return len(self.activations_dict.keys())\n",
    "\n",
    "    def count_all_activations(self):\n",
    "        return sum(list(self.activations_dict.values()))\n",
    "        \n",
    "    def count_activations(self, targeted_features: list[int, int]):\n",
    "        results = [self.activations_dict[tuple(feature)] for feature in targeted_features]\n",
    "        return sum(results)\n",
    "\n",
    "    def __str__(self):\n",
    "        return pprint.pformat(self.__dict__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "408011f8-ffb9-4822-9e8b-e2be5c25785f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_lp_reward_and_lp_final_input_into_atr(\n",
    "    lp_reward_and_lp_final_input: list[float, LinearProbeFinalInput], \n",
    "):\n",
    "    lp_reward = lp_reward_and_lp_final_input[0]\n",
    "    lp_final_input = lp_reward_and_lp_final_input[1]\n",
    "\n",
    "    token = lp_final_input.token\n",
    "    lp_reward = lp_reward\n",
    "\n",
    "    source_training_point = lp_final_input.source_training_point\n",
    "    if lp_final_input.point_type == 'positive':\n",
    "        activations_dict = source_training_point.positive_token_ae_features\n",
    "    else:\n",
    "        activations_dict = source_training_point.negative_token_ae_features\n",
    "\n",
    "    mapped_dict = {\n",
    "        model_customizer.parse_layer_name_to_layer_number(layer_name): \n",
    "        csr_activations.toarray()[0].astype(bool) for layer_name, csr_activations in activations_dict.items()\n",
    "    }\n",
    "\n",
    "    final_mapped_dict = {}\n",
    "\n",
    "    for layer_num, activations_fired_boolean_list in mapped_dict.items():\n",
    "        for index, activation_fired_boolean in enumerate(activations_fired_boolean_list):\n",
    "            final_mapped_dict[(layer_num, index)] = activation_fired_boolean\n",
    "    \n",
    "    return ActivationsTokenReward(\n",
    "        activations_dict=final_mapped_dict, raw_activations_dict=activations_dict,\n",
    "        token=token, linear_probe_reward=lp_reward\n",
    "    )\n",
    "\n",
    "\n",
    "atr = process_lp_reward_and_lp_final_input_into_atr(\n",
    "    high_reward_inputs[124])\n",
    "\n",
    "atr.count_activations(targeted_features = high_value_features_for_model)\n",
    "atr.count_features()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc6bbbc2-0941-4e28-82f9-cdaaaa0abdfb",
   "metadata": {},
   "outputs": [],
   "source": [
    " high_reward_atr_points = [process_lp_reward_and_lp_final_input_into_atr(reward_and_lp_input) for reward_and_lp_input in high_reward_inputs]\n",
    " lower_reward_atr_points = [process_lp_reward_and_lp_final_input_into_atr(reward_and_lp_input) for reward_and_lp_input in lower_reward_inputs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "654d2ff9-7ee2-4c3e-9320-d1fed0b53bce",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_total_feature_activation_percentage(atr_points):\n",
    "    num_inputs = len(atr_points)\n",
    "    num_features = atr_points[0].count_features()\n",
    "\n",
    "    total_feature_activations = sum([atr.count_all_activations() for atr in atr_points])\n",
    "    fa_percentage = 100*(total_feature_activations / (num_inputs * num_features))\n",
    "\n",
    "    return fa_percentage\n",
    "\n",
    "def get_feature_activation_percentage(atr_points, high_value_features_for_model=high_value_features_for_model):\n",
    "    num_inputs = len(atr_points)\n",
    "    num_features = len(high_value_features_for_model)\n",
    "\n",
    "    total_targeted_feature_activations = sum([atr.count_activations(targeted_features=high_value_features_for_model) for atr in atr_points])\n",
    "    targeted_fa_percentage = 100*(total_targeted_feature_activations / (num_inputs * num_features))\n",
    "\n",
    "    return targeted_fa_percentage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21bb694a-61b1-4f79-9390-d80c6f37d16e",
   "metadata": {},
   "outputs": [],
   "source": [
    "total_fa_percentage_on_high_reward = get_total_feature_activation_percentage(high_reward_atr_points)\n",
    "high_reward_feature_activation_percentage = get_feature_activation_percentage(high_reward_atr_points)\n",
    "lower_reward_feature_activation_percentage = get_feature_activation_percentage(lower_reward_atr_points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb9d21b1-3ef1-4a67-acb5-815cc1f948eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(total_fa_percentage_on_high_reward)\n",
    "print(high_reward_feature_activation_percentage)\n",
    "print(lower_reward_feature_activation_percentage)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c07f77b-dd4f-40e4-840a-88419e6af55e",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.run.summary[\"total_fa_percentage_on_high_reward\"] = total_fa_percentage_on_high_reward\n",
    "wandb.run.summary[\"high_reward_feature_activation_percentage\"] = high_reward_feature_activation_percentage\n",
    "wandb.run.summary[\"lower_reward_feature_activation_percentage\"] = lower_reward_feature_activation_percentage"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
