{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from transformer_lens import HookedTransformer\n",
    "\n",
    "import sae_bench.sae_bench_utils.dataset_utils as dataset_utils\n",
    "\n",
    "dataset_name = \"monology/pile-uncopyrighted\"\n",
    "\n",
    "\n",
    "def get_dataset_list_of_strs(\n",
    "    dataset_name: str, column_name: str, min_row_chars: int, total_chars: int\n",
    ") -> list[str]:\n",
    "    dataset = load_dataset(dataset_name, split=\"train\", streaming=True)\n",
    "\n",
    "    total_chars_so_far = 0\n",
    "    result = []\n",
    "\n",
    "    for row in dataset:\n",
    "        if len(row[column_name]) > min_row_chars:\n",
    "            result.append(row[column_name])\n",
    "            total_chars_so_far += len(row[column_name])\n",
    "            if total_chars_so_far > total_chars:\n",
    "                break\n",
    "\n",
    "    return result\n",
    "\n",
    "\n",
    "model = HookedTransformer.from_pretrained_no_processing(\"pythia-70m-deduped\")\n",
    "dataset_list_of_strs = get_dataset_list_of_strs(dataset_name, \"text\", 100, 10_000_000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens = dataset_utils.tokenize_and_concat_dataset(\n",
    "    model.tokenizer,\n",
    "    dataset_list_of_strs,\n",
    "    seq_len=128,\n",
    "    add_bos=True,\n",
    "    max_tokens=2_000_000,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tokens.shape, (tokens.shape[0] * tokens.shape[1]) / 1_000_000, \"M tokens\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens2 = dataset_utils.load_and_tokenize_dataset(\n",
    "    dataset_name, 128, 2_000_000, model.tokenizer\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sae_bench.sae_bench_utils.dataset_info as dataset_info\n",
    "import sae_bench.sae_bench_utils.dataset_utils as dataset_utils\n",
    "\n",
    "dataset_name = \"canrager/amazon_reviews_mcauley_1and5_sentiment\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data, test_data = dataset_utils.get_multi_label_train_test_data(\n",
    "    dataset_name, 4000, 1000, 42\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(train_data.keys())\n",
    "print(train_data[\"1.0\"][0])\n",
    "print(train_data[\"5.0\"][0])\n",
    "print(len(train_data[\"1.0\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_dataset_name = \"codeparrot/github-code\"\n",
    "languages = [\"C\", \"Python\", \"HTML\", \"Java\", \"PHP\"]\n",
    "\n",
    "train_code_data, test_code_data = dataset_utils.get_github_code_dataset(\n",
    "    new_dataset_name, languages, 4000, 1000, 42\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_dataset_name2 = \"fancyzhx/ag_news\"\n",
    "chosen_classes = dataset_info.chosen_classes_per_dataset[new_dataset_name2]\n",
    "\n",
    "train_news_data, test_news_data = dataset_utils.get_ag_news_dataset(\n",
    "    new_dataset_name2, chosen_classes, 4000, 1000, 42\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_dataset_name3 = \"Helsinki-NLP/europarl\"\n",
    "chosen_classes = dataset_info.chosen_classes_per_dataset[new_dataset_name3]\n",
    "\n",
    "train_news_data, test_news_data = dataset_utils.get_ag_news_dataset(\n",
    "    new_dataset_name3, chosen_classes, 4000, 1000, 42\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(train_code_data.keys())\n",
    "print(train_code_data[\"C\"][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data, test_data = dataset_utils.get_multi_label_train_test_data(\n",
    "    train_df,\n",
    "    test_df,\n",
    "    dataset_name,\n",
    "    4000,\n",
    "    1000,\n",
    "    42,\n",
    ")\n",
    "\n",
    "chosen_classes = dataset_info.chosen_classes_per_dataset[dataset_name]\n",
    "\n",
    "train_data = dataset_utils.filter_dataset(train_data, chosen_classes)\n",
    "test_data = dataset_utils.filter_dataset(test_data, chosen_classes)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
