{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4adc49c1-5b23-4c47-897f-0364d28dd620",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d90093a-7127-4a1c-a9ff-0c9c0f93f037",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "robust_type = 'pez'\n",
    "attack =  'gbda'\n",
    "runid = f\"{robust_type}_on_{attack}\"\n",
    "base_dir = f\"/data/long_phan/proxy_gaming/src/adv/data/final_runs/{runid}\"\n",
    "fig, ax = plt.subplots(figsize=(12, 8))  # Adjust figsize as needed\n",
    "\n",
    "N = 1000\n",
    "num_steps = 200\n",
    "eval_steps = 10\n",
    "anchor_gold = None\n",
    "pez_mode = 'prepend'\n",
    "tokens = 4\n",
    "\n",
    "proxy_scores = []\n",
    "gold_scores = []\n",
    "lines =  []\n",
    "for model_size in ['large']:\n",
    "    for robust in ['nonrobust', 'robust']:\n",
    "    # for robust in ['nonrobust']:\n",
    "\n",
    "        if robust == 'robust':\n",
    "            file_name = f'{base_dir}/{model_size}_proxy{robust}_{robust_type}_{pez_mode}_{N}N_{num_steps}sattack_{eval_steps}seval_{tokens}tokens.npy'\n",
    "        else:\n",
    "            file_name = f'{base_dir}/{model_size}_proxynonrobust_{pez_mode}_{N}N_{num_steps}sattack_{eval_steps}seval_{tokens}tokens.npy'\n",
    "\n",
    "        if attack == 'fgsm':\n",
    "            file_name = file_name.replace(f\"_{pez_mode}\", \"\").replace(f\"_{tokens}tokens\", \"\")\n",
    "            \n",
    "        print(file_name)\n",
    "        with open(file_name, 'rb') as f:\n",
    "            all_proxy = np.load(f)\n",
    "            all_gold = np.load(f) if attack != 'fgsm' else []\n",
    "            adv_prompts = np.load(f) if attack != 'fgsm' else []\n",
    "        \n",
    "        total_steps = num_steps // eval_steps\n",
    "        if attack ==  'pez':\n",
    "            total_steps += 1\n",
    "            cutoff_steps = 1\n",
    "        else:\n",
    "            cutoff_steps = 0 \n",
    "\n",
    "        y_proxy_individual = np.array(all_proxy).T.reshape((-1, total_steps))\n",
    "        y_gold_individual = np.array(all_gold).T.reshape((-1, total_steps))  \n",
    "        adv_prompt_individual = np.array(adv_prompts).T.reshape((-1, total_steps)) \n",
    "\n",
    "        \n",
    "        y_proxy = y_proxy_individual.mean(0) \n",
    "        y_gold = y_gold_individual.mean(0) / 0.9648238992568243\n",
    "        \n",
    "        total_steps -= cutoff_steps\n",
    "        y_proxy = y_proxy[cutoff_steps:]\n",
    "        y_gold = y_gold[cutoff_steps:]\n",
    "\n",
    "        proxy_scores.append(y_proxy_individual)\n",
    "        gold_scores.append(y_gold_individual)\n",
    "\n",
    "        y_gold += 0 - y_gold[0]\n",
    "        y_proxy += 0 - y_proxy[0]\n",
    "    \n",
    "\n",
    "        x = range(total_steps)\n",
    "\n",
    "        line, = ax.plot(x, y_proxy, label=f\"{model_size} {robust} proxy\", color='blue')\n",
    "        if robust == 'nonrobust': line.set_linestyle('--')\n",
    "\n",
    "        line, = ax.plot(x, y_gold, label=f\"{robust} gold\", color='orange')\n",
    "        if robust == 'nonrobust': line.set_linestyle('--')\n",
    "\n",
    "        ax.set_xlabel('n')\n",
    "        ax.set_ylabel('Reward')\n",
    "        \n",
    "        print(f\"{robust} Change Score\", np.linalg.norm(np.diff(y_proxy) - np.diff(y_gold)))\n",
    "        print('====')\n",
    "\n",
    "labels = [line.get_label() for line in ax.lines]\n",
    "ax.legend(bbox_to_anchor=(1.1, 1.03), loc='upper left')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2c152032-ddb5-4ff2-bac7-e5fda4cbcdd1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "large nonrobust\n",
      "RMS 8.290060264014999\n",
      "large robust\n",
      "RMS 7.341792278321582\n"
     ]
    }
   ],
   "source": [
    "page = 0\n",
    "metrics = [[] for _ in range(1)]\n",
    "window_size =  2\n",
    "cutoff = y_proxy_individual.shape[1]\n",
    "\n",
    "# PEZ has raw score from original steps\n",
    "if attack ==  'pez':\n",
    "    skip_step = 1\n",
    "else:\n",
    "    skip_step = 0 \n",
    "\n",
    "j = 0\n",
    "\n",
    "for proxy_score, gold_score, robust in zip(proxy_scores, gold_scores, ['nonrobust', 'robust']):\n",
    "    print(f\"{model_size} {robust}\")\n",
    "    gold_score = gold_score[:,skip_step:]\n",
    "    proxy_score = proxy_score[:,skip_step:]\n",
    "\n",
    "    p_score = proxy_score.mean(0)\n",
    "    g_score = gold_score.mean(0)\n",
    "    final_gold_score = g_score[np.argmax(p_score)]\n",
    "    \n",
    "    for i,(p,g)  in enumerate(zip(proxy_score,gold_score)):\n",
    "        p = pd.Series(p).rolling(window_size, min_periods=1, center=True)\n",
    "        p = p.mean()\n",
    "        \n",
    "        p += 0 - p[0]\n",
    "        g += 0 - g[0]\n",
    "\n",
    "        gold =  np.zeros_like(p)\n",
    "        gold = g\n",
    "                \n",
    "        diff = np.diff(p) - np.diff(gold)\n",
    "        metrics[0].append((diff ** 2).mean() ** (1/2) * 100) # RMS\n",
    "\n",
    "    print(\"RMS\", np.array(metrics).mean(1)[0])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fb6eff1-7a9c-40bd-ab05-e7b0cc37d771",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
