{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/aptx4869/opt/anaconda3/envs/grl/lib/python3.8/site-packages/haiku/_src/data_structures.py:37: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead.\n",
      "  PyTreeDef = type(jax.tree_structure(None))\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os, sys, glob\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import plotly.graph_objects as go\n",
    "import plotly.express as px\n",
    "import jax\n",
    "from jax import lax, nn, random\n",
    "import jax.numpy as jnp\n",
    "import haiku as hk\n",
    "import optax\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_data(data, x, y, outdir, tag, title, timing=None, smooth=1):\n",
    "    if smooth > 1:\n",
    "        \"\"\"\n",
    "        smooth data with moving window average.\n",
    "        that is,\n",
    "            smoothed_y[t] = average(y[t-k], y[t-k+1], ..., y[t+k-1], y[t+k])\n",
    "        where the \"smooth\" param is width of that window (2k+1)\n",
    "        \"\"\"\n",
    "        y = np.ones(smooth)\n",
    "        for datum in data:\n",
    "            x = np.asarray(datum[y])\n",
    "            z = np.ones(len(x))\n",
    "            smoothed_x = np.convolve(x,y,'same') / np.convolve(z,y,'same')\n",
    "            datum[y] = smoothed_x\n",
    "            \n",
    "    if isinstance(data, list):\n",
    "        data = pd.concat(data, ignore_index=True)\n",
    "        if timing:\n",
    "            data = data[data.timing == timing].drop('timing', axis=1)\n",
    "\n",
    "    if not os.path.isdir(outdir):\n",
    "        os.mkdir(outdir)\n",
    "    fig = plt.figure(figsize=(20, 10))\n",
    "    ax = fig.add_subplot(111)\n",
    "    sns.set(style=\"whitegrid\", font_scale=1.5)\n",
    "    sns.set_palette('Set2') # or husl\n",
    "    if 'timing' in data.columns:\n",
    "        sns.lineplot(x=x, y=y, ax=ax, data=data, hue=tag, style='timing')\n",
    "    else:\n",
    "        sns.lineplot(x=x, y=y, ax=ax, data=data, hue=tag)\n",
    "    ax.grid(True, alpha=0.8, linestyle=':')\n",
    "    ax.legend(loc='best').set_draggable(True)\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "    if timing:\n",
    "        title = f'{title}-{timing}'\n",
    "    outpath = f'{outdir}/{title}.png'\n",
    "    ax.set_title(title)\n",
    "    fig.savefig(outpath)\n",
    "    fig.show()\n",
    "    print(f'Plot Path: {outpath}')\n",
    "\n",
    "def get_datasets(files, tag, condition=None):\n",
    "    unit = 0\n",
    "    datasets = []\n",
    "    for f in files:\n",
    "        assert f.endswith('log.txt')\n",
    "        data = pd.read_csv(f, sep='\\t')\n",
    "\n",
    "        data.insert(len(data.columns), tag, condition)\n",
    "\n",
    "        datasets.append(data)\n",
    "        unit +=1\n",
    "\n",
    "    return datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[6.00746017 2.82912789 5.57424059 3.57977993 6.55085934 5.51474829\n",
      " 6.03082554 5.35784754] [6.00746017 2.82912789 5.57424059 3.57977993 6.55085934 5.51474829\n",
      " 6.03082554 5.35784754]\n"
     ]
    },
    {
     "ename": "AssertionError",
     "evalue": "(array([6.00746017, 2.82912789, 5.57424059, 3.57977993, 6.55085934,\n       5.51474829, 6.03082554, 5.35784754]), array([6.00746017, 2.82912789, 5.57424059, 3.57977993, 6.55085934,\n       5.51474829, 6.03082554, 5.35784754]))",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m/var/folders/zq/p22_4nf93xbc8nsbrt5ft09r0000gp/T/ipykernel_94644/2877876190.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     39\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     40\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m100000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m         \u001b[0mq1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpayoff\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp2\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpi2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     42\u001b[0m         \u001b[0mq2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpayoff\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp1\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpi1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     43\u001b[0m         \u001b[0mnew_pi1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_prob\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp1\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meta\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mq1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/var/folders/zq/p22_4nf93xbc8nsbrt5ft09r0000gp/T/ipykernel_94644/2877876190.py\u001b[0m in \u001b[0;36mpayoff\u001b[0;34m(p1, p2)\u001b[0m\n\u001b[1;32m     11\u001b[0m         \u001b[0mp2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mps1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m         \u001b[0;32massert\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mp2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     14\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m         \u001b[0mp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mps2\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexpand_dims\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAssertionError\u001b[0m: (array([6.00746017, 2.82912789, 5.57424059, 3.57977993, 6.55085934,\n       5.51474829, 6.03082554, 5.35784754]), array([6.00746017, 2.82912789, 5.57424059, 3.57977993, 6.55085934,\n       5.51474829, 6.03082554, 5.35784754]))"
     ]
    }
   ],
   "source": [
    "A = np.random.randint(2, 10)\n",
    "V = 10\n",
    "\n",
    "ps1 = np.random.uniform(0, V, (A, A))\n",
    "ps2 = np.random.uniform(0, V, (A, A))\n",
    "eta = 1 / V\n",
    "\n",
    "def payoff(*, p1=None, p2=None):\n",
    "    if p1 is None:\n",
    "        p = np.sum(ps1 * np.expand_dims(p2, 0), -1)\n",
    "        p2 = np.matmul(ps1, p2)\n",
    "        print(p, p2)\n",
    "        np.testing. np.all(p == p2), (p, p2)\n",
    "    else:\n",
    "        p = np.sum(ps2 * np.expand_dims(p1, 1), 0)\n",
    "        p2 = np.matmul(p1, ps)\n",
    "        print(p, p2)\n",
    "        assert np.all(p == p2), (p, p2)\n",
    "\n",
    "    return p\n",
    "\n",
    "def get_random_pi():\n",
    "    p1 = np.random.uniform(0, 1, A)\n",
    "    p1 = p1 / np.sum(p1)\n",
    "    return p1\n",
    "\n",
    "def get_prob(x):\n",
    "    return x / np.sum(x)\n",
    "\n",
    "def compute_obj(q, x, p):\n",
    "    return np.sum(eta * x * q - x * (np.log(x) - np.log(p)))\n",
    "\n",
    "\n",
    "for _ in range(100):\n",
    "    p1 = get_random_pi()\n",
    "    p2 = get_random_pi()\n",
    "    pi1 = get_random_pi()\n",
    "    pi2 = get_random_pi()\n",
    "\n",
    "    for i in range(100000):\n",
    "        q1 = payoff(p2=pi2)\n",
    "        q2 = payoff(p1=pi1)\n",
    "        new_pi1 = get_prob(p1 * np.exp(eta * q1))\n",
    "        new_pi2 = get_prob(p2 * np.exp(eta * q2))\n",
    "        if np.all(np.abs(new_pi1 - pi1) < 1e-7) and np.all(np.abs(new_pi2 - pi2) < 1e-7):\n",
    "            break\n",
    "        pi1 = new_pi1\n",
    "        pi2 = new_pi2\n",
    "    print(i)\n",
    "    q1 = payoff(p2=pi2)\n",
    "    q2 = payoff(p1=pi1)\n",
    "    tp1 = get_prob(p1 * np.exp(eta * q1))\n",
    "    tp2 = get_prob(p2 * np.exp(eta * q2))\n",
    "\n",
    "    np.testing.assert_allclose(pi1, tp1, 1e-4)\n",
    "    np.testing.assert_allclose(pi2, tp2, 1e-4)\n",
    "\n",
    "    pv1 = compute_obj(q1, pi1, p1)\n",
    "    pv2 = compute_obj(q2, pi2, p2)\n",
    "\n",
    "    pi1_pt = get_prob(pi1+np.random.uniform(0, .1, pi1.shape))\n",
    "    pi2_pt = get_prob(pi2+np.random.uniform(0, .1, pi2.shape))\n",
    "\n",
    "    pv1_pt = compute_obj(q1, pi1_pt, p1)\n",
    "    pv2_pt = compute_obj(q2, pi2_pt, p2)\n",
    "\n",
    "    assert pv1_pt <= pv1, (pv1_pt, pv1)\n",
    "    assert pv2_pt <= pv2, (pv2_pt, pv2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "ename": "SyntaxError",
     "evalue": "invalid syntax (894332461.py, line 5)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;36m  File \u001b[0;32m\"/var/folders/zq/p22_4nf93xbc8nsbrt5ft09r0000gp/T/ipykernel_18058/894332461.py\"\u001b[0;36m, line \u001b[0;32m5\u001b[0m\n\u001b[0;31m    def set_a(self, a)\u001b[0m\n\u001b[0m                      ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n"
     ]
    }
   ],
   "source": [
    "ps = "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 297,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-0.01285148  0.01285153]] [[-0.01285148  0.01285154]]\n"
     ]
    }
   ],
   "source": [
    "from jax import lax, random, nn\n",
    "\n",
    "b = 1\n",
    "d = 2\n",
    "rng = random.PRNGKey(0)\n",
    "rngs = random.split(rng)\n",
    "x = random.normal(rngs[0], (b, d))\n",
    "y = random.normal(rngs[1], (b, d))\n",
    "\n",
    "def f(y, x):\n",
    "    xprob = nn.softmax(x)\n",
    "    xlogprob = nn.log_softmax(x)\n",
    "    ylogprob = nn.log_softmax(y)\n",
    "    kl = jnp.sum(xprob * (xlogprob - ylogprob), -1)\n",
    "    return kl.mean()\n",
    "\n",
    "std_grads = jax.grad(f)(x, y)\n",
    "\n",
    "xprob = nn.softmax(x)\n",
    "yprob = nn.softmax(y)\n",
    "manual_grads = xprob - yprob\n",
    "print(std_grads, manual_grads)\n",
    "chex.assert_trees_all_close(std_grads, manual_grads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-3.9999999999999996"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def f(t, gamma=.99, r=-0.04, v=-4):\n",
    "    v_hat = 0\n",
    "    for t in range(t):\n",
    "        v_hat += gamma**t * r\n",
    "    v_hat += gamma**(t+1)*v\n",
    "    return v_hat\n",
    "f(10)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
