{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0508fb7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import cvxpy as cp\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import pandas as pd\n",
    "import pickle\n",
    "import re\n",
    "import scipy\n",
    "import seaborn as sns\n",
    "\n",
    "\n",
    "from collections import defaultdict\n",
    "from itertools import product\n",
    "from sklearn.cluster import KMeans\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "349e2543",
   "metadata": {},
   "outputs": [],
   "source": [
    "# alloc_fname = os.path.join(\"..\", \"outputs\", \"outputs\", dataset, \"adv_gesw_0.30_1_alloc.npy\")\n",
    "# x = np.load(alloc_fname)\n",
    "# x.sum(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "489638d6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adda2bf5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f396f0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# AAMAS1, AAMAS2, AAMAS3, cs, gAAMASX\n",
    "dataset = \"gAAMAS3\"\n",
    "\n",
    "alloc_types = [\"exp_usw_max\", \"exp_gesw_max\", \"cvar_usw\", \"cvar_gesw\", \"adv_usw\", \"adv_gesw\"]\n",
    "\n",
    "metrics = {x: dict() for x in range(1, 6)}\n",
    "\n",
    "for alloc in alloc_types:\n",
    "    if alloc.startswith(\"cvar\"):\n",
    "        metrics[alloc] = {}\n",
    "        conf_level = \"0.01\"\n",
    "        metrics[alloc][conf_level] = {}\n",
    "        for seed in range(1, 6):\n",
    "            fname = os.path.join(\"..\", \"outputs\", \"outputs\", dataset, \"%s_%s_%d_metrics.pkl\" % (alloc, conf_level, seed))\n",
    "            metrics[alloc][conf_level][seed] = pickle.load(open(fname, 'rb'))\n",
    "    elif alloc.startswith(\"adv\"):\n",
    "        metrics[alloc] = {}\n",
    "        conf_level = \"0.30\"\n",
    "        metrics[alloc][conf_level] = {}\n",
    "        for seed in range(1, 6):\n",
    "            try:\n",
    "                fname = os.path.join(\"..\", \"outputs\", \"outputs\", dataset, \"%s_%s_%d_metrics.pkl\" % (alloc, conf_level, seed))\n",
    "                metrics[alloc][conf_level][seed] = pickle.load(open(fname, 'rb'))\n",
    "            except:\n",
    "                pass\n",
    "    else:\n",
    "        metrics[alloc] = {}\n",
    "        for seed in range(1, 6):\n",
    "            fname = os.path.join(\"..\", \"outputs\", \"outputs\", dataset, \"%s_%d_metrics.pkl\" % (alloc, seed))\n",
    "            metrics[alloc][seed] = pickle.load(open(fname, 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b235d55",
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics['adv_usw']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "299e00a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Stats for %s dataset\\n\\n\" % dataset)\n",
    "\n",
    "cvar_level = 0.01\n",
    "adv_level = 0.3\n",
    "\n",
    "m_dicts = [metrics['exp_usw_max'], \n",
    "           metrics['exp_gesw_max'],\n",
    "           metrics['cvar_usw'][\"%.2f\" % cvar_level],\n",
    "           metrics['cvar_gesw'][\"%.2f\" % cvar_level],\n",
    "           metrics['adv_usw'][\"%.2f\" % adv_level],\n",
    "           metrics['adv_gesw'][\"%.2f\" % adv_level]\n",
    "          ]\n",
    "\n",
    "allocs = [\"USW\", \"GESW\", \"CVaR USW\", \"CVaR GESW\", \"Rob. USW\", \"Rob. GESW\"]\n",
    "table_str = \"\\\\multirow{2}*{Allocation} & \\\\multicolumn{6}{c}{Evaluation Objective} \\\\\\\\\\n\"\n",
    "table_str += \" & USW & GESW & CVaR USW & CVaR GESW & Rob. USW & Rob. GESW \\\\\\\\\\n\\\\midrule\\n\"\n",
    "metric_matrices = []\n",
    "for _ in range(len(m_dicts[0])):\n",
    "    metric_matrices.append(np.ones((6,6)))\n",
    "# std_matrix = np.zeros((6, 6))\n",
    "\n",
    "for row_idx, (aname, md) in enumerate(zip(allocs, m_dicts)):\n",
    "    print(md)\n",
    "    for seed in [1, 3, 4, 5]:\n",
    "#     for seed in range(1,6):\n",
    "        metric_matrices[seed-1][row_idx, 0] = md[seed]['usw']\n",
    "        metric_matrices[seed-1][row_idx, 1] = md[seed]['gesw']\n",
    "        metric_matrices[seed-1][row_idx, 2] = md[seed]['cvar_usw'][cvar_level]\n",
    "        metric_matrices[seed-1][row_idx, 3] = md[seed]['cvar_gesw'][cvar_level]\n",
    "        metric_matrices[seed-1][row_idx, 4] = md[seed]['adv_usw'][adv_level]\n",
    "        metric_matrices[seed-1][row_idx, 5] = md[seed]['adv_gesw'][adv_level]\n",
    "\n",
    "# print(metric_matrices[0])\n",
    "# print(metric_matrices[1])\n",
    "print(len(metric_matrices))\n",
    "print(metric_matrices)\n",
    "\n",
    "for mm in metric_matrices:\n",
    "    mm[mm == 0] = 1e-8\n",
    "    mm *= 1/np.max(mm, axis=0)\n",
    "\n",
    "final_means = np.zeros_like(metric_matrices[0])\n",
    "final_stds = np.zeros_like(metric_matrices[0])\n",
    "\n",
    "for i in range(6):\n",
    "    for j in range(6):\n",
    "        el = [mm[i,j] for mm in metric_matrices]\n",
    "        final_means[i,j] = np.mean(el)\n",
    "        final_stds[i,j] = np.std(el)\n",
    "        \n",
    "print(final_stds)\n",
    "\n",
    "for row_idx, aname in enumerate(allocs):\n",
    "    table_str += aname + \" & \"\n",
    "#     table_str += (\"$%.2f$ & \"*5 + \"$%.2f$ \") % tuple(final_means[row_idx])\n",
    "\n",
    "    means_stds = []\n",
    "    for i in range(6):\n",
    "        means_stds.append((final_means[row_idx, i], final_stds[row_idx, i]))\n",
    "    tstr_l = []\n",
    "    \n",
    "    for m, s in means_stds:\n",
    "        if s > 0.005:\n",
    "            s_str = \"%.2f\" % s\n",
    "        else:\n",
    "            s_str = \"0\"\n",
    "        if m >= .995:\n",
    "            tstr_l.append(\"$\\\\mathbf{%.2f \\\\pm %s}$\" % (m, s_str))\n",
    "        elif m <= 0.004:\n",
    "            tstr_l.append(\"$0 \\\\pm %s$\" % (s_str))\n",
    "        else:\n",
    "            tstr_l.append(\"$%.2f \\\\pm %s$\" % (m, s_str))\n",
    "            \n",
    "#     table_str += (\"$%.2f \\\\pm %.2f$ & \"*5 + \"$%.2f \\\\pm %.2f$ \") % tuple(means_stds)\n",
    "    table_str += \" & \".join(tstr_l)\n",
    "    table_str += \"\\\\\\\\\\n\"\n",
    "    \n",
    "print(table_str)\n",
    "#     table_str += aname + \" & \"\n",
    "#     usws = [md[s]['usw'] for s in range(1, 6) if s in md] \n",
    "#     gesws = [md[s]['gesw'] for s in range(1, 6) if s in md]\n",
    "#     cvar_usws = [md[s]['cvar_usw'][cvar_level] for s in range(1, 6) if s in md] \n",
    "#     cvar_gesws = [md[s]['cvar_gesw'][cvar_level] for s in range(1, 6) if s in md] \n",
    "#     adv_usws = [md[s]['adv_usw'][adv_level] for s in range(1, 6) if s in md] \n",
    "#     adv_gesws = [md[s]['adv_gesw'][adv_level] for s in range(1, 6) if s in md] \n",
    "    \n",
    "#     metric_lists = [usws, gesws, cvar_usws, cvar_gesws, adv_usws, adv_gesws]\n",
    "    \n",
    "#     all_means = [np.mean(x) for x in metric_lists]\n",
    "#     all_stds = [np.std(x) for x in metric_lists]\n",
    "#     all_means_stds = []\n",
    "#     for i in range(len(metric_lists)):\n",
    "#         all_means_stds.append(all_means[i])\n",
    "#         all_means_stds.append(all_stds[i])\n",
    "    \n",
    "#     print(all_means_stds)\n",
    "#     table_str += (\"$%.2f \\\\pm %.2f$ & \"*5 + \"$%.2f \\\\pm %.2f$ \") % tuple(all_means_stds)\n",
    "#     table_str += \"\\\\\\\\\\n\"\n",
    "# print(table_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "003960d0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1b57b33",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Why is CVaR so high for the naive settings\n",
    "data_dir = \"../data/AAMAS\"\n",
    "print(data_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d08ec37a",
   "metadata": {},
   "outputs": [],
   "source": [
    "didx = 2\n",
    "mufile = os.path.join(data_dir, \"mu_matrix_%d.npy\" % didx)\n",
    "mu_matrix = np.load(mufile)\n",
    "zetafile = os.path.join(data_dir, \"zeta_matrix_%d.npy\" % didx)\n",
    "zeta_matrix = np.load(zetafile)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f212e765",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_samples = 1000\n",
    "rng = np.random.default_rng(seed=0)\n",
    "samples = [rng.normal(mu_matrix, zeta_matrix*100) for _ in range(num_samples)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ba38046",
   "metadata": {},
   "outputs": [],
   "source": [
    "alloc = np.load(\"../outputs/outputs/gAAMAS%d/exp_usw_max_alloc.npy\" % didx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f9e2ef3",
   "metadata": {},
   "outputs": [],
   "source": [
    "usws = [np.sum(s*alloc)/alloc.shape[1] for s in samples]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c962a5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.hist(usws, density=True, bins=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34de6653",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(sorted(usws)[:int(num_samples*.01)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28166fc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "cvar_alloc = np.load(\"../outputs/outputs/gAAMAS%d/cvar_usw_0.30_alloc.npy\" % didx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45fe0103",
   "metadata": {},
   "outputs": [],
   "source": [
    "usw_cvar = [np.sum(s*cvar_alloc)/alloc.shape[1] for s in samples]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3fe6f15",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.hist(usw_cvar, density=True, bins=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "290592b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(sorted(usw_cvar)[:int(num_samples*.01)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc74bc5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "paper_of_interest = 7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "055080aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.where(alloc[:, paper_of_interest]))\n",
    "print(mu_matrix[np.where(alloc[:, paper_of_interest])[0], paper_of_interest])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e90ade54",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.where(cvar_alloc[:, paper_of_interest]))\n",
    "print(mu_matrix[np.where(cvar_alloc[:, paper_of_interest])[0], paper_of_interest])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e22629db",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(mu_matrix[np.where(alloc)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b3965fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(zeta_matrix[np.where(alloc)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6296c463",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(mu_matrix[np.where(cvar_alloc)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51d37f70",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(zeta_matrix[np.where(cvar_alloc)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ffcb5a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import skewnorm\n",
    "rv = skewnorm(-1*np.ones_like(mu_matrix), loc=mu_matrix, scale=zeta_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "833a4680",
   "metadata": {},
   "outputs": [],
   "source": [
    "first = []\n",
    "for _ in range(100):\n",
    "    first.append(rv.rvs()[0,0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc6b4787",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(first)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49079235",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import norm\n",
    "print(norm.pdf(norm.ppf(.2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5340c8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gurobipy as gp\n",
    "def solve_cvar_usw_gauss(mu_matrix, sigma_matrix, covs_lb, covs_ub, loads, conf_level, coi_mask):\n",
    "    m = gp.Model(\"TPMS\")\n",
    "\n",
    "    alloc = m.addMVar(mu_matrix.shape, vtype=gp.GRB.BINARY, name='alloc')\n",
    "\n",
    "    m.addConstr(alloc.sum(axis=0) >= covs_lb)\n",
    "    m.addConstr(alloc.sum(axis=0) <= covs_ub)\n",
    "    m.addConstr(alloc.sum(axis=1) <= loads)\n",
    "    m.addConstr(alloc <= coi_mask)\n",
    "\n",
    "    aux = m.addVar(lb=0)\n",
    "\n",
    "    m.addConstr(aux**2 == (alloc * sigma_matrix * alloc).sum())\n",
    "\n",
    "    frac = norm.pdf(norm.ppf(conf_level))/(1-conf_level)\n",
    "    obj = (alloc * mu_matrix).sum() - frac*aux\n",
    "    m.setParam(\"NonConvex\", 2)\n",
    "    m.setObjective(obj, gp.GRB.MAXIMIZE)\n",
    "\n",
    "    m.optimize()\n",
    "\n",
    "    return alloc.x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76d064e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "solve_cvar_usw_gauss(mu_matrix, zeta_matrix, \n",
    "                     np.array([2]*mu_matrix.shape[1]), np.array([2]*mu_matrix.shape[1]), \n",
    "                     np.array([15]*mu_matrix.shape[0]),\n",
    "                    .01,\n",
    "                    np.ones_like(mu_matrix))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd2a28b5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "913b4ab6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bfd25b9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdc6a038",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make a plot that shows the distribution of USW for each allocation\n",
    "dset_name = \"aamas1\"\n",
    "\n",
    "def get_samples(central_estimate, std_devs, dset_name, num_samples=100):\n",
    "    rng = np.random.default_rng(seed=0)\n",
    "    if dset_name.startswith(\"aamas\") or dset_name == 'cs':\n",
    "        samples = [rng.uniform(size=central_estimate.shape) < central_estimate for _ in range(num_samples)]\n",
    "        return samples\n",
    "    else:\n",
    "        return [rng.normal(central_estimate, std_devs) for _ in range(num_samples)]\n",
    "\n",
    "all_lines = []\n",
    "for alloc_name in [\"cvar_usw_0.70_alloc\", \"exp_usw_max_alloc\"]:\n",
    "    ce = np.load(\"../data/AAMAS/prob_up_1.npy\")\n",
    "    alloc = np.load(\"../outputs/outputs/AAMAS1/%s.npy\" % alloc_name)\n",
    "    samples = get_samples(ce, None, dset_name, num_samples=1000)\n",
    "    usws = [np.sum(alloc*s)/alloc.shape[1] for s in samples]\n",
    "    cutoff = int(len(usws)*.3)\n",
    "    print(np.mean(sorted(usws)[:cutoff]))\n",
    "\n",
    "    for u in usws:\n",
    "        all_lines.append([alloc_name, u])\n",
    "\n",
    "df = pd.DataFrame(all_lines)\n",
    "df.columns = [\"alloc\", \"usw\"]\n",
    "\n",
    "# sns.histplot(usws)\n",
    "sns.kdeplot(\n",
    "    data=df, x=\"usw\", hue=\"alloc\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb4e567d",
   "metadata": {},
   "outputs": [],
   "source": [
    "alloc = np.load(\"../outputs/outputs/AAMAS1/%s.npy\" % \"cvar_usw_0.70_alloc\")\n",
    "alloc_usw = np.load(\"../outputs/outputs/AAMAS1/%s.npy\" % \"exp_usw_max_alloc\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea0fc5bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.sum(alloc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d552acd",
   "metadata": {},
   "outputs": [],
   "source": [
    "prob_up = np.load(\"../data/AAMAS/prob_up_1.npy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa638a59",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(prob_up[np.where(alloc)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08fe4710",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(prob_up[np.where(alloc_usw)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c4b52c2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20625e7f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a0b59e6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7de88f62",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe1b5226",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the CVaR over level of noise for USW and GESW, for each of the Gaussian datasets (6 plots). Do the naive and the CVaR version.\n",
    "# AAMAS1, AAMAS2, AAMAS3, cs, gAAMASX\n",
    "dataset = \"gAAMAS1\"\n",
    "\n",
    "alloc_types = [\"exp_usw_max\", \"exp_gesw_max\", \"cvar_usw\", \"cvar_gesw\"]\n",
    "\n",
    "metrics = defaultdict(dict)\n",
    "\n",
    "# noise_levels = [1.0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]\n",
    "\n",
    "noise_levels = [1.0, 2.0, 4.0, 6, 8]\n",
    "\n",
    "seed = 5\n",
    "\n",
    "for noise_level in noise_levels:\n",
    "    for alloc in alloc_types:\n",
    "        if alloc.startswith(\"cvar\"):\n",
    "            conf_level=\"0.01\"\n",
    "            fname = os.path.join(\"..\", \"outputs\", \"outputs\", dataset, \"%s_%s_%.2f_%d_metrics.pkl\" % (alloc, conf_level, noise_level, seed))\n",
    "            metrics[alloc][noise_level] = pickle.load(open(fname, 'rb'))\n",
    "        else:\n",
    "            fname = os.path.join(\"..\", \"outputs\", \"outputs\", dataset, \"%s_%.2f_%d_metrics.pkl\" % (alloc, noise_level, seed))\n",
    "            metrics[alloc][noise_level] = pickle.load(open(fname, 'rb'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24951de0",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams[\"font.family\"] = \"serif\"\n",
    "plt.rcParams[\"font.serif\"] = [\"Times New Roman\"]\n",
    "\n",
    "cvar_at_01 = [metrics['cvar_usw'][nl]['cvar_usw'][0.01] for nl in noise_levels]\n",
    "max_exp_usw = [metrics['exp_usw_max'][nl]['cvar_usw'][0.01] for nl in noise_levels]\n",
    "sns.lineplot(x=noise_levels, y=cvar_at_01, color=\"#C41751\")\n",
    "sns.lineplot(x=noise_levels, y=max_exp_usw, linestyle=\"--\", color=\"#C41751\")\n",
    "cvar_gesw_at_01 = [metrics['cvar_gesw'][nl]['cvar_gesw'][0.01] for nl in noise_levels]\n",
    "max_exp_gesw = [metrics['exp_gesw_max'][nl]['cvar_gesw'][0.01] for nl in noise_levels]\n",
    "sns.lineplot(x=noise_levels, y=cvar_gesw_at_01, color=\"#1D71C3\")\n",
    "sns.lineplot(x=noise_levels, y=max_exp_gesw, linestyle=\"--\", color=\"#1D71C3\")\n",
    "plt.rcParams[\"font.family\"] = \"serif\"\n",
    "plt.rcParams[\"font.serif\"] = [\"Times New Roman\"]\n",
    "fs = 26\n",
    "plt.ylabel(\"CVaR$_{0.01}[$W$]$\", fontsize=fs+2)\n",
    "plt.xlabel(\"Standard Deviation Multiplier\", fontsize=fs+2)\n",
    "plt.xlim([.8, 9])\n",
    "plt.tick_params(axis='both', which='major', labelsize=fs)\n",
    "plt.savefig(\"../Plots/cvar_over_noise_%s.pdf\" % dataset, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04d5e036",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"gAAMAS3\"\n",
    "\n",
    "alloc_types = [\"cvar_usw\", \"cvar_gesw\",\"exp_usw_max\", \"exp_gesw_max\"]\n",
    "\n",
    "metrics = defaultdict(dict)\n",
    "all_entries = []\n",
    "# noise_levels = [1.0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]\n",
    "\n",
    "noise_levels = [1.0, 2.0, 4.0, 6, 8]\n",
    "\n",
    "# seed = 5\n",
    "cvar_usw_val_by_seed = {}\n",
    "cvar_gesw_val_by_seed = {}\n",
    "\n",
    "for seed in range(1,6):\n",
    "\n",
    "    \n",
    "    for noise_level in noise_levels:\n",
    "        cvar_usw_val_by_seed[(seed, noise_level)] = {}\n",
    "        cvar_gesw_val_by_seed[(seed, noise_level)] = {}\n",
    "        for alloc in alloc_types:\n",
    "            \n",
    "            if alloc.startswith(\"cvar\"):\n",
    "                conf_level=\"0.01\"\n",
    "                fname = os.path.join(\"..\", \"outputstuesday\", \"outputs\", dataset, \"%s_%s_%.2f_%d_metrics.pkl\" % (alloc, conf_level, noise_level, seed))\n",
    "                metrics[alloc][noise_level] = pickle.load(open(fname, 'rb'))\n",
    "            else:\n",
    "                fname = os.path.join(\"..\", \"outputstuesday\", \"outputs\", dataset, \"%s_%.2f_%d_metrics.pkl\" % (alloc, noise_level, seed))\n",
    "                metrics[alloc][noise_level] = pickle.load(open(fname, 'rb'))\n",
    "\n",
    "            if alloc == \"cvar_usw\":\n",
    "                cvar_usw_val_by_seed[(seed, noise_level)] = metrics[alloc][noise_level]['cvar_usw'][0.01]\n",
    "            if alloc == \"cvar_gesw\":\n",
    "                cvar_gesw_val_by_seed[(seed, noise_level)] = metrics[alloc][noise_level]['cvar_gesw'][0.01]\n",
    "                \n",
    "#             if alloc in ['exp_usw_max', 'cvar_usw']:\n",
    "#                 all_entries.append([alloc, noise_level, 'cvar_usw', metrics[alloc][noise_level]['cvar_usw'][0.01]])\n",
    "#             else:\n",
    "#                 all_entries.append([alloc, noise_level, 'cvar_gesw', metrics[alloc][noise_level]['cvar_gesw'][0.01]])\n",
    "            if alloc in ['exp_usw_max', 'cvar_usw']:\n",
    "                all_entries.append([alloc, noise_level, 'cvar_usw', metrics[alloc][noise_level]['cvar_usw'][0.01], not alloc.startswith(\"cvar\"), seed])\n",
    "            else:\n",
    "                all_entries.append([alloc, noise_level, 'cvar_gesw', metrics[alloc][noise_level]['cvar_gesw'][0.01],  not alloc.startswith(\"cvar\"), seed])\n",
    "            \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ff88f6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for e in all_entries:\n",
    "#     if e[0] in ['exp_usw_max', 'cvar_usw']:\n",
    "#         e[3] = cvar_usw_val_by_seed[(e[-1], e[1])] - e[3]\n",
    "#     else:\n",
    "#         e[3] *= cvar_gesw_val_by_seed[(e[-1], e[1])] - e[3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc278651",
   "metadata": {},
   "outputs": [],
   "source": [
    "df= pd.DataFrame(all_entries)\n",
    "df.columns = ['alloc', 'noise_level', 'Welfare', 'metric_value' , 'Robust Concept', 'seed']\n",
    "print(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7bfade7",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['Welfare'][df[\"Welfare\"] == 'cvar_usw'] = \"USW\"\n",
    "df['Welfare'][df[\"Welfare\"] == 'cvar_gesw'] = \"GESW\"\n",
    "df['Robust Concept'][df[\"Robust Concept\"] == True] = \"None\"\n",
    "df['Robust Concept'][df[\"Robust Concept\"] == False] = \"CVaR\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd63c3f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a166ad3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.clf()\n",
    "plt.rcParams[\"font.family\"] = \"serif\"\n",
    "plt.rcParams[\"font.serif\"] = [\"Times New Roman\"]\n",
    "plt.rcParams[\"figure.figsize\"] = (6,4)\n",
    "fs = 26\n",
    "plt.ylabel(\"CVaR$_{0.01}[$W$]$\", fontsize=fs+2)\n",
    "plt.xlabel(\"Standard Deviation Scale\", fontsize=fs+2)\n",
    "plt.xlim([.8, 8.2])\n",
    "plt.tick_params(axis='both', which='major', labelsize=fs)\n",
    "ax = sns.lineplot(data=df, x='noise_level', y='metric_value', hue=\"Welfare\", style='Robust Concept', errorbar=None)\n",
    "# box = ax.get_position()\n",
    "# ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])\n",
    "\n",
    "# # Put a legend to the right of the current axis\n",
    "# ax.legend(loc='center left', bbox_to_anchor=(0, 0))\n",
    "# plt.figsize([8,8])\n",
    "# plt.legend([],[], frameon=False)\n",
    "# plt.legend(loc='center left', )\n",
    "\n",
    "plt.legend(loc=0, prop={'size': 15})\n",
    "plt.savefig(\"../Plots/cvar_over_noise_%s.pdf\" % dataset, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db1fd342",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6aa72c3b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a77828e3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afa3d76b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26723d9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot performance over time for the USW at least on AAMAs2\n",
    "dataset = \"gAAMAS1\"\n",
    "\n",
    "algos = [\"IQP\", \"SubgradAsc\"]\n",
    "\n",
    "conf_level = 0.30\n",
    "\n",
    "plt.rcParams[\"font.family\"] = \"serif\"\n",
    "plt.rcParams[\"font.serif\"] = [\"Times New Roman\"]\n",
    "fs = 26\n",
    "plt.ylabel(\"Rob. USW\", fontsize=fs+2)\n",
    "plt.xlabel(\"Time (s)\", fontsize=fs+2)\n",
    "plt.xlim([.8, 9])\n",
    "plt.tick_params(axis='both', which='major', labelsize=fs)\n",
    "\n",
    "for alg in algos:\n",
    "    time_fname = os.path.join(\"..\", \"outputsmonday\", \"outputs\", dataset, \"adv_usw_%s_0.30_timestamps.pkl\" % alg)\n",
    "    times = pickle.load(open(time_fname, 'rb'))\n",
    "    \n",
    "    ov_fname = os.path.join(\"..\", \"outputsmonday\", \"outputs\", dataset, \"adv_usw_%s_0.30_obj_vals.pkl\" % alg)\n",
    "    ovs = pickle.load(open(ov_fname, 'rb'))\n",
    "    ovs = [float(x) for x in ovs]\n",
    "    \n",
    "    list_diff = len(times) - len(ovs)\n",
    "    times = times[list_diff:]\n",
    "#     if alg != \"SubgradAsc\":\n",
    "#         times = times[1:-1]\n",
    "#     print(len(times), len(ovs))\n",
    "    if alg == \"IQP\":\n",
    "        sty = \"-\"\n",
    "    else:\n",
    "        sty = \"--\"\n",
    "    sns.lineplot(x=times, y=ovs, label=alg, linestyle=sty)\n",
    "plt.xlim([-5, 200])\n",
    "# plt.legend([],[], frameon=False)\n",
    "plt.legend(loc=0, prop={'size': 15})\n",
    "plt.savefig(\"../Plots/conv_%s.pdf\" % dataset, bbox_inches='tight')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77050fc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"gAAMAS2\"\n",
    "\n",
    "algos = [\"ProjGD\", \"SubgradAsc\"]\n",
    "\n",
    "conf_level = 0.30\n",
    "\n",
    "for alg in algos:\n",
    "    time_fname = os.path.join(\"..\", \"outputs\", \"outputs\", dataset, \"adv_gesw_%s_0.30_timestamps.pkl\" % alg)\n",
    "    times = pickle.load(open(time_fname, 'rb'))\n",
    "    \n",
    "    ov_fname = os.path.join(\"..\", \"outputs\", \"outputs\", dataset, \"adv_gesw_%s_0.30_obj_vals.pkl\" % alg)\n",
    "    ovs = pickle.load(open(ov_fname, 'rb'))\n",
    "    ovs = [float(x) for x in ovs]\n",
    "    \n",
    "    list_diff = len(times) - len(ovs)\n",
    "    times = times[list_diff:]\n",
    "#     if alg != \"SubgradAsc\":\n",
    "#         times = times[1:-1]\n",
    "#     print(len(times), len(ovs))\n",
    "    sns.lineplot(x=times, y=ovs, label=alg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e07d016",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
