{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shlex, subprocess\n",
    "import os\n",
    "import gzip\n",
    "import re\n",
    "import scipy as sp\n",
    "import numpy.random as rd\n",
    "import matplotlib\n",
    "import numpy as np\n",
    "import numpy as np\n",
    "import matplotlib.mlab as mlab\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy.stats as stats\n",
    "from operator import itemgetter\n",
    "from matplotlib.colors import LogNorm\n",
    "import matplotlib.patches as mpatches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.chdir('/mnt/ryan/confidence_estimate_experiment')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prob_to_logodds(myprob):\n",
    "    return sp.log(myprob/(1-myprob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parsing(inp):    \n",
    "    DATAFILE_PATTERN = '^(.+) (.+)'\n",
    "    match = re.search(DATAFILE_PATTERN, inp)\n",
    "    return (float(match.group(1)), float(match.group(2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#predictions from the \"ground truth\" model\n",
    "pred1 = []\n",
    "with open(\"pred_original.out\", \"r\") as f:\n",
    "      for line in f:\n",
    "        pred1.append(prob_to_logodds(float(line)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#predictions and uncertainty scores from the ftrl model and --ftrl_confidence option\n",
    "pred_with_ce = []\n",
    "with open(\"pred_with_ce_original.out\", \"r\") as f:\n",
    "      for line in f:\n",
    "        pred_with_ce.append(parsing(line))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np_errors = np.array(errors)\n",
    "ce = map(lambda x: x[1][1]  ,zip(pred1, pred_with_ce))\n",
    "np_ce = np.array(ce)\n",
    "from scipy.stats import rankdata\n",
    "np_ce = rankdata(np_ce) / len(np_ce)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "errors = map(lambda x: sp.absolute(x[1][0]-x[0]),zip(pred1, pred_with_ce))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "errors_ce = sorted(zip(np_errors, np_ce), key=itemgetter(1))\n",
    "errors_ce_errors= map(lambda x: x[0], errors_ce)\n",
    "cuts = 1000\n",
    "myrange = np.array(range(0,cuts + 1))/float(cuts)*len(errors_ce_errors)\n",
    "#calculate 25, 50, and 75 percentiles\n",
    "q25 = []\n",
    "q50= []\n",
    "q75= []\n",
    "for i in range(0,cuts):\n",
    "    filt_dat = errors_ce_errors[int(myrange[i]):int(myrange[i+1])]\n",
    "    q25.append(np.percentile(filt_dat, 25))\n",
    "    q50.append(np.percentile(filt_dat, 50))\n",
    "    q75.append(np.percentile(filt_dat, 75))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "myxrange = np.array(range(0,cuts))/float(cuts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "myxrange = np.array(range(0,cuts))/float(cuts)\n",
    "red_patch = mpatches.Patch(color='red', label='75% Percentile')\n",
    "blue_patch = mpatches.Patch(color='blue', label='25% Percentile')\n",
    "green_patch = mpatches.Patch(color='green', label='50% Percentile')\n",
    "\n",
    "\n",
    "plt.figure(figsize=(15,15))\n",
    "plt.plot(myxrange, q25)\n",
    "plt.plot(myxrange, np.array(q50))\n",
    "plt.plot(myxrange, np.array(q75))\n",
    "plt.legend(handles=[blue_patch, green_patch, red_patch])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline \n",
    "plt.figure(figsize=(15,15))\n",
    "filter_expr = (np_errors < 5) & (np_ce < 2)\n",
    "plt.hist2d( np_ce[filter_expr],  np_errors[filter_expr], bins=300, norm=LogNorm())\n",
    "plt.plot(myxrange, q25, 'blue')\n",
    "plt.plot(myxrange, q50, 'green')\n",
    "plt.plot(myxrange, q75, 'red')\n",
    "plt.colorbar()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
