{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f7e319593d6880f3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-15T12:35:17.517457Z",
     "start_time": "2025-05-15T12:35:13.446699Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import pickle\n",
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
    "from datasets import load_dataset\n",
    "import GPUtil\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "initial_id",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-15T12:35:17.834503Z",
     "start_time": "2025-05-15T12:35:17.831801Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "reward_model_name = \"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\"\n",
    "reward_model_name = \"LxzGordon/URM-LLaMa-3.1-8B\"\n",
    "dataset_name = \"tatsu-lab/alpaca_farm\"\n",
    "dataset_subset = \"alpaca_human_preference\"\n",
    "dataset_split = \"preference\"\n",
    "\n",
    "output_hidden_states = False # only work for \"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\"\n",
    "if reward_model_name == \"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\":\n",
    "    output_hidden_states = True  # only work for \"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "535fc8a4fcdef78c",
   "metadata": {},
   "source": [
    "## Set util functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e620b88b54195da9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-15T12:35:17.893926Z",
     "start_time": "2025-05-15T12:35:17.887308Z"
    }
   },
   "outputs": [],
   "source": [
    "def release_gpu_memory():\n",
    "    \"\"\"\n",
    "    Function to release GPU memory and print GPU memory usage before and after releasing.\n",
    "    \"\"\"\n",
    "    print(\"Releasing GPU memory...\")\n",
    "    # GPU memory usage\n",
    "    gpus = GPUtil.getGPUs()\n",
    "    print(\"  Before releasing GPU memory:\")\n",
    "    for gpu in gpus:\n",
    "        print(f\"   - GPU {gpu.id}: {gpu.memoryFree}MB free | {gpu.memoryUsed}MB used | {gpu.memoryTotal}MB total\")\n",
    "\n",
    "    # release GPU memory\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    # GPU memory usage after releasing\n",
    "    gpus = GPUtil.getGPUs()\n",
    "    print(\"  After releasing GPU memory:\")\n",
    "    for gpu in gpus:\n",
    "        print(f\"   - GPU {gpu.id}: {gpu.memoryFree}MB free | {gpu.memoryUsed}MB used | {gpu.memoryTotal}MB total\")\n",
    "\n",
    "\n",
    "def get_short_name(hf_name: str) -> str:\n",
    "    \"\"\"\n",
    "    Get a short name from the full huggingface name.\n",
    "\n",
    "    Args:\n",
    "        hf_name: Full huggingface name.\n",
    "\n",
    "    Returns:\n",
    "        Short name.\n",
    "    \"\"\"\n",
    "    # split the model name by '/' and take the last part\n",
    "    short_name = hf_name.split('/')[-1]\n",
    "\n",
    "    # Replace '-' with '_'\n",
    "    short_name = short_name.replace(\":\", \"_\").replace(\"\\\\\", \"_\")\n",
    "\n",
    "    return short_name\n",
    "\n",
    "\n",
    "def get_reward_dir(data_name: str,\n",
    "                 data_subset: str,\n",
    "                 data_split: str) -> str:\n",
    "    # Get short dataset name\n",
    "    short_dataset_name = get_short_name(data_name)\n",
    "    return f\"../data/split_{data_split}/reward/\"\n",
    "\n",
    "\n",
    "def get_hidden_state_dir(data_name: str,\n",
    "                         data_subset: str,\n",
    "                         data_split: str) -> str:\n",
    "    # Get short dataset name\n",
    "    short_dataset_name = get_short_name(data_name)\n",
    "    return f\"../data/split_{data_split}/hidden_state/\"\n",
    "\n",
    "\n",
    "def get_reward_filename(rm_name: str,\n",
    "                        data_name: str,\n",
    "                        data_subset: str,\n",
    "                        data_split: str) -> str:\n",
    "    # Get short rm name\n",
    "    short_rm_name = get_short_name(rm_name)\n",
    "\n",
    "    # Get save directory\n",
    "    save_dir = get_reward_dir(data_name, data_subset, data_split)\n",
    "    os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "    return save_dir + \"/\" + short_rm_name + \".pkl\"\n",
    "\n",
    "\n",
    "def get_hidden_state_filename(rm_name: str,\n",
    "                               data_name: str,\n",
    "                               data_subset: str,\n",
    "                               data_split: str) -> str:\n",
    "    # Get short rm name\n",
    "    short_rm_name = get_short_name(rm_name)\n",
    "\n",
    "    # Get save directory\n",
    "    save_dir = get_hidden_state_dir(data_name, data_subset, data_split)\n",
    "    os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "    return save_dir + \"/\" + short_rm_name + \".pkl\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "945c8763f88049f6",
   "metadata": {},
   "source": [
    "## Load reward model and dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4c4d2f6c6f0e50bf",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-15T12:35:23.736228Z",
     "start_time": "2025-05-15T12:35:17.933858Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.51it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "device: cpu\n",
      "reward model: LxzGordon/URM-LLaMa-3.1-8B\n",
      "Releasing GPU memory...\n",
      "  Before releasing GPU memory:\n",
      "  After releasing GPU memory:\n"
     ]
    }
   ],
   "source": [
    "# Load reward model and its tokenizer\n",
    "if reward_model_name in {\"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\"}:\n",
    "    reward_model = AutoModelForSequenceClassification.from_pretrained(\n",
    "        reward_model_name,\n",
    "        torch_dtype=torch.bfloat16,\n",
    "        # device_map=\"cuda:0\",\n",
    "        attn_implementation=\"flash_attention_2\",\n",
    "        num_labels=1,\n",
    "    )\n",
    "    tokenizer = AutoTokenizer.from_pretrained(reward_model_name)\n",
    "\n",
    "elif reward_model_name in {\"LxzGordon/URM-LLaMa-3.1-8B\"}:\n",
    "    reward_model = AutoModelForSequenceClassification.from_pretrained(\n",
    "        reward_model_name,\n",
    "        # device_map=\"cuda:0\",\n",
    "        trust_remote_code=True\n",
    "    )\n",
    "    tokenizer = AutoTokenizer.from_pretrained(reward_model_name, use_fast=True)\n",
    "else:\n",
    "    NotImplementedError(f\"Model {reward_model_name} is not implemented.\")\n",
    "\n",
    "print(f\"device: {reward_model.device}\")\n",
    "print(f\"reward model: {reward_model_name}\")\n",
    "\n",
    "# release GPU memory\n",
    "release_gpu_memory()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8d84e151f6f0414e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-15T12:35:26.014821Z",
     "start_time": "2025-05-15T12:35:23.758094Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sample prompts from Alpaca:\n",
      "\n",
      "Prompt 0:\n",
      "  [instruction] Append the following sentence to the end of the input.\n",
      "  [input] The house saw strange events that night.\n",
      "  [output_1] The house saw strange events that night that remained a mystery.\n",
      "  [output_2] The house saw strange events that night, and something was lurking in the shadows.\n",
      "  [preference] 2\n"
     ]
    }
   ],
   "source": [
    "# Load dataset\n",
    "dataset = load_dataset(dataset_name, dataset_subset, split=dataset_split)\n",
    "\n",
    "# Print dataset example\n",
    "print(\"Sample prompts from Alpaca:\")\n",
    "for i in range(1):\n",
    "    print(f\"\\nPrompt {i}:\")\n",
    "    print(f\"  [instruction] {dataset[i]['instruction']}\")\n",
    "    print(f\"  [input] {dataset[i]['input']}\")\n",
    "    print(f\"  [output_1] {dataset[i]['output_1']}\")\n",
    "    print(f\"  [output_2] {dataset[i]['output_2']}\")\n",
    "    print(f\"  [preference] {dataset[i]['preference']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17a3232bb2e6de5e",
   "metadata": {},
   "source": [
    "## Prepare prompts and responses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "36efc8412fb151a2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-15T12:35:26.416845Z",
     "start_time": "2025-05-15T12:35:26.040255Z"
    }
   },
   "outputs": [],
   "source": [
    "# Prepare prompts (=instruction + input) and responses for reward model\n",
    "prompts = []\n",
    "responses_1 = []\n",
    "responses_2 = []\n",
    "\n",
    "for sample in dataset:\n",
    "    # prompt\n",
    "    if sample[\"input\"]:\n",
    "        prompt = f\"{sample['instruction']}\\n{sample['input']}\"\n",
    "    else:\n",
    "        prompt = sample[\"instruction\"]\n",
    "    prompts.append(prompt)\n",
    "\n",
    "    # responses\n",
    "    responses_1.append(sample[\"output_1\"])\n",
    "    responses_2.append(sample[\"output_2\"])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9743665e68f8177c",
   "metadata": {},
   "source": [
    "## Compute reward scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a197fb82c4f9914",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-15T13:24:46.446030Z",
     "start_time": "2025-05-15T12:35:26.442731Z"
    }
   },
   "outputs": [],
   "source": [
    "reward_model.eval()\n",
    "batch_size = 5\n",
    "\n",
    "rewards_1 = []\n",
    "rewards_2 = []\n",
    "hidden_states_1 = []\n",
    "hidden_states_2 = []\n",
    "\n",
    "pad_token_id = tokenizer.pad_token_id  # used to extract the last token position\n",
    "\n",
    "for i in tqdm(range(0, len(prompts), batch_size)):\n",
    "    prompt_batch = prompts[i:i + batch_size] + prompts[i:i + batch_size]\n",
    "    response_batch = responses_1[i:i + batch_size] + responses_2[i:i + batch_size]\n",
    "    actual_batch_size = int(len(prompt_batch) / 2)\n",
    "\n",
    "    if reward_model_name in {\"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\"}:\n",
    "        # prepare chat massages\n",
    "        chat_messages = [[{\"role\": \"user\", \"content\": prompt},\n",
    "                          {\"role\": \"assistant\", \"content\": response}]\n",
    "                         for prompt, response in zip(prompt_batch, response_batch)]\n",
    "\n",
    "        # tokenize inputs\n",
    "        inputs = tokenizer.apply_chat_template(\n",
    "            chat_messages,\n",
    "            return_tensors=\"pt\",\n",
    "            tokenize=True,\n",
    "            padding=True\n",
    "        ).to(reward_model.device)\n",
    "\n",
    "        # model inference\n",
    "        with torch.no_grad():\n",
    "            outputs = reward_model(inputs, output_hidden_states=output_hidden_states)\n",
    "\n",
    "        # Get reward scores\n",
    "        # if reward_model_name in {\"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\"}:\n",
    "        rewards = outputs.logits.cpu().float()\n",
    "\n",
    "    elif reward_model_name == \"LxzGordon/URM-LLaMa-3.1-8B\":\n",
    "        # prepare chat massages\n",
    "        chat_messages = [[{\"role\": \"user\", \"content\": prompt},\n",
    "                          {\"role\": \"assistant\", \"content\": response}]\n",
    "                         for prompt, response in zip(prompt_batch, response_batch)]\n",
    "\n",
    "        # format and tokenize chat messages\n",
    "        inputs = tokenizer.apply_chat_template(chat_messages, tokenize=False)\n",
    "        inputs = tokenizer(inputs, return_tensors='pt', padding=True).to(reward_model.device)\n",
    "\n",
    "        # model inference and get reward scores\n",
    "        with torch.no_grad():\n",
    "            rewards = reward_model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits.cpu().float()\n",
    "\n",
    "    else:\n",
    "        raise NotImplementedError(f\"Model {reward_model_name} is not implemented.\")\n",
    "\n",
    "    rewards_1.append(rewards[:actual_batch_size, 0])\n",
    "    rewards_2.append(rewards[actual_batch_size:, 0])\n",
    "\n",
    "    # Get hidden states\n",
    "    if output_hidden_states:\n",
    "        # get the final token position\n",
    "        attention_mask = inputs != pad_token_id\n",
    "        last_token_indices = attention_mask.sum(dim=1) - 1\n",
    "        full_hidden_states = outputs.hidden_states[-1].detach().cpu().float()\n",
    "        # get the hidden states of the last token\n",
    "        hidden_states = torch.stack([\n",
    "            full_hidden_states[i, idx, :] for i, idx in enumerate(last_token_indices)\n",
    "        ])\n",
    "        # divide the hidden states into two responses\n",
    "        hidden_states_1.append(hidden_states[:actual_batch_size,:])\n",
    "        hidden_states_2.append(hidden_states[actual_batch_size:,:])\n",
    "\n",
    "# Finalize\n",
    "rewards_1 = torch.cat(rewards_1)\n",
    "rewards_2 = torch.cat(rewards_2)\n",
    "if output_hidden_states:\n",
    "    reward_weight = reward_model.score.weight.squeeze().cpu().detach().float()\n",
    "    hidden_states_1 = torch.cat(hidden_states_1)\n",
    "    hidden_states_2 = torch.cat(hidden_states_2)\n",
    "\n",
    "# release GPU memory\n",
    "release_gpu_memory()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c663a2567d5394af",
   "metadata": {},
   "source": [
    "### check the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "245ab06135ce55f6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-15T13:24:46.498120Z",
     "start_time": "2025-05-15T13:24:46.494653Z"
    }
   },
   "outputs": [],
   "source": [
    "# check the shape of reward\n",
    "print(f\"rewards_1: {rewards_1.shape}\")\n",
    "print(f\"rewards_2: {rewards_2.shape}\")\n",
    "\n",
    "# check the shape of hidden states\n",
    "if output_hidden_states:\n",
    "    print(f\"hidden_states_1: {hidden_states_1.shape}\")\n",
    "    print(f\"hidden_states_2: {hidden_states_2.shape}\")\n",
    "    print(f\"reward_weight: {reward_weight.shape}\")\n",
    "\n",
    "# Check the validity of weight and hidden states\n",
    "if output_hidden_states:\n",
    "    idx = 6\n",
    "    print(f\"model output (responses_1): {rewards_1[:idx]}\")\n",
    "    print(f\"weight * hidden (responses_1): {hidden_states_1[:idx] @ reward_weight}\")\n",
    "    print(f\"model output (responses_2): {rewards_2[:idx]}\")\n",
    "    print(f\"weight * hidden (responses_2): {hidden_states_2[:idx] @ reward_weight}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a684678a6abc328",
   "metadata": {},
   "source": [
    "## Save results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "460030117d88b128",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-15T13:24:46.533599Z",
     "start_time": "2025-05-15T13:24:46.526607Z"
    }
   },
   "outputs": [],
   "source": [
    "# save reward scores\n",
    "reward_filename = get_reward_filename(\n",
    "    reward_model_name,\n",
    "    dataset_name,\n",
    "    dataset_subset,\n",
    "    dataset_split\n",
    ")\n",
    "\n",
    "with open(reward_filename, \"wb\") as f:\n",
    "    pickle.dump({\n",
    "        \"rewards_1\": rewards_1,\n",
    "        \"rewards_2\": rewards_2\n",
    "    }, f)\n",
    "\n",
    "\n",
    "if output_hidden_states:\n",
    "    hidden_state_filename = get_hidden_state_filename(\n",
    "        reward_model_name,\n",
    "        dataset_name,\n",
    "        dataset_subset,\n",
    "        dataset_split\n",
    "    )\n",
    "\n",
    "    with open(hidden_state_filename, \"wb\") as f:\n",
    "        pickle.dump({\n",
    "            \"weight\": reward_weight,\n",
    "            \"hidden_states_1\": hidden_states_1,\n",
    "            \"hidden_states_2\": hidden_states_2\n",
    "        }, f)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
