{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0b05712d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification\n",
    "import torch.nn.functional as F\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7963fe98",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#### NEW #####\n",
    "mode = 'your_output_log.csv'\n",
    "df = pd.read_csv(mode, header=None, on_bad_lines='skip')\n",
    "N = 2048\n",
    "df.groupby(df.index // N).mean(numeric_only=True).iloc[:60]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "587f11f6",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "thresholding = False\n",
    "train_test_split = True\n",
    "if train_test_split:\n",
    "    split_idx = N // 2\n",
    "\n",
    "x, y_proxy, y_gold, kl = [],[],[],[]\n",
    "y_proxy_train, y_gold_train, kl_train = [],[],[]\n",
    "y_proxy_individual, y_gold_individual = [],[]\n",
    "y_proxy_individual_train, y_gold_individual_train = [],[]\n",
    "responses = []\n",
    "std = 0\n",
    "for i in range(len(df) // N):\n",
    "    prompt_chunk = df.iloc[i*N: (i+1)*N].values[:,:2].astype(str)\n",
    "    chunk = df.iloc[i*N: (i+1)*N].values[:,2:].astype(float)\n",
    "    if not std:\n",
    "        std = 2 if chunk[:,0].std() > 1.5 else 1\n",
    "        print('std', chunk[:,0].std(), 'using', std)\n",
    "    chunk[:,1] /= std\n",
    "    \n",
    "    max_diff = (chunk[:,1] - chunk[:,3]).argmax()\n",
    "    x.append(i)\n",
    "\n",
    "    if thresholding:\n",
    "        chunk[:,1] = chunk[:,1] > 0\n",
    "        chunk[:,3] = chunk[:,3] > 0\n",
    "    \n",
    "    responses.append(prompt_chunk)\n",
    "    if train_test_split:\n",
    "        y_proxy.append(chunk[:split_idx,1].mean())\n",
    "        y_proxy_individual.append(chunk[:split_idx,1])\n",
    "        y_gold.append(chunk[:split_idx,3].mean())\n",
    "        y_gold_individual.append(chunk[:split_idx,3])\n",
    "        kl.append(chunk[:split_idx,4].mean())\n",
    "        y_proxy_train.append(chunk[split_idx:,1].mean())\n",
    "        y_proxy_individual_train.append(chunk[split_idx:,1])\n",
    "        y_gold_train.append(chunk[split_idx:,3].mean())\n",
    "        y_gold_individual_train.append(chunk[split_idx:,3])\n",
    "        kl_train.append(chunk[split_idx:,4].mean())\n",
    "    else:\n",
    "        y_proxy.append(chunk[:,1].mean())\n",
    "        y_proxy_individual.append(chunk[:,1])\n",
    "        y_gold.append(chunk[:,3].mean())\n",
    "        y_gold_individual.append(chunk[:,3])\n",
    "        kl.append(chunk[:,4].mean())\n",
    "if train_test_split:\n",
    "    y_proxy_individual_train = np.array(y_proxy_individual_train).T\n",
    "    y_gold_individual_train = np.array(y_gold_individual_train).T\n",
    "y_proxy_individual = np.array(y_proxy_individual).T\n",
    "y_gold_individual = np.array(y_gold_individual).T\n",
    "responses = np.array(responses)\n",
    "\n",
    "y_proxy += y_gold[0] - y_proxy[0]\n",
    "if train_test_split:\n",
    "    y_proxy_train += y_gold_train[0] - y_proxy_train[0]\n",
    "\n",
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "lines = []\n",
    "line, = ax1.plot(x, y_proxy, label=\"proxy\", color='blue')\n",
    "lines.append(line)\n",
    "line, = ax1.plot(x, y_gold, label=\"gold\", color='orange')\n",
    "lines.append(line)\n",
    "if train_test_split:\n",
    "    line, = ax1.plot(x, y_proxy_train, label=\"proxy_train\", linestyle='--', color='blue')\n",
    "    lines.append(line)\n",
    "    line, = ax1.plot(x, y_gold_train, label=\"gold_train\", linestyle='--', color='orange')\n",
    "    lines.append(line)\n",
    "\n",
    "ax2 = ax1.twinx()\n",
    "line, = ax2.plot(x, np.sqrt(kl), label=\"kl\", color='g')\n",
    "lines.append(line)\n",
    "\n",
    "ax1.set_xlabel('Steps')\n",
    "ax1.set_ylabel('Reward')\n",
    "ax2.set_ylabel('KL', color='g')\n",
    "ax2.set_ylim(0,500)\n",
    "\n",
    "# ax1.set_xscale('log')\n",
    "# ax2.set_xscale('log')\n",
    "\n",
    "labels = [line.get_label() for line in lines]\n",
    "plt.legend(lines, labels, bbox_to_anchor=(1.1, 1.03), loc='upper left')\n",
    "plt.show()\n",
    "\n",
    "li = -1\n",
    "print(y_proxy_train[li] - y_gold_train[li])\n",
    "print(y_proxy[li] - y_gold[li])\n",
    "print(np.argmax(y_proxy_train))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "480b54ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "ms = []\n",
    "\n",
    "########### CUTOFF ############\n",
    "\n",
    "start, cutoff = 0, min(500, np.argmax(y_proxy_train))\n",
    "window_size = int(cutoff * 0.1)\n",
    "step = max(round(cutoff // 100), 1)\n",
    "threshold = -0.\n",
    "\n",
    "y_proxy_individual = y_proxy_individual[:,:cutoff]\n",
    "y_gold_individual = y_gold_individual[:,:cutoff]\n",
    "y_proxy_individual_train = y_proxy_individual_train[:,:cutoff]\n",
    "y_gold_individual_train = y_gold_individual_train[:,:cutoff]\n",
    "y_proxy_individual += y_gold_individual[:,0].mean() - y_proxy_individual[:,0].mean()\n",
    "y_proxy_individual_train += y_gold_individual_train[:,0].mean() - y_proxy_individual_train[:,0].mean()\n",
    "cutoff -= 1\n",
    "\n",
    "\n",
    "low_corr, low_corr_train = [],[]\n",
    "normal_p, normal_g = [],[]\n",
    "gamed_p, gamed_g = [],[]\n",
    "normal_p_train, normal_g_train = [],[]\n",
    "gamed_p_train, gamed_g_train = [],[]\n",
    "cutoff = min(cutoff, y_proxy_individual.shape[1])\n",
    "ms.append(cutoff)\n",
    "ms.append(window_size)\n",
    "ms.append(step)\n",
    "ms.append(y_gold_individual_train[:,cutoff].mean())\n",
    "ms.append(y_gold_individual[:,cutoff].mean())\n",
    "\n",
    "for i, tup in enumerate(zip(\n",
    "    y_proxy_individual_train, y_gold_individual_train, y_proxy_individual, y_gold_individual)):\n",
    "    vecs = []\n",
    "    for vec in tup:\n",
    "        vec = pd.Series(vec).rolling(window_size, min_periods=1, center=True, step=1)\n",
    "        vecs.append(vec.mean())\n",
    "    if (np.corrcoef(vecs[0][start:cutoff//2],vecs[1][start:cutoff//2])[0,1] < threshold or \\\n",
    "       np.corrcoef(vecs[0][cutoff//2:cutoff],vecs[1][cutoff//2:cutoff])[0,1] < threshold or \\\n",
    "       np.corrcoef(vecs[0][start:cutoff],vecs[1][start:cutoff])[0,1] < threshold) and \\\n",
    "       np.corrcoef(vecs[0][start:cutoff],vecs[1][start:cutoff])[0,1] < 0.95:\n",
    "        low_corr_train.append(i)\n",
    "        gamed_p_train.append(vecs[0])\n",
    "        gamed_g_train.append(vecs[1])\n",
    "    else:\n",
    "        normal_p_train.append(vecs[0])\n",
    "        normal_g_train.append(vecs[1])\n",
    "    if (np.corrcoef(vecs[2][start:cutoff//2],vecs[3][start:cutoff//2])[0,1] < threshold or \\\n",
    "       np.corrcoef(vecs[2][cutoff//2:cutoff],vecs[3][cutoff//2:cutoff])[0,1] < threshold or \\\n",
    "       np.corrcoef(vecs[2][start:cutoff],vecs[3][start:cutoff])[0,1] < threshold) and \\\n",
    "       np.corrcoef(vecs[2][start:cutoff],vecs[3][start:cutoff])[0,1] < 0.95:\n",
    "        low_corr.append(i)\n",
    "        gamed_p.append(vecs[2])\n",
    "        gamed_g.append(vecs[3])\n",
    "    else:\n",
    "        normal_p.append(vecs[2])\n",
    "        normal_g.append(vecs[3])\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,4))\n",
    "print('GAMED:', len(low_corr_train), len(low_corr_train) / len(y_proxy_individual))\n",
    "print('GAMED:', len(low_corr), len(low_corr) / len(y_proxy_individual))\n",
    "\n",
    "\n",
    "try:\n",
    "    learned_p, learned_g = np.array(normal_p).mean(0), np.array(normal_g).mean(0)\n",
    "    learned_p_train, learned_g_train = np.array(normal_p_train).mean(0), np.array(normal_g_train).mean(0)\n",
    "    ax1.plot(range(len(learned_p_train[:cutoff])), learned_p_train[:cutoff], label=\"learned_proxy_train\", linestyle='--', color='blue')\n",
    "    ax1.plot(range(len(learned_g_train[:cutoff])), learned_g_train[:cutoff], label=\"learned_gold_train\", linestyle='--', color='orange')\n",
    "    ax1.plot(range(len(learned_p[:cutoff])), learned_p[:cutoff], label=\"learned_proxy\", color='blue')\n",
    "    ax1.plot(range(len(learned_g[:cutoff])), learned_g[:cutoff], label=\"learned_gold\", color='orange')\n",
    "    ax1.set_title(f'Learned subset ({100 - round(len(low_corr) / len(y_proxy_individual) * 100, 1)}%)')\n",
    "    ms.append(len(low_corr_train))\n",
    "except:\n",
    "    print('oops')\n",
    "    ms.append(0)\n",
    "try:\n",
    "    p, g = np.array(gamed_p).mean(0), np.array(gamed_g).mean(0)\n",
    "    p_train, g_train = np.array(gamed_p_train).mean(0), np.array(gamed_g_train).mean(0)\n",
    "    ax2.plot(range(len(p_train[:cutoff])), p_train[:cutoff], label=\"gamed_proxy_train\", linestyle='--', color='blue')\n",
    "    ax2.plot(range(len(g_train[:cutoff])), g_train[:cutoff], label=\"gamed_gold_train\", linestyle='--', color='orange')\n",
    "    ax2.plot(range(len(p[:cutoff])), p[:cutoff], label=\"gamed_proxy\", color='blue')\n",
    "    ax2.plot(range(len(g[:cutoff])), g[:cutoff], label=\"gamed_gold\", color='orange')\n",
    "    ax2.set_title(f'Gamed subset ({round(len(low_corr) / len(y_proxy_individual) * 100, 1)}%)')\n",
    "    ms.append(len(low_corr))\n",
    "\n",
    "    labels = [line.get_label() for line in lines]\n",
    "    plt.legend(lines, labels, bbox_to_anchor=(1.1, 1.03), loc='upper left')\n",
    "    plt.show()\n",
    "\n",
    "except:\n",
    "    print('oops')\n",
    "    ms.append(0)\n",
    "\n",
    "\n",
    "page = 0\n",
    "corrs = []\n",
    "diffs = []\n",
    "metrics = [[] for _ in range(3)]\n",
    "\n",
    "cutoff = min(cutoff, y_proxy_individual.shape[1])\n",
    "j = 0\n",
    "for i,(p,g) in enumerate(zip(y_proxy_individual_train, y_gold_individual_train)):\n",
    "    if i < page*n*n: continue\n",
    "    if j >= (page+1)*n*n: break\n",
    "    row, col = divmod(j - page*n*n, n)\n",
    "    p = pd.Series(p).rolling(window_size, min_periods=1, center=True, step=step)\n",
    "    p = p.mean()[start:cutoff].to_numpy() #/ y_proxy_individual.std(1).mean()\n",
    "    g = pd.Series(g).rolling(window_size, min_periods=1, center=True, step=step)\n",
    "    g = g.mean()[start:cutoff].to_numpy() #/ y_gold_individual.std(1).mean()\n",
    "    mid = len(p) // 2\n",
    "    corr1 = np.corrcoef(p[start:mid],g[start:mid])[0,1]\n",
    "    corr1 = 1 if np.isnan(corr1) else corr1\n",
    "    corr2 = np.corrcoef(p[mid:cutoff],g[mid:cutoff])[0,1]\n",
    "    corr2 = 1 if np.isnan(corr2) else corr2\n",
    "    corr3 = np.corrcoef(p[start:cutoff],g[start:cutoff])[0,1]\n",
    "    corr3 = 0 if np.isnan(corr3) else corr3\n",
    "    if corr3 > 0.95:\n",
    "        corr1 = corr2 = corr3\n",
    "    corrs.extend([corr1, corr2])\n",
    "    diff = np.diff(p[:cutoff]) - np.diff(g[:cutoff])\n",
    "    metrics[0].append(np.abs(diff).max()) # slope max diff\n",
    "    metrics[1].append(np.abs(diff).mean())\n",
    "    metrics[2].append((diff ** 2).mean() ** (1/2)) # L2 diff of first derivative\n",
    "    \n",
    "    diffs.append(diff)\n",
    "\n",
    "print(np.mean(corrs), np.array(metrics).mean(1))\n",
    "ms.append(np.mean(corrs))\n",
    "ms.extend(np.array(metrics).mean(1)*100)\n",
    "saved_corrs = diffs\n",
    "\n",
    "print(ms)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
