{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b3633acd-cf3d-425e-b72f-5f677b50cce2",
   "metadata": {},
   "source": [
    "### Setup imports and environment variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2b6b274-06b0-40cb-9a9b-c3f1cd866ce9",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "! rm -rf Interpreting-Reward-Models || true\n",
    "! git clone XXXX\n",
    "! cd Interpreting-Reward-Models && pip install ."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff617388-7094-42b4-acda-ebc4f138e5cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from dataclasses import dataclass, field\n",
    "from typing import Dict, Optional\n",
    "\n",
    "import huggingface_hub\n",
    "import torch\n",
    "import wandb\n",
    "\n",
    "from datasets import Dataset, load_dataset\n",
    "from huggingface_hub import hf_hub_download, upload_file, upload_folder, HfApi\n",
    "from torch.optim import Adam\n",
    "from tqdm import tqdm\n",
    "from transformers import (\n",
    "    AutoModelForCausalLM,\n",
    "    AutoTokenizer,\n",
    "    HfArgumentParser,\n",
    "    LlamaTokenizer,\n",
    "    pipeline\n",
    ")\n",
    "\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments\n",
    "from trl import AutoModelForCausalLMWithValueHead, DPOTrainer\n",
    "from trl import DPOTrainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75aa4d18-5bf8-4f6a-aadd-ba0f8155a885",
   "metadata": {},
   "outputs": [],
   "source": [
    "repo_id = 'amirabdullah19852020/interpreting_reward_models'\n",
    "tqdm.pandas()\n",
    "huggingface_hub.login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b75104d7-ffee-4c0c-8308-0213952eca23",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3761782-7b2b-4048-a53d-00374501a499",
   "metadata": {},
   "outputs": [],
   "source": [
    "from reward_analyzer.configs.rlhf_training_config import DPOTrainingConfig"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "235eee04-f0b9-4754-8b38-568625e55eab",
   "metadata": {},
   "source": [
    "### Set up DPO arguments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae3e2a0c-b9fb-4aab-b219-21679dfec0da",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass, field\n",
    "from typing import Dict, Optional\n",
    "\n",
    "import torch\n",
    "from datasets import Dataset, load_dataset\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments\n",
    "\n",
    "from trl import DPOTrainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bfc3c1c-41c9-46bb-9096-8abf3cfddb16",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name=\"EleutherAI/gpt-neo-125m\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4eb123ba-a00a-45e2-adf1-d1e379af7387",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"PKU-Alignment/PKU-SafeRLHF\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "226bd007-cab9-4543-a72a-520268d625c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset['train'][:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c88a369e-57d4-4aaa-875b-77e21fe7f39c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_safe_rlhf_model(model_name):\n",
    "    script_args = DPOTrainingConfig(model_name_or_path=model_name)\n",
    "    model = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path).cuda()\n",
    "\n",
    "    if script_args.ignore_bias_buffers:\n",
    "        # torch distributed hack\n",
    "        model._ddp_params_and_buffers_to_ignore = [\n",
    "            name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool\n",
    "        ]\n",
    "\n",
    "    model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path).cpu()\n",
    "    tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path)\n",
    "    if tokenizer.pad_token is None:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "    train_dataset = get_hh(\"train\", sanity_check=script_args.sanity_check)\n",
    "    eval_dataset = get_hh(\"test\", sanity_check=True)\n",
    "\n",
    "    training_args = TrainingArguments(\n",
    "            per_device_train_batch_size=script_args.per_device_train_batch_size, max_steps=script_args.max_steps,\n",
    "            remove_unused_columns=False, gradient_accumulation_steps=script_args.gradient_accumulation_steps,\n",
    "            learning_rate=script_args.learning_rate, push_to_hub=True,\n",
    "            hub_model_id=script_args.huggingface_hub_name, evaluation_strategy=\"steps\",\n",
    "            logging_first_step=True, logging_steps=10,\n",
    "            eval_steps=2000, output_dir=\"./test\",\n",
    "            optim=\"adamw_hf\", warmup_steps=150,\n",
    "            report_to=script_args.report_to, bf16=True,\n",
    "            gradient_checkpointing=script_args.gradient_checkpointing\n",
    "    )\n",
    "\n",
    "    dpo_trainer = DPOTrainer(\n",
    "        model, model_ref,\n",
    "        args=training_args, beta=script_args.beta,\n",
    "        train_dataset=train_dataset, eval_dataset=eval_dataset,\n",
    "        tokenizer=tokenizer,max_length=script_args.max_length,\n",
    "        max_target_length=script_args.max_target_length, max_prompt_length=script_args.max_prompt_length,\n",
    "        generate_during_eval=True\n",
    "    )\n",
    "    dpo_trainer.train()\n",
    "    return dpo_trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a6c5861-18da-4a12-a816-539f8417ce38",
   "metadata": {},
   "outputs": [],
   "source": [
    "dpo_trainer = train_anthropic_model(model_name)\n",
    "dpo_trained.model.save_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "400a4bf5-6aee-49e6-bf66-8bc7d620af7c",
   "metadata": {},
   "source": [
    "### Setup dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b1acce0-2578-46b9-97bd-4bee25d8298e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import huggingface_hub\n",
    "huggingface_hub.login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df529c1d-83b6-4a14-8ad7-401f16220e9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datetime import datetime\n",
    "\n",
    "# Get the current datetime\n",
    "current_datetime = datetime.now()\n",
    "\n",
    "# Format it as ISO 8601 string\n",
    "isoformatted_datetime = current_datetime.isoformat()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b172e9b7-e2fd-47c1-a034-3bb9f5d540bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "api = HfApi()\n",
    "repo_url = api.create_repo(repo_id=repo_id, repo_type=None, exist_ok=True, token=token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d64051cf-2f59-4279-b48e-c923cd6ca7a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "api.upload_folder(\n",
    "    repo_id=repo_url.repo_id,\n",
    "    folder_path=f'./{model_name}',\n",
    "    path_in_repo=f'models/{model_name}/{isoformatted_datetime}',\n",
    "    token=token,\n",
    "    repo_type=None\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rlhf",
   "language": "python",
   "name": "rlhf"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
