{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "import sys\n",
    "sys.path.insert(0, '../')\n",
    "sys.path.insert(0, '../datasets')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dir = '../results/dataset_val/'\n",
    "dataset = 'cifar10'\n",
    "budget = 1000\n",
    "num_players = 100\n",
    "dataset_size = 32\n",
    "seeds = [0,1,2,4]\n",
    "eps = 10\n",
    "model_name = \"resnet34\"\n",
    "flip_ratio = 0.3\n",
    "metric = \"shapley\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.listdir(results_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_names = []\n",
    "\n",
    "for file in os.listdir(results_dir):\n",
    "    if file.endswith('.npy') and dataset in file and f\"budget_{budget}_\" in file and f\"players_{num_players}_dataset_size_{dataset_size}_\" in file and model_name in file:\n",
    "        if not (\"no_dp\" in file or f\"eps_{eps}\" in file):\n",
    "            continue\n",
    "        f = '.'.join(file.split('.')[:-1])\n",
    "        exp_names.append(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_names.sort()\n",
    "exp_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = {}\n",
    "scores_no_dp = {}\n",
    "scores_no_mom = {}\n",
    "\n",
    "for seed in seeds:\n",
    "    for exp_name in exp_names:\n",
    "        if 'scores' in exp_name and f\"seed_{seed}\" in exp_name and metric in exp_name:\n",
    "            if \"no_dp\" in exp_name:\n",
    "                scores_no_dp[seed] = np.load(results_dir + exp_name + '.npy', allow_pickle=True)\n",
    "            elif \"no_momentum\" in exp_name:\n",
    "                scores_no_mom[seed] = np.load(results_dir + exp_name + '.npy', allow_pickle=True)\n",
    "            else:\n",
    "                scores[seed] = np.load(results_dir + exp_name + '.npy', allow_pickle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ys = []\n",
    "y_error = []\n",
    "num_data = num_players\n",
    "num_flipped = int(num_data * flip_ratio)\n",
    "partitions = np.arange(0, 1.0, 0.01)\n",
    "warmups = [\"no_dp\", \"iid\"]\n",
    "for partition in partitions:\n",
    "    warmups.append(partition)\n",
    "flipped_idx = np.arange(num_flipped)\n",
    "\n",
    "for i, warmup in enumerate(warmups):\n",
    "    cur_y = []\n",
    "    for seed in seeds:\n",
    "        if warmup == \"no_dp\":\n",
    "            cur_scores = scores_no_dp[seed]\n",
    "        elif warmup == \"iid\":\n",
    "            cur_scores = scores_no_mom[seed]\n",
    "        else:\n",
    "            cur_scores = scores[seed]\n",
    "        if warmup != \"no_dp\" and warmup != \"iid\":\n",
    "            cur_scores = cur_scores[int(len(cur_scores) * warmup):]\n",
    "        cur_scores = np.mean(cur_scores, axis=0)\n",
    "        lowest_idx = np.argsort(cur_scores)\n",
    "        # check how many of the lowest indices are in the flipped indices\n",
    "        num_correct = 0\n",
    "        for idx in lowest_idx[:num_flipped]:\n",
    "            if idx in flipped_idx:\n",
    "                num_correct += 1        \n",
    "        cur_y.append(num_correct / num_flipped)\n",
    "    ys.append(np.mean(cur_y))\n",
    "    y_error.append(np.std(cur_y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from notebooks.utils import set_up_plotting\n",
    "plt = set_up_plotting()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot boxplot of the differences with pretty fill colors\n",
    "# medianprops = dict(linestyle=None, linewidth=0)\n",
    "# plt.boxplot(acc[:,1:], patch_artist=True, showfliers=False, showmeans=False, medianprops=medianprops)\n",
    "# fill the boxes with colors\n",
    "# colors = ['lightblue', 'lightgreen', 'lightpink', 'lightyellow', 'lightgray', 'lightcoral']\n",
    "# for patch, color in zip(plt.gca().patches, colors):\n",
    "#     patch.set_facecolor(color)\n",
    "# draw error bars with ys and y_error\n",
    "plt.grid()\n",
    "plt.ylabel('Recall')\n",
    "# plt.xlabel('Warmup ratio')\n",
    "plt.title(f\"k={budget}, {model_name}, {dataset}\")\n",
    "x_ticks = []\n",
    "for warmup in warmups[1:]:\n",
    "    x_ticks.append(f\"{warmup}\")\n",
    "plt.xticks(np.arange(min(partitions), max(partitions)+0.2, 0.3))\n",
    "# plt.errorbar(range(1, len(x_ticks)+1), ys[1:], yerr=y_error[1:], fmt='o', markersize=10, capsize=10)\n",
    "plt.plot(partitions, ys[2:], 'o-', markersize=1, label=\"corr. noise\")\n",
    "plt.fill_between(partitions, np.array(ys[2:]) - np.array(y_error[2:]), np.array(ys[2:]) + np.array(y_error[2:]), alpha=0.3)\n",
    "# fix y axis range from 0.2 to 0.9\n",
    "plt.ylim(0.1, 0.9)\n",
    "# draw horizontal dotted line at acc.T[0].mean()\n",
    "plt.axhline(y=ys[0], color='cyan', linestyle='--', label=\"no DP\")\n",
    "plt.axhline(y=ys[1], color='red', linestyle='--', label=\"i.i.d. noise\")\n",
    "plt.legend()\n",
    "# set xtick labels as [\"0\", \"0.1\", \"0.3\", \"0.5\", \"0.7\", \"0.9\"]\n",
    "# plt.xticks(num_x_ticks, [\"0\", \"0.1\", \"0.3\", \"0.5\", \"0.7\", \"0.9\"])\n",
    "# plt.savefig(f'../figs/label_flip_{dataset}_{model_name}_{budget}_acc.pdf', dpi=300, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ys[int(len(partitions) * 0.2)], y_error[int(len(partitions) * 0.2)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ys[int(len(partitions) * 0.8)], y_error[int(len(partitions) * 0.8)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ys[0], y_error[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ys[1], y_error[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ys[2], y_error[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_no_dop = []\n",
    "y_no_dop_error = []\n",
    "y_iid = []\n",
    "y_iid_error = []\n",
    "ys = []\n",
    "y_error = []\n",
    "num_data = num_players\n",
    "start = 0\n",
    "end = 1.0\n",
    "warmup = 0.9\n",
    "for ratio in np.arange(start, end, 0.05):\n",
    "    num_flipped = int(num_data * flip_ratio)\n",
    "    warmups = [\"no_dp\", \"iid\", warmup]\n",
    "    flipped_idx = np.arange(num_flipped)\n",
    "    cur_flipped = int(num_data * ratio)\n",
    "    for i, warmup in enumerate(warmups):\n",
    "        cur_y = []\n",
    "        for seed in seeds:\n",
    "            if warmup == \"no_dp\":\n",
    "                cur_scores = scores_no_dp[seed]\n",
    "            elif warmup == \"iid\":\n",
    "                cur_scores = scores_no_mom[seed]\n",
    "            else:\n",
    "                cur_scores = scores[seed]\n",
    "            if warmup != \"no_dp\" and warmup != \"iid\":\n",
    "                cur_scores = cur_scores[int(len(cur_scores) * warmup):]\n",
    "            cur_scores = np.mean(cur_scores, axis=0)\n",
    "            print(cur_scores)\n",
    "            lowest_idx = np.argsort(cur_scores)\n",
    "            # check how many of the lowest indices are in the flipped indices\n",
    "            num_correct = 0\n",
    "            for idx in lowest_idx[:cur_flipped]:\n",
    "                if idx in flipped_idx:\n",
    "                    num_correct += 1        \n",
    "            cur_y.append(num_correct / num_flipped)\n",
    "        if warmup == \"no_dp\":\n",
    "            y_no_dop.append(np.mean(cur_y))\n",
    "            y_no_dop_error.append(np.std(cur_y))\n",
    "        elif warmup == \"iid\":\n",
    "            y_iid.append(np.mean(cur_y))\n",
    "            y_iid_error.append(np.std(cur_y))\n",
    "        else:\n",
    "            ys.append(np.mean(cur_y))\n",
    "            y_error.append(np.std(cur_y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_iid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot the curve of y_no_dop, y_iid, ys\n",
    "plt.grid()\n",
    "plt.ylabel('Recall')\n",
    "plt.xlabel('Label flip ratio')\n",
    "plt.title(f\"k={budget}, {model_name}, {dataset}\")\n",
    "x_ticks = []\n",
    "for warmup in warmups[1:]:\n",
    "    x_ticks.append(f\"{warmup}\")\n",
    "plt.xticks(np.arange(start, end, 0.05))\n",
    "# plt.errorbar(range(1, len(x_ticks)+1), ys[1:], yerr=y_error[1:], fmt='o', markersize=10, capsize=10)\n",
    "plt.plot(np.arange(start, end, 0.05), y_no_dop, 'o-', markersize=1, label=\"no DP\", color=\"green\")\n",
    "plt.fill_between(np.arange(start, end, 0.05), np.array(y_no_dop) - np.array(y_no_dop_error), np.array(y_no_dop) + np.array(y_no_dop_error), alpha=0.3, color=\"green\")\n",
    "\n",
    "plt.plot(np.arange(start, end, 0.05), y_iid, 'o-', markersize=1, label=\"i.i.d. noise\", color=\"C0\")\n",
    "plt.fill_between(np.arange(start, end, 0.05), np.array(y_iid) - np.array(y_iid_error), np.array(y_iid) + np.array(y_iid_error), alpha=0.3, color=\"C0\")\n",
    "\n",
    "plt.plot(np.arange(start, end, 0.05), ys, 'o-', markersize=1, label=\"corr. noise\", color=\"orange\")\n",
    "plt.fill_between(np.arange(start, end, 0.05), np.array(ys) - np.array(y_error), np.array(ys) + np.array(y_error), alpha=0.3, color=\"orange\")\n",
    "\n",
    "# plot y = x line\n",
    "plt.plot(np.arange(start, end, 0.05), np.arange(start, end, 0.05), '--', color='red', label=\"y=x\")\n",
    "\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_no_dp = []\n",
    "ys = []\n",
    "y_error = []\n",
    "num_data = num_players\n",
    "num_flipped = int(num_data * flip_ratio)\n",
    "partitions = np.arange(0, 1.0, 0.01)\n",
    "warmup = 0.9\n",
    "warmups = [\"no_dp\", \"iid\", warmup]\n",
    "flipped_idx = np.arange(num_flipped)\n",
    "\n",
    "for i, warmup in enumerate(warmups):\n",
    "    cur_y = []\n",
    "    for j, seed in enumerate(seeds):\n",
    "        if warmup == \"no_dp\":\n",
    "            cur_scores = scores_no_dp[seed]\n",
    "        elif warmup == \"iid\":\n",
    "            cur_scores = scores_no_mom[seed]\n",
    "        else:\n",
    "            cur_scores = scores[seed]\n",
    "        if warmup != \"no_dp\" and warmup != \"iid\":\n",
    "            cur_scores = cur_scores[int(len(cur_scores) * warmup):]\n",
    "        cur_scores = np.mean(cur_scores, axis=0)\n",
    "        lowest_idx = np.argsort(cur_scores)\n",
    "        # check how many of the lowest indices are in the flipped indices\n",
    "        # compute the AUC with different ratio of indices checked\n",
    "        auc_partitions = np.arange(0, 1.02, 0.02)\n",
    "        tpr = []\n",
    "        fpr = []\n",
    "        for partition in auc_partitions:\n",
    "            num_correct = 0\n",
    "            for idx in lowest_idx[:int(len(lowest_idx) * partition)]:\n",
    "                if idx in flipped_idx:\n",
    "                    num_correct += 1\n",
    "            tpr.append(num_correct / num_flipped)\n",
    "            fpr.append((int(len(lowest_idx) * partition) - num_correct) / (num_data - num_flipped))\n",
    "        # compute the AUC\n",
    "        auc = 0\n",
    "        for i in range(len(tpr) - 1):\n",
    "            auc += (tpr[i] + tpr[i + 1]) * (fpr[i + 1] - fpr[i]) / 2\n",
    "        if warmup == \"no_dp\":\n",
    "            y_no_dp.append(auc)\n",
    "        elif warmup == \"iid\":\n",
    "            auc = auc - y_no_dp[j]\n",
    "        else:\n",
    "            auc = auc - y_no_dp[j]\n",
    "        cur_y.append(auc)\n",
    "    ys.append(np.mean(cur_y))\n",
    "    y_error.append(np.std(cur_y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ys, y_error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dv_dp",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
