{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "import pandas as pd\n",
    "import torch\n",
    "train_dataset = load_dataset('PKU-Alignment/BeaverTails', split='30k_train')\n",
    "test_dataset = load_dataset('PKU-Alignment/BeaverTails', split='30k_test')\n",
    "y_train_df = pd.DataFrame(train_dataset['category'])\n",
    "categories = y_train_df.columns.tolist()\n",
    "y_train = torch.tensor(y_train_df.values)\n",
    "y_test = torch.tensor(pd.DataFrame(test_dataset['category']).values)\n",
    "x_train = [prompt + ' Response: ' + response for prompt, response in zip(train_dataset['prompt'], train_dataset['response'])]\n",
    "x_test = [prompt + ' Response: ' + response for prompt, response in zip(test_dataset['prompt'], test_dataset['response'])]\n",
    "\n",
    "# Create a smaller balanced subset with 300 samples per category\n",
    "import numpy as np\n",
    "import random\n",
    "\n",
    "# Set random seed for reproducibility\n",
    "random.seed(42)\n",
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "\n",
    "# Maximum samples per category\n",
    "samples_per_category = 300\n",
    "\n",
    "# Convert to numpy for easier manipulation\n",
    "y_train_np = y_train.numpy()\n",
    "\n",
    "# Create subset indices\n",
    "subset_indices = []\n",
    "\n",
    "# For each category, select up to samples_per_category positive examples\n",
    "for cat_idx in range(y_train_np.shape[1]):\n",
    "    # Get indices of positive samples for this category\n",
    "    positive_indices = np.where(y_train_np[:, cat_idx] == 1)[0]\n",
    "    \n",
    "    # If we have more than samples_per_category, sample randomly\n",
    "    if len(positive_indices) > samples_per_category:\n",
    "        selected_indices = np.random.choice(positive_indices, samples_per_category, replace=False)\n",
    "    else:\n",
    "        selected_indices = positive_indices\n",
    "    \n",
    "    # Add selected indices to our subset\n",
    "    subset_indices.extend(selected_indices)\n",
    "\n",
    "# Remove duplicates (some samples may belong to multiple categories)\n",
    "subset_indices = list(set(subset_indices))\n",
    "print(f\"Selected {len(subset_indices)} unique samples across all categories\")\n",
    "\n",
    "# Create the balanced subset\n",
    "x_train_small = [x_train[i] for i in subset_indices]\n",
    "y_train_small = y_train[subset_indices]\n",
    "\n",
    "# Print category distribution in the smaller dataset\n",
    "category_counts = y_train_small.sum(dim=0)\n",
    "print(\"\\nCategory distribution in smaller dataset:\")\n",
    "for i, category in enumerate(categories):\n",
    "    print(f\"{category}: {category_counts[i].item()} samples\")\n",
    "\n",
    "print(f\"\\nOriginal training set: {len(x_train)} samples\")\n",
    "print(f\"Smaller balanced training set: {len(x_train_small)} samples\")\n",
    "print(f\"Test set remains unchanged: {len(x_test)} samples\")\n",
    "\n",
    "exp_dir = \"output/experiments/safety/beaver_tails/v0.2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ar_model.search(inputs=data_openai['x_train'], labels=data_openai['y_train'], reset_cache=False, batch_size=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Trees"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ar.eval import evaluate_model\n",
    "from ar.config import LogicConfig\n",
    "\n",
    "# Assuming data['categories'] contains your categories and data['y_test'] contains your test labels\n",
    "# If you're working with custom data, adjust these variables accordingly\n",
    "tree_concept_config = LogicConfig(\n",
    "   search_concept_type='sentence', \n",
    "    search_concept_token='all', \n",
    "    search_strategy='tree', \n",
    "    search_tree_depth=5,  # same depth as before\n",
    "    detection_top_k_output=32,  # different detection parameters\n",
    ")\n",
    "\n",
    "\n",
    "llama3_1 = {\n",
    "'model_name': \"meta-llama/Meta-Llama-3.1-8B\",\n",
    "'sae_name': \"EleutherAI/sae-llama-3.1-8b-64x\",\n",
    "'layer': 23,\n",
    "'hookpoint': 'layers.23',\n",
    "'cache_dir': exp_dir,\n",
    "}\n",
    "\n",
    "results = evaluate_model(\n",
    "    test_data=x_test,\n",
    "    test_labels=y_test,\n",
    "    train_data=x_train,\n",
    "    train_labels=y_train,\n",
    "    concepts=categories,\n",
    "    rules={},\n",
    "    config=tree_concept_config,\n",
    "    model_kwargs=llama3_1,\n",
    "    batch_size=8,\n",
    "    verbose=True,\n",
    "    save_path=\"output/experiments/safety/eval_metrics\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ar.eval import evaluate_model\n",
    "from ar.config import LogicConfig\n",
    "\n",
    "# Assuming data['categories'] contains your categories and data['y_test'] contains your test labels\n",
    "# If you're working with custom data, adjust these variables accordingly\n",
    "tree_concept_config = LogicConfig(\n",
    "   search_concept_type='sentence', \n",
    "    search_concept_token='all', \n",
    "    search_strategy='tree', \n",
    "    search_tree_depth=10,  # same depth as before\n",
    "    detection_top_k_output=32,  # different detection parameters\n",
    ")\n",
    "\n",
    "\n",
    "llama3_1 = {\n",
    "'model_name': \"meta-llama/Meta-Llama-3.1-8B\",\n",
    "'sae_name': \"EleutherAI/sae-llama-3.1-8b-64x\",\n",
    "'layer': 23,\n",
    "'hookpoint': 'layers.23',\n",
    "'cache_dir': exp_dir,\n",
    "}\n",
    "\n",
    "results = evaluate_model(\n",
    "    test_data=x_test,\n",
    "    test_labels=y_test,\n",
    "    train_data=x_train,\n",
    "    train_labels=y_train,\n",
    "    concepts=categories,\n",
    "    rules={},\n",
    "    config=tree_concept_config,\n",
    "    model_kwargs=llama3_1,\n",
    "    batch_size=8,\n",
    "    verbose=True,\n",
    "    save_path=\"output/experiments/safety/eval_metrics\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ar.eval import evaluate_model\n",
    "from ar.config import LogicConfig\n",
    "\n",
    "# Assuming data['categories'] contains your categories and data['y_test'] contains your test labels\n",
    "# If you're working with custom data, adjust these variables accordingly\n",
    "tree_concept_config = LogicConfig(\n",
    "   search_concept_type='sentence', \n",
    "    search_concept_token='all', \n",
    "    search_strategy='tree', \n",
    "    search_tree_depth=5,  # same depth as before\n",
    "    detection_top_k_output=32,  # different detection parameters\n",
    ")\n",
    "\n",
    "\n",
    "llama3_1 = {\n",
    "'model_name': \"meta-llama/Meta-Llama-3.1-8B\",\n",
    "'sae_name': \"EleutherAI/sae-llama-3.1-8b-64x\",\n",
    "'layer': 23,\n",
    "'hookpoint': 'layers.23',\n",
    "'cache_dir': exp_dir,\n",
    "}\n",
    "\n",
    "results = evaluate_model(\n",
    "    test_data=x_test,\n",
    "    test_labels=y_test,\n",
    "    train_data=x_train_small,\n",
    "    train_labels=y_train_small,\n",
    "    concepts=categories,\n",
    "    rules={},\n",
    "    config=tree_concept_config,\n",
    "    model_kwargs=llama3_1,\n",
    "    batch_size=8,\n",
    "    verbose=True,\n",
    "    save_path=\"output/experiments/safety/eval_metrics\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ar.eval import evaluate_model\n",
    "from ar.config import LogicConfig\n",
    "\n",
    "# Assuming data['categories'] contains your categories and data['y_test'] contains your test labels\n",
    "# If you're working with custom data, adjust these variables accordingly\n",
    "sae_concept_vectors_config = LogicConfig(\n",
    "    search_concept_type='sentence', \n",
    "    search_concept_token='all', \n",
    "    search_strategy='top_k',\n",
    "    detection_top_k_output=32,  # different detection parameters\n",
    "    detection_top_k_concepts=5,\n",
    "    detection_allow_multi=True,\n",
    "    # search_top_k_order=\"original_order\"\n",
    ")\n",
    "\n",
    "\n",
    "llama3_1 = {\n",
    "'model_name': \"meta-llama/Meta-Llama-3.1-8B\",\n",
    "'sae_name': \"EleutherAI/sae-llama-3.1-8b-64x\",\n",
    "'layer': 23,\n",
    "'hookpoint': 'layers.23',\n",
    "'cache_dir': exp_dir,\n",
    "}\n",
    "\n",
    "results = evaluate_model(\n",
    "    test_data=x_test,\n",
    "    test_labels=y_test,\n",
    "    train_data=x_train_small,\n",
    "    train_labels=y_train_small,\n",
    "    concepts=categories,\n",
    "    config=sae_concept_vectors_config,\n",
    "    model_kwargs=llama3_1,\n",
    "    batch_size=8,\n",
    "    verbose=True,\n",
    "    save_path=\"output/experiments/safety/eval_metrics\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ar.eval import evaluate_model\n",
    "from ar.config import LogicConfig\n",
    "\n",
    "# Assuming data['categories'] contains your categories and data['y_test'] contains your test labels\n",
    "# If you're working with custom data, adjust these variables accordingly\n",
    "sae_concept_config = LogicConfig(\n",
    "    search_concept_type='sentence', \n",
    "    search_concept_token='all', \n",
    "    search_strategy='top_k',\n",
    "    detection_top_k_output=32,  # different detection parameters\n",
    "    detection_top_k_concepts=1,\n",
    "    detection_allow_multi=True,\n",
    "    search_top_k_order=\"original_order\"\n",
    ")\n",
    "\n",
    "\n",
    "llama3_1 = {\n",
    "'model_name': \"meta-llama/Meta-Llama-3.1-8B\",\n",
    "'sae_name': \"EleutherAI/sae-llama-3.1-8b-64x\",\n",
    "'layer': 23,\n",
    "'hookpoint': 'layers.23',\n",
    "'cache_dir': exp_dir,\n",
    "}\n",
    "\n",
    "results = evaluate_model(\n",
    "    test_data=x_test,\n",
    "    test_labels=y_test,\n",
    "    train_data=x_train_small,\n",
    "    train_labels=y_train_small,\n",
    "    concepts=categories,\n",
    "    config=sae_concept_config,\n",
    "    model_kwargs=llama3_1,\n",
    "    batch_size=8,\n",
    "    verbose=True,\n",
    "    save_path=\"output/experiments/safety/eval_metrics\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ar.eval import evaluate_model\n",
    "from ar.config import LogicConfig\n",
    "\n",
    "# Assuming data['categories'] contains your categories and data['y_test'] contains your test labels\n",
    "# If you're working with custom data, adjust these variables accordingly\n",
    "sae_concept_config = LogicConfig(\n",
    "    search_concept_type='sentence', \n",
    "    search_concept_token='all', \n",
    "    search_strategy='top_k',\n",
    "    detection_top_k_output=1,  # different detection parameters\n",
    "    detection_top_k_concepts=1,\n",
    "    detection_allow_multi=True,\n",
    "    search_top_k_order=\"original_order\"\n",
    ")\n",
    "\n",
    "\n",
    "llama3_1 = {\n",
    "'model_name': \"meta-llama/Meta-Llama-3.1-8B\",\n",
    "'sae_name': \"EleutherAI/sae-llama-3.1-8b-64x\",\n",
    "'layer': 23,\n",
    "'hookpoint': 'layers.23',\n",
    "'cache_dir': exp_dir,\n",
    "}\n",
    "\n",
    "results = evaluate_model(\n",
    "    test_data=x_test,\n",
    "    test_labels=y_test,\n",
    "    train_data=x_train_small,\n",
    "    train_labels=y_train_small,\n",
    "    concepts=categories,\n",
    "    config=sae_concept_config,\n",
    "    model_kwargs=llama3_1,\n",
    "    batch_size=8,\n",
    "    verbose=True,\n",
    "    save_path=\"output/experiments/safety/eval_metrics\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
