{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Consider the following extensive form game. Let $p$ be the probability that Blue chooses \"right\" and $q$ be the probability that Red chooses \"right (which is I.I.D at each timestep as Red has a single information set).\n",
    "\n",
    "![](extensive_form_diagram.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import pandas as pd\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "import sys\n",
    "sys.path.insert(0, \"../..\")\n",
    "\n",
    "############################################################################\n",
    "############################ imports and set-up ############################\n",
    "############################################################################\n",
    "\n",
    "from ai_control_games.presentation_utils import plot_with_zoom\n",
    "matplotlib.rcParams['figure.dpi'] = 600\n",
    "import matplotlib.pyplot as plt\n",
    "from tueplots import bundles\n",
    "plt.rcParams.update(bundles.aaai2024(column=\"half\", nrows=1, ncols=1))\n",
    "\n",
    "default_width, default_height = plt.rcParams[\"figure.figsize\"]\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": False,\n",
    "})\n",
    "lw = plt.rcParams['lines.linewidth']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "ps = np.linspace(0, 1, 201)  # Blue team strategies\n",
    "qs = np.linspace(0, 1, 201)  # Red team strategies\n",
    "\n",
    "P, Q = np.meshgrid(ps, qs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "The value of the game can be just calculated as the probability of getting to each leaf, times the utility for Blue at the leaf."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def f(p, q):\n",
    "    val = 0\n",
    "    val -= 1 * (1 - p) * ((1 - q) ** 2)  # Blue left, Red left left\n",
    "    val -= 0.5 * (1 - p) * (q ** 2)  # Blue left, Red right right\n",
    "    val -= 0.5 * p * ((1 - q) ** 2)  # Blue right, Red left left\n",
    "    val -= 1 * p * (q ** 2)  # Blue right, Red right right\n",
    "    # In all other cases, the value is 0\n",
    "    return val\n",
    "\n",
    "\n",
    "# Rearrange the above for a simpler form\n",
    "def alternate_f(p, q):\n",
    "    # -1 + p (0.5 - 1 q) + 2 q - 1.5 q^2\n",
    "    val = -1 + p * (0.5 - 1 * q) + 2 * q - 1.5 * q ** 2\n",
    "    return val\n",
    "\n",
    "\n",
    "def f_del_p(p, q):\n",
    "    return 0.5 - q\n",
    "\n",
    "def best_p_response(q):\n",
    "    return 0.0 if q > 0.5 else 1.0\n",
    "\n",
    "def max_p_f(q):\n",
    "    return alternate_f(best_p_response(q), q)\n",
    "\n",
    "def f_del_q(p, q):\n",
    "    return 2 - p - (3 * q)\n",
    "\n",
    "def best_q_response(p):\n",
    "    return 0.0 if p < 0.5 else 1.0\n",
    "\n",
    "def min_q_f(p):\n",
    "    return alternate_f(p, best_q_response(p))\n",
    "\n",
    "\n",
    "data_dicts = []\n",
    "\n",
    "for p in ps:\n",
    "    for q in qs:\n",
    "        v1 = alternate_f(p, q)\n",
    "        v2 = alternate_f(p, q)\n",
    "        assert np.isclose(v1, v2), f'{v1} != {v2}'\n",
    "        data_dicts.append({'p': p, 'q': q, 'v': v1})\n",
    "\n",
    "df = pd.DataFrame(data_dicts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generate the maximin and minimax values for this game, and plot the curves for some different values of p and q."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Column width is 8.375 cm\n",
    "cm = 1 / 2.54  # centimeters in inches\n",
    "figsize = (8.375 * cm, 8.375 * 1.62 * cm)\n",
    "figsize = (default_width, default_width * 1.318)\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize, sharex=True, sharey=True)\n",
    "\n",
    "# Use a red-ish map for red\n",
    "cmap = sns.color_palette(\"flare\", as_cmap=True)\n",
    "for q in np.linspace(0, 1, 5):\n",
    "    ax1.plot(ps, f(ps, q), label=f'$q={q:.2f}$', color=cmap(q), alpha=0.5)\n",
    "\n",
    "mins = [float(min(f(p, qs))) for p in ps]\n",
    "\n",
    "ax1.plot(ps, mins, color='black', markersize=1, label=\"$\\\\min_{q} V(p, q)$\", linestyle=':')\n",
    "ax1.axhline(max(mins), color='blue', linestyle='--', label=f\"${max(mins):.3f}$\")\n",
    "ax1.set_yticks([-1.0, -0.8, -0.6, -0.4])\n",
    "ax1.legend(loc=(1.05, 0.), fontsize=\"medium\")\n",
    "ax1.set_ylabel(\"$V(p,q)$\", fontsize=\"medium\")\n",
    "ax1.set_xlabel(\"$p$\", fontsize=\"medium\")\n",
    "ax1.tick_params(axis=\"both\", which=\"major\", labelsize=\"small\")\n",
    "\n",
    "cmap = plt.get_cmap('viridis')\n",
    "for p in np.linspace(0, 1, 5):\n",
    "    ax2.plot(qs, f(p, qs), label=f'$p={p:.2f}$', color=cmap(p), alpha=0.5)\n",
    "\n",
    "maxs = [float(max(f(ps, q))) for q in qs]\n",
    "ax2.plot(qs, [float(max(f(ps, q))) for q in qs], color='black', markersize=1, label=\"$\\\\max_{p} V(p, q)$\",\n",
    "         linestyle=':')\n",
    "ax2.axhline(min(maxs), color='red', linestyle='--', label=f\"${min(maxs):.3f}$\")\n",
    "\n",
    "ax2.set_xticks([0.0, 0.5, 1.0])\n",
    "ax2.set_yticks([-1.0, -0.8, -0.6, -0.4])\n",
    "ax2.legend(loc=(1.05, 0.0), fontsize=\"medium\")\n",
    "ax2.set_ylabel(\"$V(p,q)$\", fontsize=\"medium\")\n",
    "ax2.set_xlabel(\"$q$\", fontsize=\"medium\")\n",
    "ax2.tick_params(axis=\"both\", which=\"major\", labelsize=\"small\")\n",
    "\n",
    "ax2.grid(which=\"major\", alpha=0.5)\n",
    "ax1.sharex(ax2)\n",
    "ax1.grid(which=\"major\", alpha=0.5)\n",
    "\n",
    "plt.savefig('appendixE_minimax_maximin.pdf')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
