{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c6bff33e-dce3-4991-8498-6c7b206be648",
   "metadata": {},
   "source": [
    "### INSTALLS AND IMPORTS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c72eb777-2427-48b4-a519-9a57273d172c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import re\n",
    "\n",
    "from datasets import Dataset, DatasetDict, load_dataset, load_from_disk\n",
    "import wandb\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from collections import OrderedDict\n",
    "from transformers import AutoModel, AutoTokenizer\n",
    "import requests\n",
    "from scipy.sparse import csr_matrix\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c4cc52d-34d5-493c-8e39-0964285dae8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import huggingface_hub\n",
    "import transformers\n",
    "from transformers import AutoModel, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c35cc3d5-8bc5-4e5e-a06e-9788f79ef95a",
   "metadata": {},
   "outputs": [],
   "source": [
    "! rm -rf Interpreting-Reward-Models || true\n",
    "! git clone XXXX\n",
    "! cd Interpreting-Reward-Models && pip install ."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3e39d1e-c04a-4100-8577-fb4e923777f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from reward_analyzer import SparseAutoencoder, TaskConfig\n",
    "from reward_analyzer.utils.model_storage_utils import load_autoencoders_for_artifact, load_latest_model_from_hub, download_folder_from_hub\n",
    "from reward_analyzer.utils.transformer_utils import batch\n",
    "from reward_analyzer.configs.project_configs import HuggingfaceConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebf72db3-b682-4874-8731-8c66072fe0d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "contrastive_triples_path = 'data/contrastive_triples_rlhf.dataset/2024-05-15_14'\n",
    "download_folder_from_hub(folder_path=contrastive_triples_path, local_folder=contrastive_triples_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8fa5688-968b-44ca-b751-96f4d70bb815",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Override with your own config if not using Amir's huggingface hub account.\n",
    "huggingface_config = HuggingfaceConfig()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7da93d60-8239-43c4-a92f-07b157c38d75",
   "metadata": {},
   "outputs": [],
   "source": [
    "contrastive_dataset = load_from_disk(contrastive_triples_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "186db6d0-a4ef-477a-9a44-fb39cfefc72d",
   "metadata": {},
   "outputs": [],
   "source": [
    "supported_model_names = ['EleuterAI/pythia-70m', 'EleutherAI/pythia-160m', 'google/gemma-2b-it', 'EleutherAI/gpt-neo-125m']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8e8d862-1e7d-4e7f-8418-c0b98345da56",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'EleutherAI/pythia-70m'\n",
    "#model_name = 'EleutherAI/pythia-160m'\n",
    "#model_name = 'EleutherAI/gpt-neo-125m'\n",
    "#model_name = 'google/gemma-2b-it'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee0e6019-585c-4cdd-b163-e63e33b1f1b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9c70b4e-ff12-4055-8fa8-ed6ba51b4dd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "task = TaskConfig.HH_RLHF\n",
    "task_name = task.name\n",
    "version = 'v0'\n",
    "\n",
    "if 'pythia' in model_name:\n",
    "    layer_name_stem = 'layers.{}.mlp'\n",
    "elif 'neo' in model_name:\n",
    "    layer_name_stem = 'h.{}.mlp'\n",
    "elif 'gemma' in model_name:\n",
    "    layer_name_stem = 'layers.{}.mlp'\n",
    "else:\n",
    "    raise Exception(f'Not familiar with model name family of {model_name}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb831298-a97a-4d7c-b215-9cbd8255c181",
   "metadata": {},
   "source": [
    "### Load model and autoencoder artifacts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f78e2e80-8efd-4d44-9572-ab3c3cdfe7b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer =  AutoTokenizer.from_pretrained(model_name)\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "# Note that this downloads the RLHF tuned version of similar stem from our huggingface hub repo, not the base model.\n",
    "model = load_latest_model_from_hub(model_name = model_name, task_config=task)\n",
    "model.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35fb5987-4d65-4542-a124-9b44a6fe0b92",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.init(project = f'Autoencoder_training_{task.name}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a36957a-00dc-4b57-a237-817306b7e5d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "autoencoders_dict = load_autoencoders_for_artifact(f'nlp_and_interpretability/Autoencoder_training_{task.name}/autoencoders_{model_name.split(\"/\")[-1].replace(\"-\", \"_\")}_{task_name}:{version}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9875733d-1e7e-4412-9eb0-1b71a6e55bcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "contrastive_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cc62c1a-59ac-4a78-b3b2-85e601004028",
   "metadata": {},
   "outputs": [],
   "source": [
    "rlhf_small = autoencoders_dict['rlhf_small']\n",
    "\n",
    "for key, value in rlhf_small.items():\n",
    "    value.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f08b322-2a22-4018-9f26-6a0479e924af",
   "metadata": {},
   "outputs": [],
   "source": [
    "def dump_data_to_jsonl(data: dict, filename: str):\n",
    "    list_lengths = [len(value_list) for value_list in data.values()]\n",
    "\n",
    "    assert min(list_lengths) == max(list_lengths), f'Expected list lengths to be the same! Instead got {list_lengths}'\n",
    "    n = max(list_lengths)\n",
    "\n",
    "    # Open a file to write JSON Lines\n",
    "    with open(filename, 'a+') as jsonl_file:\n",
    "        # Iterate over the index of the lists\n",
    "        all_lines = []\n",
    "        for i in range(n):\n",
    "            # Create a dictionary for the current JSON object\n",
    "            json_object = {key: values[i] for key, values in data.items()}\n",
    "\n",
    "            all_lines.append(json.dumps(json_object) + '\\n')\n",
    "        \n",
    "        # Write the JSON object as a line in the JSONL file\n",
    "        jsonl_file.writelines(all_lines)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "023e195a-56dc-479c-ba2e-a8ec885c6c0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def features_from_single_input(single_input):\n",
    "    return torch.mean(single_input, dim=0)\n",
    "\n",
    "def extract_and_process_activations(texts, model, tokenizer, layer_name_stem, autoencoders_dict, with_full_activations=False):\n",
    "    inputs = tokenizer(texts, return_tensors=\"pt\", padding=True, truncation=True, max_length=128)\n",
    "    token_ids = inputs[\"input_ids\"].squeeze().tolist()\n",
    "\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "    # Move tokenized inputs to the CUDA device\n",
    "    inputs = {key: value.to(device) for key, value in inputs.items()}\n",
    "    activations = {}\n",
    "\n",
    "    target_layer_names = [layer_name_stem.format(key) for key in autoencoders_dict]\n",
    "\n",
    "    def get_activation(name):\n",
    "        def hook(model, input, output):\n",
    "            activations[name] = output.detach()\n",
    "        return hook\n",
    "\n",
    "    hooks = [\n",
    "        module.register_forward_hook(get_activation(name))\n",
    "        for name, module in model.named_modules()\n",
    "        if name in target_layer_names\n",
    "    ]\n",
    "\n",
    "    with torch.no_grad():\n",
    "        outputs = model(**inputs)\n",
    "\n",
    "    for hook in hooks:\n",
    "        hook.remove()\n",
    "\n",
    "    specified_activations =  {layer_num: activations[layer_name_stem.format(layer_num)] for layer_num in autoencoders_dict}\n",
    "    final_token_embeddings = outputs.last_hidden_state.squeeze().detach().tolist()\n",
    "    final_token_embeddings = [item[0] for item in final_token_embeddings]\n",
    "    tokens = [tokenizer.convert_ids_to_tokens(local_token_ids) for local_token_ids in token_ids]\n",
    "\n",
    "\n",
    "    all_features = {\n",
    "        \"texts\": texts,\n",
    "        \"token_ids\": token_ids,\n",
    "        \"token_embeddings\": final_token_embeddings,\n",
    "        \"tokens\": tokens\n",
    "    }\n",
    "\n",
    "    for layer_num, activation_values in specified_activations.items():\n",
    "        activation_values = activation_values.squeeze(0).cpu()\n",
    "        autoencoder = autoencoders_dict[layer_num].cpu()\n",
    "        batch_features, _ = autoencoder(activation_values)\n",
    "        batch_features = batch_features.detach().squeeze(0)\n",
    "\n",
    "        if with_full_activations:\n",
    "            all_features[f'activations_{layer_num}'] = activation_values.detach().cpu().squeeze(0).numpy().tolist()\n",
    "\n",
    "        full_reprs = []\n",
    "        averaged_reprs = []\n",
    "        for single_feature in batch_features:\n",
    "            averaged_repr_each_input = features_from_single_input(single_feature).cpu().tolist()\n",
    "\n",
    "            full_reprs.append(single_feature.cpu().tolist())\n",
    "            averaged_reprs.append(averaged_repr_each_input)\n",
    "\n",
    "        # all_features[f'full_repr_{layer_num}'] = full_reprs\n",
    "        all_features[f'averaged_reprs_{layer_num}'] = averaged_reprs\n",
    "\n",
    "    return all_features\n",
    "\n",
    "import shutil\n",
    "\n",
    "def remove_path_if_exists(path):\n",
    "    # Check if the path exists\n",
    "    if os.path.exists(path):\n",
    "        # Check if the path is a file\n",
    "        if os.path.isfile(path):\n",
    "            os.remove(path)\n",
    "            print(f\"File {path} has been deleted.\")\n",
    "        # Check if the path is a directory\n",
    "        elif os.path.isdir(path):\n",
    "            shutil.rmtree(path)\n",
    "            print(f\"Directory {path} and all its contents have been deleted.\")\n",
    "    else:\n",
    "        print(f\"The path {path} does not exist.\")\n",
    "\n",
    "\n",
    "def extract_features_batched(texts, model_name, model, tokenizer, layer_name_stem, autoencoders_dict, source='', output_filestem=None, batch_size=8, with_full_activations=False):\n",
    "    model_name = model_name.split(\"/\")[-1].replace(\"-\", \"_\")\n",
    "    output_filestem = output_filestem or f'./{model_name}_{task.name}_{source}_activations_dataset'\n",
    "\n",
    "    remove_path_if_exists(f'{output_filestem}.jsonl')\n",
    "    remove_path_if_exists(f'{output_filestem}.hf')\n",
    "\n",
    "\n",
    "    for curr_batch in tqdm(batch(texts, n=batch_size)):\n",
    "        features = extract_and_process_activations(curr_batch, model, tokenizer, layer_name_stem, autoencoders_dict, with_full_activations=with_full_activations)\n",
    "        dump_data_to_jsonl(features, filename = f'{output_filestem}.jsonl')\n",
    "\n",
    "    dataset = load_dataset(\"json\", data_files = f'{output_filestem}.jsonl', split='train')\n",
    "    filename = f'{output_filestem}.hf'\n",
    "    dataset.save_to_disk(filename)\n",
    "    return dataset, filename"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20d36177-9e4d-48ab-be02-b203a9b12bb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_datasets(source_dataset, model_name, model, tokenizer, layer_name_stem, autoencoders_dict, task, with_full_activations=False):\n",
    "    feature_names = ['chosen', 'new_rejected']\n",
    "    split_datasets = {}\n",
    "    for feature in feature_names:\n",
    "        feature_texts = source_dataset[feature]\n",
    "        activations_dataset, filename = extract_features_batched(\n",
    "            texts=feature_texts, model_name=model_name, model=model, tokenizer=tokenizer, layer_name_stem=layer_name_stem, autoencoders_dict=autoencoders_dict, source=feature,\n",
    "            with_full_activations=with_full_activations\n",
    "        )\n",
    "        split_datasets[feature] = activations_dataset\n",
    "        print(f'Dataset is of type {type(activations_dataset)}')\n",
    "\n",
    "    merged_dataset = DatasetDict(split_datasets)\n",
    "    filename = f'merged_contrastive_{model_name.split(\"/\")[-1].replace(\"-\", \"_\")}_{task.name}_activations_and_features.hf'\n",
    "\n",
    "    merged_dataset.save_to_disk(filename)\n",
    "    return merged_dataset, filename\n",
    "\n",
    "merged_dataset, filename = generate_datasets(\n",
    "    source_dataset=contrastive_dataset, model_name=model_name, model=model, tokenizer=tokenizer,\n",
    "    layer_name_stem=layer_name_stem, autoencoders_dict=rlhf_small, task=task\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af2a3594-1055-4741-a4ef-114058c71cf9",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(merged_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8ed9eab-0324-433e-b864-65c7a00dcd7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "api = huggingface_hub.HfApi()\n",
    "\n",
    "api.upload_folder(\n",
    "    repo_id=huggingface_config.repo_id,\n",
    "    folder_path=filename,\n",
    "    path_in_repo=f'data/{filename}',\n",
    "    repo_type=None\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27d871f1-f4a5-4e5f-8f2f-45f081ce51ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_training_dataset_to_wandb(training_dataset: Dataset, model_name, dataset_name), task:\n",
    "    out_filename = training_dataset.save_to_disk(dataset_name)\n",
    "    \n",
    "    my_artifact = wandb.Artifact(f\"logistic_probe_training_dataset_{model_name}_{task.name}\", type=\"data\")\n",
    "    \n",
    "    # Add the list to the artifact\n",
    "    my_artifact.add_file(local_path=out_filename, name=\"logistic_probe_training_dataset\")\n",
    "\n",
    "    metadata_dict = {\n",
    "        \"description\": \"Training dataset, with activations and rewards\",\n",
    "        \"source\": \"Generated by my script\",\n",
    "        \"num_examples\": len(training_dataset),\n",
    "        \"split\": \"full\"\n",
    "    }\n",
    "\n",
    "    my_artifact.metadata.update(metadata_dict)\n",
    "\n",
    "    # Log the artifact to the run\n",
    "    wandb.log_artifact(my_artifact)\n",
    "\n",
    "save_training_dataset_to_wandb(merged_dataset, model_name=model_name)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
