{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 658,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "#!/usr/bin/env python3\n",
    "\n",
    "import numpy as np\n",
    "import os\n",
    "from datetime import datetime\n",
    "\n",
    "import argparse\n",
    "\n",
    "import importlib\n",
    "try: import setGPU\n",
    "except ImportError: pass\n",
    "\n",
    "import torch\n",
    "\n",
    "import mle, mle_net, policy_net, task_net, plot\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import tkinter\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "matplotlib.use('TkAgg')\n",
    "\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 659,
   "metadata": {},
   "outputs": [],
   "source": [
    "from setup import init_newsvendor_params, init_theta_true, gen_data, log_error_and_write"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 702,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = init_newsvendor_params()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 660,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = init_newsvendor_params()\n",
    "true_model_types = ['linear', 'nonlinear']\n",
    "\n",
    "Theta_true_lin, Theta_true_sq = init_theta_true(params, is_linear=False, with_seed=True)\n",
    "\n",
    "np.random.seed(1)\n",
    "X, Y = gen_data(1000, params, Theta_true_lin, Theta_true_sq, with_seed=False) \n",
    "X_test, Y_test = gen_data(500, params, Theta_true_lin, Theta_true_sq, with_seed=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 694,
   "metadata": {},
   "outputs": [],
   "source": [
    "D = np.sum(params['d'] * Y, axis = 1)\n",
    "P = Problem(X, D, params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from newsvendor_caps import Problem\n",
    "from sklearn import tree\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "\n",
    "def test_discretization(X, Y, X_test, Y_test, phi, type_ = 'knn'): \n",
    "    D_test = np.sum(params['d'] * Y_test, axis = 1)\n",
    "\n",
    "    P.init_n()\n",
    "    if type_ == 'knn':\n",
    "        P.init_models(KNeighborsClassifier(n_neighbors=10))\n",
    "    elif type_ == 'linear':\n",
    "        P.init_models(LogisticRegression())\n",
    "    else: \n",
    "        raise Exception(\"Choose either 'linear' or 'knn'\")\n",
    "\n",
    "    costs = []\n",
    "    t = 0\n",
    "    for (x,d) in zip(X_test, D_test):\n",
    "        pred = P.predict(x, phi)\n",
    "        c = P.objective(pred, d)\n",
    "        costs.append(c)\n",
    "        if t % 100 == 0: print(t)\n",
    "        t += 1\n",
    "    return P, costs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 715,
   "metadata": {},
   "outputs": [],
   "source": [
    "from robust_knn import setup_knn, knn_robust_kl, knn_mean, knn_robust_wass, knn_robust_budget\n",
    "\n",
    "def test_kl_divergence(K, epsilon): \n",
    "    D = np.sum(params['d'] * Y, axis = 1)\n",
    "\n",
    "    neigh = setup_knn(X, K)\n",
    "    all_costs = []\n",
    "    for x_test, y_test in zip(X_test, Y_test): \n",
    "        d_test = np.sum(y_test * params['d'])\n",
    "\n",
    "        w = knn_robust_kl(x_test, X, D, K, params, neigh, epsilon)\n",
    "        cur_cost = P.objective(w, d_test)\n",
    "        all_costs.append(cur_cost)\n",
    "    return all_costs\n",
    "\n",
    "\n",
    "def test_wass(K, epsilon): \n",
    "    D = np.sum(params['d'] * Y, axis = 1)\n",
    "\n",
    "    neigh = setup_knn(X, K)\n",
    "    all_costs = []\n",
    "    t = 0\n",
    "    for x_test, y_test in zip(X_test, Y_test): \n",
    "        d_test = np.sum(y_test * params['d'])\n",
    "\n",
    "        w = knn_robust_wass(x_test, X, D, K, params, neigh, epsilon)\n",
    "        cur_cost = P.objective(w, d_test)\n",
    "        all_costs.append(cur_cost)\n",
    "\n",
    "        if t % 100 == 0: print(t)\n",
    "        t += 1\n",
    "\n",
    "    return all_costs\n",
    "\n",
    "\n",
    "def test_knn(K): \n",
    "    neigh = setup_knn(X, K)\n",
    "    all_costs = []\n",
    "    for x_test, y_test in zip(X_test, Y_test): \n",
    "        d_test = np.sum(y_test * params['d'])\n",
    "\n",
    "        w = knn_mean(x_test, X, D, K, params, neigh)\n",
    "        cur_cost = P.objective(w, d_test)\n",
    "        all_costs.append(cur_cost)\n",
    "    return all_costs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# discretization method for multiple phi\n",
    "\n",
    "all_discretization_costs = []\n",
    "for phi in np.arange(0, 200, 10):\n",
    "    print(\"Phi:\", phi)\n",
    "    _, scores = test_discretization(X, Y, X_test, Y_test, phi, \"knn\")\n",
    "    all_discretization_costs.append(scores)\n",
    "    print(\"Mean: \", np.mean(scores), \"Quantile 99: \", np.quantile(scores, 0.99))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# KNN + KL Divergence\n",
    "\n",
    "k_nn = 10\n",
    "all_kl_costs = []\n",
    "for epsilon in np.arange(0.1,1.1,0.05):\n",
    "    print(\"KL-divergence radius: \", epsilon)\n",
    "    scores = test_kl_divergence(k_nn, epsilon)\n",
    "    all_kl_costs.append(scores)\n",
    "    print(\"Mean: \", np.mean(scores), \"Quantile 99: \", np.quantile(scores, 0.99))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 838,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wasserstein distance:  50.0\n",
      "0\n",
      "100\n",
      "200\n",
      "300\n",
      "400\n",
      "Mean:  300.42 Quantile 99:  600.0\n",
      "Wasserstein distance:  50.1\n",
      "0\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_15316/1413555867.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      4\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mepsilon\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m50\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m100\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m0.1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m     \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Wasserstein distance: \"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mepsilon\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 6\u001b[1;33m     \u001b[0mscores\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtest_wass\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mk_nn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mepsilon\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      7\u001b[0m     \u001b[0mall_wass_costs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mscores\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      8\u001b[0m     \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Mean: \"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mscores\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m\"Quantile 99: \"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mquantile\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mscores\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m0.99\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_15316/4270266773.py\u001b[0m in \u001b[0;36mtest_wass\u001b[1;34m(K, epsilon)\u001b[0m\n\u001b[0;32m     24\u001b[0m         \u001b[0md_test\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msum\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my_test\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mparams\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'd'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     25\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 26\u001b[1;33m         \u001b[0mw\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mknn_robust_wass\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx_test\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mD\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mK\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mneigh\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mepsilon\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     27\u001b[0m         \u001b[0mcur_cost\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mP\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mobjective\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mw\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0md_test\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     28\u001b[0m         \u001b[0mall_costs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcur_cost\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\Users\\rares\\Dropbox (MIT)\\Documents\\documents\\mit\\research\\spo\\e2e-model-learning-master\\e2e-model-learning-master\\newsvendor\\robust_knn.py\u001b[0m in \u001b[0;36mknn_robust_wass\u001b[1;34m(x_test, X_train, demands, K, params, neigh, epsilon)\u001b[0m\n\u001b[0;32m    197\u001b[0m     \u001b[1;31m# model.minsup(E((c_lin * w + c_quad * w2 + b_lin * yb + b_quad * yb2 + h_lin * yh + h_quad * yh2).sum()), fset)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    198\u001b[0m     \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mminsup\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mc_lin\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mw\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mE\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb_lin\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0myb\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mh_lin\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0myh\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msum\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfset\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 199\u001b[1;33m     \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msolve\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mgrb\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdisplay\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m)\u001b[0m                            \u001b[1;31m# solve the model by Gurobi\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    200\u001b[0m     \u001b[1;32mreturn\u001b[0m \u001b[0mw\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\Users\\rares\\anaconda3\\lib\\site-packages\\rsome\\dro.py\u001b[0m in \u001b[0;36msolve\u001b[1;34m(self, solver, display, params)\u001b[0m\n\u001b[0;32m    743\u001b[0m             \u001b[0msolution\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdef_sol\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdo_math\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdisplay\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    744\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 745\u001b[1;33m             \u001b[0msolution\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msolver\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msolve\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdo_math\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdisplay\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    746\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    747\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msolution\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mSolution\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\Users\\rares\\anaconda3\\lib\\site-packages\\rsome\\dro.py\u001b[0m in \u001b[0;36mdo_math\u001b[1;34m(self, primal)\u001b[0m\n\u001b[0;32m    439\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mro_model\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mst\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mro_constr_list\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    440\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 441\u001b[1;33m         \u001b[0mformula\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mro_model\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdo_math\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprimal\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    442\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mprimal\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    443\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mprimal\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mformula\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\Users\\rares\\anaconda3\\lib\\site-packages\\rsome\\ro.py\u001b[0m in \u001b[0;36mdo_math\u001b[1;34m(self, primal)\u001b[0m\n\u001b[0;32m    366\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconstr\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mRoConstr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    367\u001b[0m                 \u001b[1;32mif\u001b[0m \u001b[0mconstr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msupport\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 368\u001b[1;33m                     \u001b[0mrc_constrs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mconstr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mle_to_rc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    369\u001b[0m                 \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    370\u001b[0m                     \u001b[0mrc_constrs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mconstr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mle_to_rc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mobj_support\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\Users\\rares\\anaconda3\\lib\\site-packages\\rsome\\lp.py\u001b[0m in \u001b[0;36mle_to_rc\u001b[1;34m(self, support)\u001b[0m\n\u001b[0;32m   2108\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2109\u001b[0m         \u001b[0mleft\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdual_var\u001b[0m \u001b[1;33m@\u001b[0m \u001b[0msupport\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mnum_rand\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mT\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 2110\u001b[1;33m         \u001b[0mleft\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mleft\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mraffine\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m:\u001b[0m\u001b[0mnum_rand\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0msupport\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconst\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mnum_rand\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   2111\u001b[0m         \u001b[0msense2\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtile\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msupport\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msense\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mnum_rand\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnum_constr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2112\u001b[0m         \u001b[0mnum_rc_constr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mleft\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconst\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msize\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\Users\\rares\\anaconda3\\lib\\site-packages\\rsome\\lp.py\u001b[0m in \u001b[0;36m__getitem__\u001b[1;34m(self, item)\u001b[0m\n\u001b[0;32m    967\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    968\u001b[0m         \u001b[0mindices\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msparray\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mitem\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 969\u001b[1;33m         \u001b[0mlinear\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msv_to_csr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mindices\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m@\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    970\u001b[0m         \u001b[0mconst\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconst\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mitem\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    971\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\Users\\rares\\anaconda3\\lib\\site-packages\\rsome\\subroutines.py\u001b[0m in \u001b[0;36msv_to_csr\u001b[1;34m(array)\u001b[0m\n\u001b[0;32m    164\u001b[0m         \u001b[0mindptr\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m+\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mindptr\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mall_items\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m-\u001b[0m \u001b[0mzero_counts\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    165\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 166\u001b[1;33m     \u001b[1;32mreturn\u001b[0m \u001b[0mcsr_matrix\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mindptr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0msize\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mall_items\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnvar\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    167\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    168\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\Users\\rares\\anaconda3\\lib\\site-packages\\scipy\\sparse\\compressed.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, arg1, shape, dtype, copy)\u001b[0m\n\u001b[0;32m     68\u001b[0m                                                 check_contents=True)\n\u001b[0;32m     69\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 70\u001b[1;33m                     self.indices = np.array(indices, copy=copy,\n\u001b[0m\u001b[0;32m     71\u001b[0m                                             dtype=idx_dtype)\n\u001b[0;32m     72\u001b[0m                     \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mindptr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mindptr\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcopy\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0midx_dtype\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# KNN + Wasserstein\n",
    "k_nn = 10\n",
    "all_wass_costs = []\n",
    "for epsilon in np.arange(50,100,0.1):\n",
    "    print(\"Wasserstein distance: \", epsilon)\n",
    "    scores = test_wass(k_nn, epsilon)\n",
    "    all_wass_costs.append(scores)\n",
    "    print(\"Mean: \", np.mean(scores), \"Quantile 99: \", np.quantile(scores, 0.99))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Running MLE\")\n",
    "Theta_est = mle.linear_softmax_reg(X, Y, params)\n",
    "f_eval_mle, z_buy, f_opt = mle.newsvendor_eval(X_test, Y_test, Theta_est, np.zeros((params['n'], len(params['d']))), params)\n",
    "\n",
    "print(\"Running Policy Model\")\n",
    "policy_lin_score = policy_net.run_policy_net(X, Y, X_test, Y_test, params)\n",
    "\n",
    "print(\"Running Task-Based Model\")\n",
    "e2e_lin_score = task_net.run_task_net(X, Y, X_test, Y_test, params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 826,
   "metadata": {},
   "outputs": [],
   "source": [
    "QUANTILE = .99\n",
    "\n",
    "cap_means = [np.mean(scores) for scores in all_discretization_costs[:9]+all_discretization_costs[14:]]\n",
    "cap_q = [np.quantile(scores, QUANTILE) for scores in all_discretization_costs[:9]+all_discretization_costs[14:]]\n",
    "\n",
    "# cap_means = [np.mean(scores) for scores in all_discretization_costs]\n",
    "# cap_q = [np.quantile(scores, QUANTILE) for scores in all_discretization_costs]\n",
    "plt.scatter(cap_means, cap_q, label='Discretization')\n",
    "\n",
    "means_kl = [np.mean(scores) for scores in all_kl_costs[2:]]\n",
    "q95_kl = [np.quantile(scores, QUANTILE) for scores in all_kl_costs[2:]]\n",
    "plt.scatter(means_kl, q95_kl, label='KNN + KL Divergence', marker='x')\n",
    "\n",
    "# cap_means_linear = [np.mean(scores) for scores in all_cap_scores_linear]\n",
    "# cap_q_linear = [np.quantile(scores, QUANTILE) for scores in all_cap_scores_linear]\n",
    "# plt.plot(cap_means_linear, cap_q_linear, label='Discrete (linear)')\n",
    "\n",
    "# # plt.scatter(np.mean(knn_scores), np.quantile(knn_scores, QUANTILE), label='knn', color = 'green')\n",
    "plt.scatter(np.mean(e2e_lin_score), np.quantile(e2e_lin_score, QUANTILE), label='Task-Based', color='black', marker='v' )\n",
    "plt.scatter(np.mean(f_eval_mle), np.quantile(f_eval_mle, QUANTILE), label='MLE', color='red', marker='^')\n",
    "plt.scatter(np.mean(policy_lin_score), np.quantile(policy_lin_score, QUANTILE), label='Policy', color='green', marker='1')\n",
    "\n",
    "plt.xlabel('mean')\n",
    "plt.ylabel('99th quantile'.format(QUANTILE))\n",
    "plt.legend(loc='lower left')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 815,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\rares\\AppData\\Local\\Temp/ipykernel_15316/3667626580.py:1: MatplotlibDeprecationWarning: hatch must consist of a string of \"*+-./OX\\ox|\" or None, but found the following invalid values \"//\". Passing invalid values is deprecated since 3.4 and will become an error two minor releases later.\n",
      "  plt.hist([all_discretization_costs[5], all_kl_costs[5]], label=['Discretization', 'KNN + KL Divergence'], hatch=['//'], linewidth=2)\n"
     ]
    }
   ],
   "source": [
    "plt.hist([all_discretization_costs[5], all_kl_costs[5]], label=['Discretization', 'KNN + KL Divergence'], hatch=['//'], linewidth=2)\n",
    "plt.xlabel(\"Cost Incurred\")\n",
    "plt.ylabel(\"Frequency\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 797,
   "metadata": {},
   "outputs": [],
   "source": [
    "xs = np.arange(0,1,0.01)\n",
    "\n",
    "cap_q = [np.quantile(all_discretization_costs[5], q) for q in xs]\n",
    "# policy_q = [np.quantile(policy_lin_score, q) for q in xs]\n",
    "\n",
    "# e2e_q = [np.quantile(e2e_lin_score, q) for q in xs]\n",
    "kl_q = [np.quantile(all_kl_costs[5], q) for q in xs]\n",
    "\n",
    "plt.plot(xs, cap_q, label = 'Discretization', linestyle='-')\n",
    "plt.plot(xs, kl_q, label = 'KNN + KL-Divergence', linestyle='--')\n",
    "\n",
    "plt.xlabel('quantile')\n",
    "plt.ylabel('cost')\n",
    "plt.legend()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 816,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "249.6893333333333"
      ]
     },
     "execution_count": 816,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(all_discretization_costs[5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 818,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "250.99608073843527"
      ]
     },
     "execution_count": 818,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(all_kl_costs[5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 823,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "task-based 75th q: 386.16624450683594\n",
      "MLE        75th q: 501.24991102376055\n",
      "policy     75th q: 394.9963912963867\n",
      "KL Div     75th q: 292.7886426068769\n",
      "KL Div     75th q: 288.0\n"
     ]
    }
   ],
   "source": [
    "q = 0.75\n",
    "print(\"task-based 75th q:\", np.quantile(e2e_lin_score, q))\n",
    "print(\"MLE        75th q:\", np.quantile(f_eval_mle, q))\n",
    "print(\"policy     75th q:\", np.quantile(policy_lin_score, q))\n",
    "print(\"KL Div     75th q:\", np.quantile(all_kl_costs[5], q))\n",
    "print(\"KL Div     75th q:\", np.quantile(all_discretization_costs[5], q))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 831,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "task-based 99th q: 561.6400146484375\n",
      "MLE        99th q: 561.6395193923181\n",
      "policy     99th q: 513.4334222412109\n",
      "KL Div     99th q: 304.0454719526291\n",
      "discre     99th q: 300.0000000000001\n"
     ]
    }
   ],
   "source": [
    "q = .99\n",
    "print(\"task-based 99th q:\", np.quantile(e2e_lin_score, q))\n",
    "print(\"MLE        99th q:\", np.quantile(f_eval_mle, q))\n",
    "print(\"policy     99th q:\", np.quantile(policy_lin_score, q))\n",
    "print(\"KL Div     99th q:\", np.quantile(all_kl_costs[15], q))\n",
    "print(\"discre     99th q:\", np.quantile(all_discretization_costs[15], q))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 827,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "task-based mean: 221.34091\n",
      "MLE        mean: 247.35413444208743\n",
      "policy     mean: 239.73256\n",
      "KL Div     mean: 220.02771755337614\n",
      "discre     mean: 238.14\n"
     ]
    }
   ],
   "source": [
    "print(\"task-based mean:\", np.mean(e2e_lin_score))\n",
    "print(\"MLE        mean:\", np.mean(f_eval_mle))\n",
    "print(\"policy     mean:\", np.mean(policy_lin_score))\n",
    "print(\"KL Div     mean:\", np.mean(all_kl_costs[0]))\n",
    "print(\"discre     mean:\", np.mean(all_discretization_costs[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 667,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Being solved by Gurobi...\n",
      "Solution status: 2\n",
      "Running time: 0.0042s\n"
     ]
    }
   ],
   "source": [
    "from rsome import dro\n",
    "from rsome import norm\n",
    "from rsome import E\n",
    "from rsome import grb_solver as grb\n",
    "import numpy as np\n",
    "import numpy.random as rd\n",
    "\n",
    "K = 10\n",
    "\n",
    "c_lin = params['c_lin']\n",
    "c_quad = params['c_quad']\n",
    "b_lin = params['b_lin']\n",
    "b_quad = params['b_quad']\n",
    "h_lin = params['h_lin']\n",
    "h_quad = params['h_quad']\n",
    "\n",
    "#  get k-nearest neighbors as the empirical data\n",
    "indx = neigh.kneighbors([x_test], K, False)\n",
    "\n",
    "D_ = D[indx][0]\n",
    "\n",
    "N = len(D_) \n",
    "\n",
    "model = dro.Model(N)\n",
    "w = model.dvar()\n",
    "d = model.rvar()\n",
    "u = model.rvar()\n",
    "\n",
    "fset = model.ambiguity()                    # create an ambiguity set\n",
    "for s in range(N):\n",
    "    fset[s].suppset(d - D_[s] <= u, -(d - D_[s]) >= u, d >= 0) # define the support for each scenario\n",
    "fset.exptset(E(u) <= theta)                 # the Wasserstein metric constraint\n",
    "pr = model.p                                # an array of scenario probabilities\n",
    "fset.probset(pr == 1/N)                     # support of scenario probabilities\n",
    "\n",
    "yb = model.dvar(N)                           # define first-stage decisions\n",
    "yh = model.dvar(N)                           # define decision rule variables\n",
    "yb.adapt(d)                                  # y affinely adapts to z\n",
    "yh.adapt(d)                                  # y affinely adapts to u\n",
    "for s in range(N):\n",
    "    yh.adapt(s)                              # y adapts to each scenario s\n",
    "    yb.adapt(s)                              # y adapts to each scenario s\n",
    "\n",
    "model.st(yb >= 0, yh >= 0)\n",
    "model.st(yb >= d - w)\n",
    "model.st(yh >= w - d)\n",
    "\n",
    "# w2 = model.dvar()\n",
    "# yh2 = model.dvar(N) \n",
    "# yb2 = model.dvar(N) \n",
    "\n",
    "# model.st(w2  >= square(w))\n",
    "# model.st(yh2 >= square(yh))\n",
    "# model.st(yb2 >= square(yb))\n",
    "\n",
    "# model.minsup(E((c_lin * w + c_quad * w2 + b_lin * yb + b_quad * yb2 + h_lin * yh + h_quad * yh2).sum()), fset)\n",
    "model.minsup(c_lin*w + E((b_lin * yb + h_lin * yh).sum()), fset)\n",
    "model.solve(grb)                            # solve the model by Gurobi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 656,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "15.0"
      ]
     },
     "execution_count": 656,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "w.get()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
