{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "188d001d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.backends.backend_pdf import PdfPages\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bc937f36",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 1000 # number of different prompts\n",
    "cutoff = 10000 # best of n, max 10000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0c97982",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_dir = 'your_path/scores/'\n",
    "\n",
    "fname2 = base_dir + 'combined_scores_llama_7b_small_sam0.pkl' # using Small proxy as an example\n",
    "proxy_mean = <your_mean>\n",
    "proxy_std = <your_std>\n",
    "\n",
    "_, save_filename = os.path.split(fname2)\n",
    "\n",
    "with open(fname2, 'rb') as f:\n",
    "    proxy_data = pickle.load(f)\n",
    "\n",
    "proxy_data = (np.array(proxy_data) - proxy_mean) / proxy_std\n",
    "proxy_data = proxy_data[:N, :cutoff]\n",
    "proxy_data.mean(), proxy_data.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45f2d052",
   "metadata": {},
   "outputs": [],
   "source": [
    "fname1 = 'your_path/gold_model1'\n",
    "fname2 = 'your_path/gold_model2'\n",
    "\n",
    "mean1 = <mean1>\n",
    "std1 = <std1>\n",
    "mean2 = <mean2>\n",
    "std2 = <std2>\n",
    "\n",
    "with open(fname1, 'rb') as f:\n",
    "    data1 = pickle.load(f)\n",
    "with open(fname2, 'rb') as f:\n",
    "    data2 = pickle.load(f)\n",
    "\n",
    "data1, data2 = (np.array(data1) - mean1) / std1, (np.array(data2) - mean2) / std2\n",
    "assert data1.shape == data2.shape\n",
    "data1, data2 = data1[:N, :cutoff], data2[:N, :cutoff]\n",
    "gold_data = (data1 + data2) / 2 / <normalizing_constant>\n",
    "gold_data.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39f75c43",
   "metadata": {},
   "outputs": [],
   "source": [
    "col1, col2, col3 = [\"\"] * cutoff * N, proxy_data.T.reshape((-1)), gold_data.T.reshape((-1))\n",
    "\n",
    "df = pd.DataFrame({0: col1, 1: col2, 2: col3})\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34c61381",
   "metadata": {},
   "outputs": [],
   "source": [
    "x, y_proxy, y_gold, kl = [],[],[],[]\n",
    "all_proxy_scores, all_gold_scores = [], []\n",
    "y_proxy_individual, y_gold_individual = [],[]\n",
    "\n",
    "highest_proxy_idx, highest_proxy_reward = None, None\n",
    "\n",
    "proxy_std, gold_std = df[1].values.reshape((N,-1)).std(0).mean(), df[2].values.reshape((N,-1)).std(0).mean()\n",
    "for i in tqdm(range(len(df) // N)):\n",
    "    samples = df.iloc[i*N: (i+1)*N].values[:,0].astype(str)\n",
    "    proxy_scores = df.iloc[i*N: (i+1)*N].values[:,1].astype(float)\n",
    "    gold_scores = df.iloc[i*N: (i+1)*N].values[:,2].astype(float)\n",
    "    x.append(i)\n",
    "    kl.append(np.log(i+1) - (i)/(i+1))\n",
    "\n",
    "    # Compute masks where current proxy rewards are higher than maximum so far\n",
    "    if highest_proxy_idx is not None:\n",
    "        mask = proxy_scores > highest_proxy_rewards\n",
    "        highest_proxy_rewards[mask] = proxy_scores[mask]\n",
    "        \n",
    "        highest_proxy_idx[mask] = i\n",
    "        highest_gold_rewards[mask] = gold_scores[mask]\n",
    "    else:\n",
    "        highest_proxy_idx = np.zeros(N, dtype=int)\n",
    "        highest_proxy_rewards = proxy_scores.copy()\n",
    "        highest_gold_rewards = gold_scores.copy()\n",
    "\n",
    "    y_proxy_individual.append(highest_proxy_rewards.copy())\n",
    "    y_gold_individual.append(highest_gold_rewards.copy())\n",
    "    highest_proxy_reward = highest_proxy_rewards.mean()\n",
    "    highest_gold_reward = highest_gold_rewards.mean()\n",
    "    \n",
    "    y_proxy.append(highest_proxy_reward)\n",
    "    y_gold.append(highest_gold_reward)\n",
    "y_proxy_individual = np.array(y_proxy_individual).T\n",
    "y_gold_individual = np.array(y_gold_individual).T\n",
    "(proxy_std, gold_std), highest_gold_reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a283143",
   "metadata": {},
   "outputs": [],
   "source": [
    "page = 0\n",
    "corrs = []\n",
    "metrics = [[] for _ in range(2)]\n",
    "window_size = int(cutoff * 0.1)\n",
    "step = max(round(cutoff // 100), 1)\n",
    "\n",
    "cutoff = min(cutoff, y_proxy_individual.shape[1])\n",
    "j = 0\n",
    "for i,(p,g) in enumerate(zip(y_proxy_individual, y_gold_individual)):\n",
    "    p = pd.Series(p).rolling(window_size, min_periods=1, center=True, step=step)\n",
    "    p = p.mean()\n",
    "    g = pd.Series(g).rolling(window_size, min_periods=1, center=True, step=step)\n",
    "    g = g.mean()\n",
    "    corr = np.corrcoef(p[:cutoff],g[:cutoff])[0,1]\n",
    "    if np.isnan(corr):\n",
    "        corr = 1\n",
    "    corrs.append(corr)\n",
    "    \n",
    "    diff = np.diff(p[:cutoff]) - np.diff(g[:cutoff])\n",
    "    if (np.diff(g[:cutoff]) != 0).sum() == 0:\n",
    "        metrics[3].append(0)\n",
    "        metrics[4].append(0)\n",
    "        metrics[5].append(0)\n",
    "        continue\n",
    "    metrics[0].append(((diff ** 2).sum() / (np.diff(g[:cutoff]) != 0).sum()) ** (1/2) * 100) # RMS\n",
    "    metrics[1].append(np.abs(p[:cutoff] - g[:cutoff]).mean())\n",
    "\n",
    "print(np.array(metrics).mean(1))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
