{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "\n",
    "from utils import set_up_plotting\n",
    "plt = set_up_plotting()\n",
    "\n",
    "import sys\n",
    "sys.path.insert(0, '../')\n",
    "sys.path.insert(0, '../datasets')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# results_dir = '../results/budget/'\n",
    "# dataset = 'diabetes'\n",
    "# budgets = [120, 240, 400, 1000]\n",
    "# num_players = 400\n",
    "# seeds = [0,1,2,3,4]\n",
    "# eps = 1.0\n",
    "# model_name = \"logistic_regression\"\n",
    "\n",
    "# results_dir = '../results/budget/'\n",
    "# dataset = 'covertype'\n",
    "# budgets = [40,80,120,160,200,240,280]\n",
    "# num_players = 400\n",
    "# seeds = [0,1,2,3,4,5,6,7,8,9]\n",
    "# eps = 20\n",
    "# model_name = \"mlp\"\n",
    "\n",
    "results_dir = '../results/budget/'\n",
    "dataset = 'wine_quality'\n",
    "budgets = [120, 240, 400, 1000]\n",
    "num_players = 400\n",
    "seeds = [0,1,2,3,4]\n",
    "eps = 10\n",
    "model_name = \"linear_regression\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_names = []\n",
    "\n",
    "for file in os.listdir(results_dir):\n",
    "    if file.endswith('.npy') and dataset in file and f\"players_{num_players}\" in file and f\"eps_{eps}\" in file and f\"{model_name}\" in file:\n",
    "        f = '.'.join(file.split('.')[:-1])\n",
    "        exp_names.append(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_names.sort()\n",
    "exp_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores_iid = {}\n",
    "# vars_iid = {}\n",
    "# scores_mom = {}\n",
    "# vars_mom = {}\n",
    "\n",
    "for seed in seeds:\n",
    "    scores_iid[seed] = []\n",
    "    # vars_iid[seed] = []\n",
    "    # scores_mom[seed] = []\n",
    "    # vars_mom[seed] = []\n",
    "    for budget in budgets:\n",
    "        for exp_name in exp_names:\n",
    "            if f\"seed_{seed}\" in exp_name and f\"budget_{budget}_\" in exp_name:\n",
    "                if 'no_momentum' in exp_name:\n",
    "                    if \"scores\" in exp_name:\n",
    "                        scores_iid[seed] += [np.load(results_dir + exp_name + '.npy', allow_pickle=True)]\n",
    "                #     elif \"vars\" in exp_name:\n",
    "                #         vars_iid[seed] += [np.load(results_dir + exp_name + '.npy', allow_pickle=True)]\n",
    "                # elif 'use_momentum' in exp_name:\n",
    "                #     if \"scores\" in exp_name:\n",
    "                #         scores_mom[seed] += [np.load(results_dir + exp_name + '.npy', allow_pickle=True)]\n",
    "                #     elif \"vars\" in exp_name:\n",
    "                #         vars_mom[seed] += [np.load(results_dir + exp_name + '.npy', allow_pickle=True)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(scores_iid[seeds[0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the experiment data from ../data\n",
    "data_dir = '../data/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the train with torch\n",
    "data_name = f'{dataset}_num_players_{num_players}'\n",
    "train_data = torch.load(data_dir + 'train_' + data_name + '.pt')\n",
    "val_data = torch.load(data_dir + 'val_' + data_name + '.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# initialize model\n",
    "import sys\n",
    "sys.path.append('../models')\n",
    "import torch.nn as nn\n",
    "from resnet import ResNet18\n",
    "from logistic_regression import LogisticRegression\n",
    "\n",
    "def init_weights(m):\n",
    "    if isinstance(m, nn.Conv2d):\n",
    "        torch.nn.init.normal_(m.weight, mean=0.0, std=0.01)\n",
    "        if m.bias is not None:\n",
    "            torch.nn.init.zeros_(m.bias)\n",
    "    elif isinstance(m, nn.BatchNorm2d):\n",
    "        torch.nn.init.ones_(m.weight)\n",
    "        torch.nn.init.zeros_(m.bias)\n",
    "    elif isinstance(m, nn.Linear):\n",
    "        torch.nn.init.normal_(m.weight, mean=0.0, std=0.01)\n",
    "        torch.nn.init.zeros_(m.bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:3\" if torch.cuda.is_available() else \"cpu\")\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set random seed\n",
    "# from models.linear_regression import LinearRegression\n",
    "from models.linear_regression import LinearRegression\n",
    "from models.mlp import MLP\n",
    "\n",
    "if model_name == \"resnet18\":\n",
    "    net = ResNet18(num_classes=10).to(device)\n",
    "elif model_name == \"logistic_regression\":\n",
    "    if dataset == \"breast-cancer\":\n",
    "        net = LogisticRegression(30, 2).to(device)\n",
    "    elif dataset == \"diabetes\":\n",
    "        net = LogisticRegression(8, 2).to(device)\n",
    "    elif dataset == \"mnist\":\n",
    "        net = LogisticRegression(784, 10).to(device)\n",
    "    elif dataset == \"covertype\":\n",
    "        net = LogisticRegression(51, 7).to(device)\n",
    "elif model_name == \"mlp\":\n",
    "    if dataset == \"breast-cancer\":\n",
    "        net = MLP(30, 2).to(device)\n",
    "    elif dataset == \"diabetes\":\n",
    "        net = MLP(8, 2).to(device)\n",
    "    elif dataset == \"mnist\":\n",
    "        net = MLP(784, 10).to(device)\n",
    "    elif dataset == \"covertype\":\n",
    "        net = MLP(51, 7).to(device)\n",
    "elif model_name == \"linear_regression\":\n",
    "    if dataset == \"wine_quality\":\n",
    "        net = LinearRegression(11).to(device)\n",
    "    # elif dataset == \"student_performance\":\n",
    "    #     net = LinearRegression(30).to(device)\n",
    "    # elif dataset == \"slice_localization\":\n",
    "    #     net = LinearRegression(379).to(device)\n",
    "else:\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test with momentum\n",
    "# divide into 10 fractions\n",
    "\n",
    "fractions = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]\n",
    "seeds = [0,1,2,3,4]\n",
    "# fractions = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]\n",
    "epochs = 250\n",
    "lr=1e-1\n",
    "is_low_value = False\n",
    "accuracy = {}\n",
    "for budget in budgets:\n",
    "    accuracy[budget] = np.zeros((len(seeds), len(fractions)))\n",
    "accuracy[\"random\"] = np.zeros((len(seeds), len(fractions)))\n",
    "\n",
    "for i, seed in enumerate(seeds):\n",
    "    scores_seed = scores_iid[seed]\n",
    "    for b_idx, budget in enumerate(budgets + ['random']):\n",
    "        for j, frac in enumerate(fractions):   \n",
    "            torch.manual_seed(0)         \n",
    "            net.apply(init_weights)\n",
    "            optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "            for t in range(epochs):\n",
    "                # if frac == 0:\n",
    "                #     cur_scores = scores_seed[0]\n",
    "                #     cur_scores = np.mean(cur_scores, axis=0)\n",
    "                #     score_ranks = np.argsort(cur_scores)\n",
    "                #     start = 0\n",
    "                #     end = len(cur_scores)\n",
    "                # else:\n",
    "                #     if budget == 'random':\n",
    "                #         cur_scores = scores_seed[0]\n",
    "                #     else:\n",
    "                #         cur_scores = scores_seed[b_idx]\n",
    "                #     cur_scores = np.mean(cur_scores, axis=0)\n",
    "                #     score_ranks = np.argsort(cur_scores)\n",
    "                #     if not is_low_value:\n",
    "                #         score_ranks = score_ranks[::-1]\n",
    "                #     if budget == 'random':\n",
    "                #         # randomoize the scores\n",
    "                #         score_ranks = np.random.permutation(score_ranks)\n",
    "                #     # start = int(len(cur_scores) * fractions[j-1]) if j > 0 else 0\n",
    "                #     # start = 0\n",
    "                #     # end = int(len(cur_scores) * frac)\n",
    "                #     start = int(len(cur_scores) * frac)\n",
    "                #     end = len(cur_scores)\n",
    "\n",
    "                if frac == 0:\n",
    "                    continue\n",
    "                else:\n",
    "                    if budget == 'random':\n",
    "                        cur_scores = scores_seed[0]\n",
    "                    else:\n",
    "                        cur_scores = scores_seed[b_idx]\n",
    "                    cur_scores = np.mean(cur_scores, axis=0)\n",
    "                    score_ranks = np.argsort(cur_scores)\n",
    "                    if not is_low_value:\n",
    "                        score_ranks = score_ranks[::-1]\n",
    "                    if budget == 'random':\n",
    "                        # randomoize the scores\n",
    "                        score_ranks = np.random.permutation(score_ranks)\n",
    "                    # start = int(len(cur_scores) * fractions[j-1]) if j > 0 else 0\n",
    "                    start = 0\n",
    "                    end = int(len(cur_scores) * frac)\n",
    "\n",
    "                datas = torch.utils.data.Subset(train_data, score_ranks[start:end])\n",
    "                # use full batch size\n",
    "                batch_size = len(datas)\n",
    "                data_loader = torch.utils.data.DataLoader(datas, batch_size=batch_size, shuffle=False)   \n",
    "                for _, (data, target) in enumerate(data_loader):\n",
    "                    data, target = data.to(device), target.to(device)\n",
    "                    # train model with data\n",
    "                    net.train()\n",
    "                    output = net(data)\n",
    "                    if model_name == \"linear_regression\":\n",
    "                        target = target.float()\n",
    "                        loss = nn.MSELoss()(output, target)\n",
    "                    else:\n",
    "                        loss = nn.CrossEntropyLoss()(output, target)\n",
    "                    loss.backward()\n",
    "                    # update model\n",
    "                    optimizer.step()\n",
    "                    optimizer.zero_grad()\n",
    "\n",
    "                if (t+1)%50 == 0:\n",
    "                    print(\n",
    "                        f\"epoch: {t + 1}\\t\"\n",
    "                        f\"train loss: {loss.item():.4f}\\t\"\n",
    "                    )\n",
    "            # evaluate the model\n",
    "            if model_name == \"linear_regression\":\n",
    "                net.eval()\n",
    "                mse = 0\n",
    "                with torch.no_grad():\n",
    "                    data_loader = torch.utils.data.DataLoader(val_data, batch_size=len(val_data), shuffle=False)\n",
    "                    for _, (data, target) in enumerate(data_loader):\n",
    "                        images, labels = data.to(device), target.to(device)\n",
    "                        outputs = net(images)\n",
    "                        # mse += nn.MSELoss()(outputs, labels.float()).item()\n",
    "                        # use MAE instead\n",
    "                        mse += nn.L1Loss()(outputs, labels.float()).item()\n",
    "                # get average mse\n",
    "                acc = mse / len(val_data)\n",
    "                accuracy[budget][i][j] = acc\n",
    "                print(\n",
    "                    f\"mse at {frac * 100}% data: {acc}\"\n",
    "                )\n",
    "            else:\n",
    "                # evaluate the model\n",
    "                net.eval()\n",
    "                correct = 0\n",
    "                total = 0\n",
    "                with torch.no_grad():\n",
    "                    data_loader = torch.utils.data.DataLoader(val_data, batch_size=len(val_data), shuffle=False)\n",
    "                    for _, (data, target) in enumerate(data_loader):\n",
    "                        images, labels = data.to(device), target.to(device)\n",
    "                        outputs = net(images)\n",
    "                        _, predicted = torch.max(outputs.data, 1)\n",
    "                        total += labels.size(0)\n",
    "                        correct += (predicted == labels).sum()\n",
    "                acc = (correct / total).item()\n",
    "                accuracy[budget][i][j] = acc\n",
    "                print(\n",
    "                    f\"accuracy at {frac * 100}% data: {100 * correct / total}\"\n",
    "                )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams[\"figure.figsize\"] = (8, 6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot the accuracy vs partition of data\n",
    "# plot the lower and upper quantiles of the accuracy\n",
    "plt.figure()\n",
    "# plt.title(f\"Adding Low Value Data First\")\n",
    "# plt.xlabel(\"Partition of Data\")\n",
    "# plt.xticks(fractions)\n",
    "# add label to the x axis\n",
    "# plt.gca().set_xticklabels([\"33%\", \"66%\", \"100%\"])\n",
    "# plt.ylabel(\"Accuracy\")\n",
    "plt.grid()\n",
    "# set y range\n",
    "# plt.ylim(0.5, 0.80)\n",
    "for b_idx, budget in enumerate(budgets + ['random']):\n",
    "    # plot the error bar\n",
    "    if budget == 1000:\n",
    "        plt.errorbar(fractions, np.mean( accuracy[budget], axis=0), yerr=np.std( accuracy[budget], axis=0), label=rf\"$k$={budget}\", color=\"C0\")\n",
    "    elif budget == \"random\":\n",
    "        plt.errorbar(fractions, np.mean( accuracy[budget], axis=0), yerr=np.std( accuracy[budget], axis=0), label=f\"Random\", color=\"C5\", linestyle='dotted', marker='^', markersize=15)\n",
    "    else:\n",
    "        plt.errorbar(fractions, np.mean( accuracy[budget], axis=0), yerr=np.std( accuracy[budget], axis=0), label=rf\"$k$={budget}\", color=f\"C{b_idx+1}\")\n",
    "plt.legend()\n",
    "plt.savefig(f\"../figs/acc_{dataset}_{num_players}_{eps}_{model_name}_add_high.pdf\", dpi=300, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dv_dp",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
