{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4diD3eGnC1z0"
      },
      "outputs": [],
      "source": [
        "import random\n",
        "import torch\n",
        "from datasets import load_dataset\n",
        "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
        "from tqdm import tqdm\n",
        "\n",
        "# =========================\n",
        "# Config\n",
        "# =========================\n",
        "device = \"cuda:0\"\n",
        "model_name = \"Skywork/Skywork-Reward-Llama-3.1-8B\"\n",
        "dataset_name = \"Skywork/Skywork-Reward-Preference-80K-v0.1\"\n",
        "num_samples = 5000\n",
        "max_length = 16384\n",
        "\n",
        "# =========================\n",
        "# Load model & tokenizer\n",
        "# =========================\n",
        "rm = AutoModelForSequenceClassification.from_pretrained(\n",
        "    model_name,\n",
        "    torch_dtype=torch.bfloat16,\n",
        "    device_map=\"auto\",\n",
        "    num_labels=1,\n",
        ")\n",
        "rm_tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
        "\n",
        "# =========================\n",
        "# Load dataset\n",
        "# =========================\n",
        "dataset = load_dataset(dataset_name, split=\"train\")\n",
        "\n",
        "# =========================\n",
        "# Randomly sample indices\n",
        "# =========================\n",
        "indices = random.sample(range(len(dataset)), num_samples)\n",
        "\n",
        "scores = []\n",
        "\n",
        "# =========================\n",
        "# Main loop with progress bar\n",
        "# =========================\n",
        "for idx in tqdm(indices, desc=\"Scoring samples\", total=num_samples):\n",
        "    sample = dataset[idx]\n",
        "\n",
        "    conv1 = sample[\"chosen\"]\n",
        "    conv2 = sample[\"rejected\"]\n",
        "\n",
        "    conv1_formatted = rm_tokenizer.apply_chat_template(\n",
        "        conv1, tokenize=False\n",
        "    )\n",
        "    conv2_formatted = rm_tokenizer.apply_chat_template(\n",
        "        conv2, tokenize=False\n",
        "    )\n",
        "\n",
        "    conv1_tokenized = rm_tokenizer(\n",
        "        conv1_formatted,\n",
        "        return_tensors=\"pt\",\n",
        "        truncation=True,\n",
        "        max_length=max_length,\n",
        "    ).to(device)\n",
        "\n",
        "    conv2_tokenized = rm_tokenizer(\n",
        "        conv2_formatted,\n",
        "        return_tensors=\"pt\",\n",
        "        truncation=True,\n",
        "        max_length=max_length,\n",
        "    ).to(device)\n",
        "\n",
        "    with torch.inference_mode():\n",
        "        score1 = rm(**conv1_tokenized).logits[0][0].item()\n",
        "        score2 = rm(**conv2_tokenized).logits[0][0].item()\n",
        "\n",
        "    scores.append((score1, score2))"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "SAVE_DIR =\n",
        "import os\n",
        "os.makedirs(SAVE_DIR, exist_ok=True)\n",
        "SAVE_PATH = os.path.join(SAVE_DIR, \"Skywork_relative_rewards_max16384_2.pt\")\n",
        "\n",
        "torch.save(\n",
        "    {\n",
        "        \"rewards\": reward,\n",
        "        \"rewards1\": reward1,\n",
        "        \"rewards2\": reward2,\n",
        "        \"max_length\": 16384,\n",
        "        \"num_samples\": len(reward),\n",
        "    },\n",
        "    SAVE_PATH,\n",
        ")"
      ],
      "metadata": {
        "id": "M2jl2SzEC-GE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "reward = []\n",
        "reward1 = []\n",
        "reward2 = []\n",
        "for s1, s2 in scores:\n",
        "    reward1.append(s1)\n",
        "    reward2.append(s2)\n",
        "    reward.append(abs(s1-s2))"
      ],
      "metadata": {
        "id": "FcMNRaUoDBId"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "PATH1 =\n",
        "PATH2 =\n",
        "\n",
        "\n",
        "import torch\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "def load_rewards(path):\n",
        "    data = torch.load(path, map_location=\"cpu\")\n",
        "    r = data[\"rewards\"]\n",
        "    # list / tensor どっちでもOKにする\n",
        "    if isinstance(r, torch.Tensor):\n",
        "        r = r.detach().cpu().numpy()\n",
        "    else:\n",
        "        r = np.array(r, dtype=np.float64)\n",
        "    return r\n",
        "\n",
        "r1 = load_rewards(PATH1)\n",
        "r2 = load_rewards(PATH2)\n",
        "\n",
        "# 1) 大きい順ソート\n",
        "r1 = np.sort(r1)[::-1]\n",
        "r2 = np.sort(r2)[::-1]\n",
        "\n",
        "# 2) 2つのうち全体最小を -c と置いて、全要素に +c（=> 全体最小が 0 になる）\n",
        "r1s = r1\n",
        "r2s = r2\n",
        "\n",
        "# 3) 棒グラフ（2つ目を背景、1つ目を手前、枠線なし）\n",
        "n = min(len(r1s), len(r2s))\n",
        "x = np.arange(n)\n",
        "\n",
        "plt.figure(figsize=(4.7, 3.55))\n",
        "\n",
        "plt.bar(\n",
        "    x, r1s[:n],\n",
        "    width=1.0,\n",
        "    linewidth=0,\n",
        "    edgecolor=\"none\",\n",
        "    label=\"Skywork\"\n",
        ")\n",
        "\n",
        "plt.bar(\n",
        "    x, r2s[:n],\n",
        "    width=1.0,\n",
        "    linewidth=0,\n",
        "    edgecolor=\"none\",\n",
        "    label=\"UltraRM\"\n",
        ")\n",
        "\n",
        "ticks = np.linspace(0, n, 6)\n",
        "tick_labels = [f\"{t/n:.1f}\" for t in ticks]\n",
        "plt.xticks(ticks, tick_labels, fontsize=14)\n",
        "plt.xlim(0,5000)\n",
        "plt.xlabel(\"Cumulative fraction\", fontsize=18)\n",
        "plt.ylabel(\"$\\Delta r$\", fontsize=18)\n",
        "plt.legend(frameon=True, fontsize=15, bbox_to_anchor=(0.48, 0.87))\n",
        "\n",
        "ax = plt.gca()\n",
        "# === 追加部分ここから ===\n",
        "y_max = 25.390625\n",
        "x_start_frac = 0.0\n",
        "x_end_frac = 0.1\n",
        "\n",
        "# fraction → index に変換\n",
        "x_start = int(x_start_frac * n)\n",
        "x_end = int(x_end_frac * n)\n",
        "\n",
        "# 赤の点線\n",
        "ax.hlines(\n",
        "    y=y_max,\n",
        "    xmin=x_start,\n",
        "    xmax=x_end,\n",
        "    colors=\"#ff7f0e\",\n",
        "    linestyles=\"dashed\",\n",
        "    linewidth=2\n",
        ")\n",
        "\n",
        "# 右横にテキスト\n",
        "ax.text(\n",
        "    x_end + 50,          # 少し右にずらす\n",
        "    y_max,\n",
        "    r\"$\\max \\,\\Delta r = 25.4$\",\n",
        "    color=\"#ff7f0e\",\n",
        "    fontsize=16,\n",
        "    va=\"center\"\n",
        ")\n",
        "\n",
        "y_max = 108.75\n",
        "x_start_frac = 0.0\n",
        "x_end_frac = 0.1\n",
        "\n",
        "# fraction → index に変換\n",
        "x_start = int(x_start_frac * n)\n",
        "x_end = int(x_end_frac * n)\n",
        "\n",
        "# 赤の点線\n",
        "ax.hlines(\n",
        "    y=y_max,\n",
        "    xmin=x_start,\n",
        "    xmax=x_end,\n",
        "    colors=\"#1f77b4\",\n",
        "    linestyles=\"dashed\",\n",
        "    linewidth=2\n",
        ")\n",
        "\n",
        "# 右横にテキスト\n",
        "ax.text(\n",
        "    x_end + 50,          # 少し右にずらす\n",
        "    y_max,\n",
        "    r\"$\\max \\,\\Delta r = 108.8$\",\n",
        "    color=\"#1f77b4\",\n",
        "    fontsize=16,\n",
        "    va=\"center\"\n",
        ")\n",
        "# === 追加部分ここまで ===\n",
        "\n",
        "ax.tick_params(axis=\"y\", labelsize=14)\n",
        "plt.ylim(0, 119)\n",
        "\n",
        "plt.tight_layout()\n",
        "out = \"\"\n",
        "plt.savefig(out, bbox_inches=\"tight\")"
      ],
      "metadata": {
        "id": "KQPsYdeIDBsJ"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}