{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31f039ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Extra_MWU,exp1\n",
    "\n",
    "from scipy.linalg import expm, sinm, cosm\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "%matplotlib inline\n",
    "%matplotlib auto\n",
    "\n",
    "plt.style.use('seaborn-white')\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "\n",
    "#training steps\n",
    "T = 2800\n",
    "\n",
    "#step size\n",
    "s = 0.3\n",
    "eta = 0.3\n",
    "\n",
    "# x,y\n",
    "x = np.zeros((T,3))\n",
    "y = np.zeros((T,3))\n",
    "\n",
    "x_hat = np.zeros((T,3))\n",
    "y_hat = np.zeros((T,3))\n",
    "\n",
    "\n",
    "# k use to store immediate strategy\n",
    "k_1 = np.zeros(T)\n",
    "k_2 = np.zeros(T)\n",
    "k_3 = np.zeros(T)\n",
    "\n",
    "\n",
    "x[0] = [0.8, 0.1, 0.1]\n",
    "y[0] = [0.1, 0.05, 0.85]\n",
    "\n",
    "k_1[0] = y[0][0]\n",
    "k_2[0] = y[1][1]\n",
    "k_3[0] = y[2][2]\n",
    "\n",
    "def A(t):\n",
    "    if(t%2 == 0):\n",
    "        #return np.array([[0, -1, 1], [1, 0, -1], [-1, 1, 0]])\n",
    "        #return np.array([[-1,2,-1],[1.5,-2.3,1],[-1,-1,1]])\n",
    "        return np.array([[0, 0.25, 0.75], [1.5, 0, 0], [0, 1, 0]])\n",
    "    if(t%2 == 1):\n",
    "        return np.array([[0, 0.75, 0.25], [1.5, 0, 0], [0, 0, 1]])\n",
    "\n",
    "\n",
    "\n",
    "def A_21(t):\n",
    "    return np.array([[-1*A(t)[0][0],-1*A(t)[1][0],-1*A(t)[2][0]], [-1*A(t)[0][1],-1*A(t)[1][1],-1*A(t)[2][1]], [-1*A(t)[0][2],-1*A(t)[1][2],-1*A(t)[2][2]]])\n",
    "\n",
    "\n",
    "for t in range(0,T-1):        \n",
    "    #eta = 0.5 * (1/math.log(t+2))\n",
    "    z = np.dot(A_21(t),x[t].T).T\n",
    "    d_1 = y[t][0] * math.exp(s*z[0]) + y[t][1] * math.exp(s*z[1]) + y[t][2] * math.exp(s*z[2])\n",
    "    y_hat[t+1][0] = y[t][0] * math.exp(s*z[0])/d_1\n",
    "    y_hat[t+1][1] = y[t][1] * math.exp(s*z[1])/d_1\n",
    "    y_hat[t+1][2] = y[t][2] * math.exp(s*z[2])/d_1\n",
    "    \n",
    "    w = np.dot(A(t),y[t].T).T\n",
    "    d_2 = x[t][0] * math.exp(s*w[0]) + x[t][1] * math.exp(s*w[1]) + x[t][2] * math.exp(s*w[2])\n",
    "    x_hat[t+1][0] = x[t][0] * math.exp(s*w[0])/d_2\n",
    "    x_hat[t+1][1] = x[t][1] * math.exp(s*w[1])/d_2\n",
    "    x_hat[t+1][2] = x[t][2] * math.exp(s*w[2])/d_2\n",
    "    \n",
    "    z_2 = np.dot(A_21(t),x_hat[t+1].T).T\n",
    "    d_3 = y[t][0] * math.exp(eta*z_2[0]) + y[t][1] * math.exp(eta*z_2[1]) + y[t][2] * math.exp(eta*z_2[2])\n",
    "    y[t+1][0] = y[t][0] * math.exp(eta*z_2[0])/d_3\n",
    "    y[t+1][1] = y[t][1] * math.exp(eta*z_2[1])/d_3\n",
    "    y[t+1][2] = y[t][2] * math.exp(eta*z_2[2])/d_3\n",
    "    \n",
    "    k_1[t+1] = y[t+1][0] \n",
    "    k_2[t+1] = y[t+1][1] \n",
    "    k_3[t+1] = y[t+1][2] \n",
    "    \n",
    "    w_2 = np.dot(A(t),y_hat[t+1].T).T\n",
    "    d_4 = x[t][0] * math.exp(eta*w_2[0]) + x[t][1] * math.exp(eta*w_2[1]) + x[t][2] * math.exp(eta*w_2[2])\n",
    "    x[t+1][0] = x[t][0] * math.exp(eta*w_2[0])/d_4\n",
    "    x[t+1][1] = x[t][1] * math.exp(eta*w_2[1])/d_4\n",
    "    x[t+1][2] = x[t][2] * math.exp(eta*w_2[2])/d_4    \n",
    "    \n",
    "    \n",
    "    \n",
    "# draw boundary of simplex, not important\n",
    "z_1 = np.linspace(0.0, 1.0, num=500)\n",
    "x_1 = 1 - z_1\n",
    "y_1 = np.zeros(500)\n",
    "\n",
    "y_2 = np.linspace(0.0, 1.0, num=500)\n",
    "x_2 = 1 - y_2\n",
    "z_2 = np.zeros(500)\n",
    "\n",
    "y_3 = np.linspace(0.0, 1.0, num=500)\n",
    "z_3 = 1 - y_3\n",
    "x_3 = np.zeros(500)\n",
    "\n",
    "fig = plt.figure()\n",
    "ax1 = plt.axes(projection='3d')\n",
    "ax1.scatter(k_1,k_2,k_3, cmap='Blues', s =0.3)\n",
    "ax1.scatter(y[0][0],y[0][1],y[0][2], s =40) \n",
    "ax1.scatter(1/4,3/8,3/8, s =40) \n",
    "ax1.scatter(x_1,y_1,z_1, s =0.1,c='k') \n",
    "ax1.scatter(x_2,y_2,z_2, s =0.1,c='k') \n",
    "ax1.scatter(x_3,y_3,z_3, s =0.1,c='k') \n",
    "\n",
    "\n",
    "plt.legend([\"Trajectory of  Extra_MWU\"])\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "505c68d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "time_steps = range(T)\n",
    "\n",
    "plt.figure(figsize=(10, 6))  \n",
    "plt.plot(time_steps, k_1, label='Strategy 1')\n",
    "plt.plot(time_steps, k_2, label='Strategy 2')\n",
    "plt.plot(time_steps, k_3, label='Strategy 3')\n",
    "\n",
    "plt.legend()\n",
    "\n",
    "plt.title('Strategies Over Time, Extra-MWU')\n",
    "plt.xlabel('Time Steps')\n",
    "plt.ylabel('Strategy Probability')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c846ccbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "\n",
    "def kl(x1, x2, x3, y1, y2, y3):\n",
    "    return 1/2 * math.log(1 / (2*x1)) + 1/4 * math.log(1/(4*x2)) + 1/4 * math.log(1/(4*x3)) + 1/4 * math.log(1/(4*y1)) + 3/8 * math.log(3/(8*y2)) + 3/8 * math.log(3/(8*y3))\n",
    "\n",
    "\n",
    "value = np.zeros(T-1)  \n",
    "\n",
    "for t in range(0, T-1):\n",
    "    value[t] = kl(x[t][0], x[t][1], x[t][2], y[t][0], y[t][1], y[t][2])\n",
    "\n",
    "t = np.arange(0, T-1)\n",
    "fig = plt.figure()\n",
    "\n",
    "value[0] = value[1]\n",
    "\n",
    "plt.plot(t, value)\n",
    "\n",
    "\n",
    "plt.title('KL-divergence Over Time, Extra-MWU')\n",
    "plt.xlabel('Time Steps')\n",
    "plt.ylabel('KL-divergence')\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c4052cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "#OMWU, exp1\n",
    "\n",
    "from scipy.linalg import expm, sinm, cosm\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "%matplotlib inline\n",
    "%matplotlib auto\n",
    "plt.style.use('seaborn-white')\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "\n",
    "#training steps\n",
    "T = 4000\n",
    "\n",
    "#step size\n",
    "s = 0.3\n",
    "\n",
    "\n",
    "# x,y\n",
    "x = np.zeros((T,3))\n",
    "y = np.zeros((T,3))\n",
    "\n",
    "\n",
    "# k use to store immediate strategy\n",
    "k_1 = np.zeros(T)\n",
    "k_2 = np.zeros(T)\n",
    "k_3 = np.zeros(T)\n",
    "\n",
    "\n",
    "#define initial conditions on accumulated strategy / payoff\n",
    "\n",
    "x[0] = [0.7, 0.2, 0.1]\n",
    "y[0] = [0.5, 0.3, 0.2]\n",
    "\n",
    "x[1] = [0.5, 0.3, 0.2]\n",
    "y[1] = [0.7, 0.2, 0.1]\n",
    "\n",
    "k_1[0] = y[0][0]\n",
    "k_2[0] = y[1][1]\n",
    "k_3[0] = y[2][2]\n",
    "\n",
    "\n",
    "def A(t):\n",
    "    if(t%2 == 0):\n",
    "        return np.array([[0, 0.25, 0.75], [1.5, 0, 0], [0, 1, 0]])\n",
    "\n",
    "    else:\n",
    "        return np.array([[0, 0.75, 0.25], [1.5, 0, 0], [0, 0, 1]])\n",
    "\n",
    "        \n",
    "def A_21(t):\n",
    "    return np.array([[-1*A(t)[0][0],-1*A(t)[1][0],-1*A(t)[2][0]], [-1*A(t)[0][1],-1*A(t)[1][1],-1*A(t)[2][1]], [-1*A(t)[0][2],-1*A(t)[1][2],-1*A(t)[2][2]]])\n",
    "\n",
    "\n",
    "\n",
    "for t in range(1,T-1):\n",
    "    \n",
    "    w = 2*np.dot(A(t),y[t].T).T - np.dot(A(t-1),y[t-1].T).T\n",
    "    d_2 = x[t][0] * math.exp(s*w[0]) + x[t][1] * math.exp(s*w[1]) + x[t][2] * math.exp(s*w[2])\n",
    "    x[t+1][0] = x[t][0] * math.exp(s*w[0])/d_2\n",
    "    x[t+1][1] = x[t][1] * math.exp(s*w[1])/d_2\n",
    "    x[t+1][2] = x[t][2] * math.exp(s*w[2])/d_2\n",
    "\n",
    "\n",
    "    z = 2*np.dot(A_21(t),x[t].T).T - np.dot(A_21(t-1),x[t-1].T).T\n",
    "    d_1 = y[t][0] * math.exp(s*z[0]) + y[t][1] * math.exp(s*z[1]) + y[t][2] * math.exp(s*z[2])\n",
    "    y[t+1][0] = y[t][0] * math.exp(s*z[0])/d_1\n",
    "    y[t+1][1] = y[t][1] * math.exp(s*z[1])/d_1\n",
    "    y[t+1][2] = y[t][2] * math.exp(s*z[2])/d_1\n",
    "    \n",
    "    \n",
    "    k_1[t+1] = y[t+1][0] \n",
    "    k_2[t+1] = y[t+1][1] \n",
    "    k_3[t+1] = y[t+1][2] \n",
    "    \n",
    "    \n",
    "# draw boundary of simplex, not important\n",
    "z_1 = np.linspace(0.0, 1.0, num=500)\n",
    "x_1 = 1 - z_1\n",
    "y_1 = np.zeros(500)\n",
    "\n",
    "y_2 = np.linspace(0.0, 1.0, num=500)\n",
    "x_2 = 1 - y_2\n",
    "z_2 = np.zeros(500)\n",
    "\n",
    "y_3 = np.linspace(0.0, 1.0, num=500)\n",
    "z_3 = 1 - y_3\n",
    "x_3 = np.zeros(500)\n",
    "\n",
    "fig = plt.figure()\n",
    "ax1 = plt.axes(projection='3d')\n",
    "ax1.scatter(k_1,k_2,k_3, cmap='Blues', s =0.3)\n",
    "ax1.scatter(y[0][0],y[0][1],y[0][2], s =40) \n",
    "ax1.scatter(1/4,3/8,3/8, s =40) \n",
    "ax1.scatter(x_1,y_1,z_1, s =0.1,c='k') \n",
    "ax1.scatter(x_2,y_2,z_2, s =0.1,c='k') \n",
    "ax1.scatter(x_3,y_3,z_3, s =0.1,c='k') \n",
    "\n",
    "\n",
    "plt.legend([\"Trajectory of strategy\"])\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26c3a4e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "time_steps = range(T)\n",
    "\n",
    "plt.figure(figsize=(10, 6))  \n",
    "plt.plot(time_steps, k_1, label='Strategy 1')\n",
    "plt.plot(time_steps, k_2, label='Strategy 2')\n",
    "plt.plot(time_steps, k_3, label='Strategy 3')\n",
    "\n",
    "plt.legend()\n",
    "plt.title('Strategies Over Time, OMWU')\n",
    "plt.xlabel('Time Steps')\n",
    "plt.ylabel('Strategy Probability')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7871eac",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "\n",
    "def kl(x1, x2, x3, y1, y2, y3):\n",
    "    return 1/2 * math.log(1 / (2*x1)) + 1/4 * math.log(1/(4*x2)) + 1/4 * math.log(1/(4*x3)) + 1/4 * math.log(1/(4*y1)) + 3/8 * math.log(3/(8*y2)) + 3/8 * math.log(3/(8*y3))\n",
    "\n",
    "\n",
    "value = np.zeros(T-1)  \n",
    "\n",
    "for t in range(0, T-1):\n",
    "    value[t] = kl(x[t][0], x[t][1], x[t][2], y[t][0], y[t][1], y[t][2])\n",
    "\n",
    "t = np.arange(0, T-1)\n",
    "fig = plt.figure()\n",
    "\n",
    "value[0] = value[1]\n",
    "\n",
    "plt.plot(t, value)\n",
    "\n",
    "plt.title('KL-divergence Over Time, OMWU')\n",
    "plt.xlabel('Time Steps')\n",
    "plt.ylabel('KL-divergence')\n",
    "\n",
    "plt.show()\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "812fcef7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Extra_MWU,exp2\n",
    "\n",
    "from scipy.linalg import expm, sinm, cosm\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "%matplotlib inline\n",
    "%matplotlib auto\n",
    "\n",
    "plt.style.use('seaborn-white')\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "\n",
    "#training steps\n",
    "T = 600\n",
    "\n",
    "#step size\n",
    "s = 0.3\n",
    "eta = 0.1\n",
    "\n",
    "# x,y\n",
    "x = np.zeros((T,3))\n",
    "y = np.zeros((T,3))\n",
    "\n",
    "x_hat = np.zeros((T,3))\n",
    "y_hat = np.zeros((T,3))\n",
    "\n",
    "\n",
    "# k use to store immediate strategy\n",
    "k_1 = np.zeros(T)\n",
    "k_2 = np.zeros(T)\n",
    "k_3 = np.zeros(T)\n",
    "\n",
    "\n",
    "x[0] = [0.8, 0.1, 0.1]\n",
    "y[0] = [0.1, 0.05, 0.85]\n",
    "\n",
    "k_1[0] = y[0][0]\n",
    "k_2[0] = y[1][1]\n",
    "k_3[0] = y[2][2]\n",
    "\n",
    "def A(t):\n",
    "    if(t%4 == 0):\n",
    "        return np.array([[0, -1, 1], [1, 0, -1], [-1, 1, 0]])\n",
    "    if(t%4 == 1):\n",
    "        return np.array([[0, 1, -1], [-1, 0, 1], [1, -1, 0]])\n",
    "    if(t%4 == 2):\n",
    "        return np.array([[1, -3, 2], [-2, 1, 1], [1, 2, -3]])\n",
    "    else:\n",
    "        return np.array([[-1, 2, -1],[2, -1, -1], [-1, -1, 2]])\n",
    "\n",
    "\n",
    "\n",
    "def A_21(t):\n",
    "    return np.array([[-1*A(t)[0][0],-1*A(t)[1][0],-1*A(t)[2][0]], [-1*A(t)[0][1],-1*A(t)[1][1],-1*A(t)[2][1]], [-1*A(t)[0][2],-1*A(t)[1][2],-1*A(t)[2][2]]])\n",
    "\n",
    "\n",
    "for t in range(0,T-1):        \n",
    "    #eta = 0.5 * (1/math.log(t+2))\n",
    "    z = np.dot(A_21(t),x[t].T).T\n",
    "    d_1 = y[t][0] * math.exp(s*z[0]) + y[t][1] * math.exp(s*z[1]) + y[t][2] * math.exp(s*z[2])\n",
    "    y_hat[t+1][0] = y[t][0] * math.exp(s*z[0])/d_1\n",
    "    y_hat[t+1][1] = y[t][1] * math.exp(s*z[1])/d_1\n",
    "    y_hat[t+1][2] = y[t][2] * math.exp(s*z[2])/d_1\n",
    "    \n",
    "    w = np.dot(A(t),y[t].T).T\n",
    "    d_2 = x[t][0] * math.exp(s*w[0]) + x[t][1] * math.exp(s*w[1]) + x[t][2] * math.exp(s*w[2])\n",
    "    x_hat[t+1][0] = x[t][0] * math.exp(s*w[0])/d_2\n",
    "    x_hat[t+1][1] = x[t][1] * math.exp(s*w[1])/d_2\n",
    "    x_hat[t+1][2] = x[t][2] * math.exp(s*w[2])/d_2\n",
    "    \n",
    "    z_2 = np.dot(A_21(t),x_hat[t+1].T).T\n",
    "    d_3 = y[t][0] * math.exp(eta*z_2[0]) + y[t][1] * math.exp(eta*z_2[1]) + y[t][2] * math.exp(eta*z_2[2])\n",
    "    y[t+1][0] = y[t][0] * math.exp(eta*z_2[0])/d_3\n",
    "    y[t+1][1] = y[t][1] * math.exp(eta*z_2[1])/d_3\n",
    "    y[t+1][2] = y[t][2] * math.exp(eta*z_2[2])/d_3\n",
    "    \n",
    "    k_1[t+1] = y[t+1][0] \n",
    "    k_2[t+1] = y[t+1][1] \n",
    "    k_3[t+1] = y[t+1][2] \n",
    "    \n",
    "    w_2 = np.dot(A(t),y_hat[t+1].T).T\n",
    "    d_4 = x[t][0] * math.exp(eta*w_2[0]) + x[t][1] * math.exp(eta*w_2[1]) + x[t][2] * math.exp(eta*w_2[2])\n",
    "    x[t+1][0] = x[t][0] * math.exp(eta*w_2[0])/d_4\n",
    "    x[t+1][1] = x[t][1] * math.exp(eta*w_2[1])/d_4\n",
    "    x[t+1][2] = x[t][2] * math.exp(eta*w_2[2])/d_4    \n",
    "    \n",
    "    \n",
    "time_steps = range(T)\n",
    "\n",
    "plt.figure(figsize=(10, 6))  \n",
    "plt.plot(time_steps, k_1, label='Strategy 1')\n",
    "plt.plot(time_steps, k_2, label='Strategy 2')\n",
    "plt.plot(time_steps, k_3, label='Strategy 3')\n",
    "\n",
    "plt.legend()\n",
    "plt.title('Strategies Over Time, Extra-MWU')\n",
    "plt.xlabel('Time Steps')\n",
    "plt.ylabel('Strategy Probability')\n",
    "\n",
    "plt.show()\n",
    "\n",
    "def kl(x1, x2, x3, y1, y2, y3):\n",
    "    return 1/3 * math.log(1 / (3*x1)) + 1/3 * math.log(1/(3*x2)) + 1/3 * math.log(1/(3*x3)) + 1/3 * math.log(1/(3*y1)) + 1/3 * math.log(1/(3*y2)) + 1/3 * math.log(1/(3*y3))\n",
    "\n",
    "\n",
    "value = np.zeros(T-1)  \n",
    "\n",
    "for t in range(0, T-1):\n",
    "    value[t] = kl(x[t][0], x[t][1], x[t][2], y[t][0], y[t][1], y[t][2])\n",
    "\n",
    "t = np.arange(0, T-1)\n",
    "fig = plt.figure()\n",
    "\n",
    "value[0] = value[1]\n",
    "\n",
    "plt.plot(t, value)\n",
    "\n",
    "plt.title('KL-divergence Over Time, Extra-MWU')\n",
    "plt.xlabel('Time Steps')\n",
    "plt.ylabel('KL-divergence')\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60659ff9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#OMWU, exp2\n",
    "\n",
    "from scipy.linalg import expm, sinm, cosm\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "%matplotlib inline\n",
    "%matplotlib auto\n",
    "plt.style.use('seaborn-white')\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "\n",
    "#training steps\n",
    "T = 1000\n",
    "\n",
    "#step size\n",
    "s = 0.1\n",
    "\n",
    "\n",
    "# x,y\n",
    "x = np.zeros((T,3))\n",
    "y = np.zeros((T,3))\n",
    "\n",
    "\n",
    "# k use to store immediate strategy\n",
    "k_1 = np.zeros(T)\n",
    "k_2 = np.zeros(T)\n",
    "k_3 = np.zeros(T)\n",
    "\n",
    "\n",
    "#define initial conditions on accumulated strategy / payoff\n",
    "\n",
    "x[0] = [0.7, 0.2, 0.1]\n",
    "y[0] = [0.5, 0.3, 0.2]\n",
    "\n",
    "x[1] = [0.5, 0.3, 0.2]\n",
    "y[1] = [0.7, 0.2, 0.1]\n",
    "\n",
    "k_1[0] = y[0][0]\n",
    "k_2[0] = y[1][1]\n",
    "k_3[0] = y[2][2]\n",
    "\n",
    "\n",
    "def A(t):\n",
    "    if(t%4 == 0):\n",
    "        return np.array([[0, -1, 1], [1, 0, -1], [-1, 1, 0]])\n",
    "    if(t%4 == 1):\n",
    "        return np.array([[0, 1, -1], [-1, 0, 1], [1, -1, 0]])\n",
    "    if(t%4 == 2):\n",
    "        return np.array([[1, -3, 2], [-2, 1, 1], [1, 2, -3]])\n",
    "    else:\n",
    "        return np.array([[-1, 2, -1],[2, -1, -1], [-1, -1, 2]])\n",
    "        \n",
    "def A_21(t):\n",
    "    return np.array([[-1*A(t)[0][0],-1*A(t)[1][0],-1*A(t)[2][0]], [-1*A(t)[0][1],-1*A(t)[1][1],-1*A(t)[2][1]], [-1*A(t)[0][2],-1*A(t)[1][2],-1*A(t)[2][2]]])\n",
    "\n",
    "\n",
    "\n",
    "for t in range(1,T-1):\n",
    "    \n",
    "    w = 2*np.dot(A(t),y[t].T).T - np.dot(A(t-1),y[t-1].T).T\n",
    "    d_2 = x[t][0] * math.exp(s*w[0]) + x[t][1] * math.exp(s*w[1]) + x[t][2] * math.exp(s*w[2])\n",
    "    x[t+1][0] = x[t][0] * math.exp(s*w[0])/d_2\n",
    "    x[t+1][1] = x[t][1] * math.exp(s*w[1])/d_2\n",
    "    x[t+1][2] = x[t][2] * math.exp(s*w[2])/d_2\n",
    "\n",
    "\n",
    "    z = 2*np.dot(A_21(t),x[t].T).T - np.dot(A_21(t-1),x[t-1].T).T\n",
    "    d_1 = y[t][0] * math.exp(s*z[0]) + y[t][1] * math.exp(s*z[1]) + y[t][2] * math.exp(s*z[2])\n",
    "    y[t+1][0] = y[t][0] * math.exp(s*z[0])/d_1\n",
    "    y[t+1][1] = y[t][1] * math.exp(s*z[1])/d_1\n",
    "    y[t+1][2] = y[t][2] * math.exp(s*z[2])/d_1\n",
    "    \n",
    "    \n",
    "    k_1[t+1] = y[t+1][0] \n",
    "    k_2[t+1] = y[t+1][1] \n",
    "    k_3[t+1] = y[t+1][2] \n",
    "    \n",
    "    \n",
    "    \n",
    "time_steps = range(T)\n",
    "\n",
    "plt.figure(figsize=(10, 6))  \n",
    "plt.plot(time_steps, k_1, label='Strategy 1')\n",
    "plt.plot(time_steps, k_2, label='Strategy 2')\n",
    "plt.plot(time_steps, k_3, label='Strategy 3')\n",
    "\n",
    "plt.legend()\n",
    "plt.title('Strategies Over Time, OMWU')\n",
    "plt.xlabel('Time Steps')\n",
    "plt.ylabel('Strategy Probability')\n",
    "\n",
    "plt.show()\n",
    "\n",
    "def kl(x1, x2, x3, y1, y2, y3):\n",
    "    epsilon = 0\n",
    "    return 1/3 * math.log(1 / (3*(x1+epsilon))) + 1/3 * math.log(1/(3*(x2+epsilon))) + 1/3 * math.log(1/(3*(x3+epsilon))) + 1/3 * math.log(1/(3*(y1+epsilon))) + 1/3 * math.log(1/(3*(y2+epsilon))) + 1/3 * math.log(1/(3*(y3+epsilon)))\n",
    "\n",
    "value = np.zeros(T-1)  \n",
    "\n",
    "for t in range(0, T-1):\n",
    "    value[t] = kl(x[t][0], x[t][1], x[t][2], y[t][0], y[t][1], y[t][2])\n",
    "\n",
    "t = np.arange(0, T-1)\n",
    "fig = plt.figure()\n",
    "\n",
    "value[0] = value[1]\n",
    "\n",
    "plt.plot(t, value)\n",
    "\n",
    "plt.title('KL-divergence Over Time, OMWU')\n",
    "plt.xlabel('Time Steps')\n",
    "plt.ylabel('KL-divergence')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "98a82aa8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using matplotlib backend: MacOSX\n"
     ]
    }
   ],
   "source": [
    "#ani,extra-mwu\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.animation as animation\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "%matplotlib inline\n",
    "%matplotlib auto\n",
    "\n",
    "plt.style.use('seaborn-white')\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "\n",
    "\n",
    "# Define constants\n",
    "T = 1500  # Number of training steps\n",
    "s = 0.5  # Step size\n",
    "eta = 0.1  # Learning rate\n",
    "\n",
    "x_strategy_history_mod_0 = []\n",
    "x_strategy_history_mod_1 = []\n",
    "x_strategy_history_mod_2 = []\n",
    "\n",
    "# Define the payoff matrices\n",
    "def get_matrix_A(t):\n",
    "    if t % 3 == 0:\n",
    "        return np.array([[0, 1, -1], [-1, 0, 1], [1, -1, 0]])  # Rock-Paper-Scissors\n",
    "    elif t % 3 == 1:\n",
    "        return np.array([[0, -1, 1], [1, 0, -1], [-1, 1, 0]])  # Variant of RPS\n",
    "    else:\n",
    "        return np.array([[0, 0.25, 0.75], [1.5, 0, 0], [0, 1, 0]])  # Another variant\n",
    "\n",
    "# Initialize strategies\n",
    "x, y = np.zeros((T, 3)), np.zeros((T, 3))\n",
    "x_hat, y_hat = np.zeros((T, 3)), np.zeros((T, 3))\n",
    "x[0], y[0] = [0.3, 0.2, 0.5], [0.1, 0.4, 0.5]\n",
    "\n",
    "for t in range(T - 1):\n",
    "    A_t, A_21_t = get_matrix_A(t), -1 * get_matrix_A(t).T\n",
    "\n",
    "    z = np.dot(A_21_t, x[t])\n",
    "    d_1 = np.sum(y[t] * np.exp(s * z))\n",
    "    y_hat[t + 1] = y[t] * np.exp(s * z) / d_1\n",
    "\n",
    "    w = np.dot(A_t, y[t])\n",
    "    d_2 = np.sum(x[t] * np.exp(s * w))\n",
    "    x_hat[t + 1] = x[t] * np.exp(s * w) / d_2\n",
    "\n",
    "    z_2 = np.dot(A_21_t, x_hat[t + 1])\n",
    "    d_3 = np.sum(y[t] * np.exp(eta * z_2))\n",
    "    y[t + 1] = y[t] * np.exp(eta * z_2) / d_3\n",
    "\n",
    "    w_2 = np.dot(A_t, y_hat[t + 1])\n",
    "    d_4 = np.sum(x[t] * np.exp(eta * w_2))\n",
    "    x[t + 1] = x[t] * np.exp(eta * w_2) / d_4\n",
    "    \n",
    "    if t % 3 == 0:\n",
    "        x_strategy_history_mod_0.append(x[t + 1])\n",
    "    elif t % 3 == 1:\n",
    "        x_strategy_history_mod_1.append(x[t + 1])\n",
    "    else:  \n",
    "        x_strategy_history_mod_2.append(x[t + 1])\n",
    "\n",
    "# Create a 3D plot\n",
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(111, projection='3d')\n",
    "\n",
    "\n",
    "elevation_angle = 30 \n",
    "azimuth_angle = 45    \n",
    "ax.view_init(elev=elevation_angle, azim=azimuth_angle)\n",
    "\n",
    "\n",
    "\n",
    "# Special points to mark\n",
    "special_points = np.array([[0.3333, 0.3333, 0.3333], [0.5, 0.25, 0.25]])\n",
    "\n",
    "# Function to update the plot for each frame\n",
    "def update(num, x, y, plot):\n",
    "    ax.clear()\n",
    "    # Draw simplex boundaries\n",
    "    ax.plot([1, 0], [0, 1], [0, 0], 'k-')  # Line from (1, 0, 0) to (0, 1, 0)\n",
    "    ax.plot([0, 0], [1, 0], [0, 1], 'k-')  # Line from (0, 1, 0) to (0, 0, 1)\n",
    "    ax.plot([1, 0], [0, 0], [0, 1], 'k-')  # Line from (1, 0, 0) to (0, 0, 1)\n",
    "    \n",
    "    # Draw points of the trajectory\n",
    "    ax.scatter(x[num, 0], x[num, 1], x[num, 2], color='red', s=10)\n",
    "    ax.scatter(special_points[:, 0], special_points[:, 1], special_points[:, 2], color='green', s=50)\n",
    "    \n",
    "    # Draw the last point of each strategy history with distinct colors\n",
    "    if len(x_strategy_history_mod_0) > 0:\n",
    "        ax.scatter(x_strategy_history_mod_0[-1][0], x_strategy_history_mod_0[-1][1], x_strategy_history_mod_0[-1][2], color='gold', s=50, marker='^')\n",
    "    if len(x_strategy_history_mod_1) > 0:\n",
    "        ax.scatter(x_strategy_history_mod_1[-1][0], x_strategy_history_mod_1[-1][1], x_strategy_history_mod_1[-1][2], color='cyan', s=50, marker='^')\n",
    "    if len(x_strategy_history_mod_2) > 0:\n",
    "        ax.scatter(x_strategy_history_mod_2[-1][0], x_strategy_history_mod_2[-1][1], x_strategy_history_mod_2[-1][2], color='magenta', s=50, marker='^')\n",
    "\n",
    "# Creating the animation\n",
    "ani = animation.FuncAnimation(fig, update, T, fargs=(x, y, None), interval=10)\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "9b125152",
   "metadata": {},
   "outputs": [],
   "source": [
    "min_length = min(len(x_strategy_history_mod_0), len(x_strategy_history_mod_1), len(x_strategy_history_mod_2))\n",
    "x_strategy_history_mod_0 = x_strategy_history_mod_0[:min_length]\n",
    "x_strategy_history_mod_1 = x_strategy_history_mod_1[:min_length]\n",
    "x_strategy_history_mod_2 = x_strategy_history_mod_2[:min_length]\n",
    "\n",
    "\n",
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "plt.plot([x[0] for x in x_strategy_history_mod_0[20:]], label='t mod 3 = 0, Strategy 1')\n",
    "plt.plot([x[0] for x in x_strategy_history_mod_1[20:]], label='t mod 3 = 1, Strategy 1')\n",
    "plt.plot([x[0] for x in x_strategy_history_mod_2[20:]], label='t mod 3 = 2, Strategy 1')\n",
    "\n",
    "plt.legend()\n",
    "plt.title('Extra-MWU, Strategy 1 Evolution Over Time for Different t mod 3 Cases')\n",
    "plt.xlabel('Time Step')\n",
    "plt.ylabel('Strategy 1 Value')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "5499312c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using matplotlib backend: MacOSX\n"
     ]
    }
   ],
   "source": [
    "#OMWU, exp2\n",
    "\n",
    "from scipy.linalg import expm, sinm, cosm\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "%matplotlib inline\n",
    "%matplotlib auto\n",
    "plt.style.use('seaborn-white')\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "\n",
    "#training steps\n",
    "T = 2000\n",
    "\n",
    "#step size\n",
    "s = 0.3\n",
    "\n",
    "\n",
    "# x,y\n",
    "x = np.zeros((T,3))\n",
    "y = np.zeros((T,3))\n",
    "\n",
    "\n",
    "# k use to store immediate strategy\n",
    "k_1 = np.zeros(T)\n",
    "k_2 = np.zeros(T)\n",
    "k_3 = np.zeros(T)\n",
    "\n",
    "\n",
    "#define initial conditions on accumulated strategy / payoff\n",
    "\n",
    "x[0] = [0.7, 0.2, 0.1]\n",
    "y[0] = [0.5, 0.3, 0.2]\n",
    "\n",
    "x[1] = [0.5, 0.3, 0.2]\n",
    "y[1] = [0.7, 0.2, 0.1]\n",
    "\n",
    "k_1[0] = y[0][0]\n",
    "k_2[0] = y[1][1]\n",
    "k_3[0] = y[2][2]\n",
    "\n",
    "\n",
    "def A(t):\n",
    "    if t % 3 == 0:\n",
    "        return np.array([[0, 1, -1], [-1, 0, 1], [1, -1, 0]])  # Rock-Paper-Scissors\n",
    "    elif t % 3 == 1:\n",
    "        return np.array([[0, -1, 1], [1, 0, -1], [-1, 1, 0]])  # Variant of RPS\n",
    "    else:\n",
    "        return np.array([[0, 0.25, 0.75], [1.5, 0, 0], [0, 1, 0]])  # Another variant\n",
    "        \n",
    "def A_21(t):\n",
    "    return np.array([[-1*A(t)[0][0],-1*A(t)[1][0],-1*A(t)[2][0]], [-1*A(t)[0][1],-1*A(t)[1][1],-1*A(t)[2][1]], [-1*A(t)[0][2],-1*A(t)[1][2],-1*A(t)[2][2]]])\n",
    "\n",
    "\n",
    "\n",
    "for t in range(1,T-1):\n",
    "    \n",
    "    w = 2*np.dot(A(t),y[t].T).T - np.dot(A(t-1),y[t-1].T).T\n",
    "    d_2 = x[t][0] * math.exp(s*w[0]) + x[t][1] * math.exp(s*w[1]) + x[t][2] * math.exp(s*w[2])\n",
    "    x[t+1][0] = x[t][0] * math.exp(s*w[0])/d_2\n",
    "    x[t+1][1] = x[t][1] * math.exp(s*w[1])/d_2\n",
    "    x[t+1][2] = x[t][2] * math.exp(s*w[2])/d_2\n",
    "\n",
    "\n",
    "    z = 2*np.dot(A_21(t),x[t].T).T - np.dot(A_21(t-1),x[t-1].T).T\n",
    "    d_1 = y[t][0] * math.exp(s*z[0]) + y[t][1] * math.exp(s*z[1]) + y[t][2] * math.exp(s*z[2])\n",
    "    y[t+1][0] = y[t][0] * math.exp(s*z[0])/d_1\n",
    "    y[t+1][1] = y[t][1] * math.exp(s*z[1])/d_1\n",
    "    y[t+1][2] = y[t][2] * math.exp(s*z[2])/d_1\n",
    "    \n",
    "    \n",
    "    k_1[t+1] = y[t+1][0] \n",
    "    k_2[t+1] = y[t+1][1] \n",
    "    k_3[t+1] = y[t+1][2] \n",
    "    \n",
    "    \n",
    "    \n",
    "time_steps = range(T)\n",
    "\n",
    "plt.figure(figsize=(10, 6))  \n",
    "plt.plot(time_steps, k_1, label='Strategy 1')\n",
    "plt.plot(time_steps, k_2, label='Strategy 2')\n",
    "plt.plot(time_steps, k_3, label='Strategy 3')\n",
    "\n",
    "plt.legend()\n",
    "plt.title('Strategies Over Time, OMWU')\n",
    "plt.xlabel('Time Steps')\n",
    "plt.ylabel('Strategy Probability')\n",
    "\n",
    "plt.show()\n",
    "\n",
    "def kl(x1, x2, x3, y1, y2, y3):\n",
    "    epsilon = 0\n",
    "    return 1/3 * math.log(1 / (3*(x1+epsilon))) + 1/3 * math.log(1/(3*(x2+epsilon))) + 1/3 * math.log(1/(3*(x3+epsilon))) + 1/3 * math.log(1/(3*(y1+epsilon))) + 1/3 * math.log(1/(3*(y2+epsilon))) + 1/3 * math.log(1/(3*(y3+epsilon)))\n",
    "\n",
    "value = np.zeros(T-1)  \n",
    "\n",
    "for t in range(0, T-1):\n",
    "    value[t] = kl(x[t][0], x[t][1], x[t][2], y[t][0], y[t][1], y[t][2])\n",
    "\n",
    "t = np.arange(0, T-1)\n",
    "fig = plt.figure()\n",
    "\n",
    "value[0] = value[1]\n",
    "\n",
    "plt.plot(t, value)\n",
    "\n",
    "plt.title('KL-divergence Over Time, OMWU')\n",
    "plt.xlabel('Time Steps')\n",
    "plt.ylabel('KL-divergence')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "fedd02e7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using matplotlib backend: MacOSX\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "%matplotlib auto\n",
    "plt.style.use('seaborn-white')\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "\n",
    "\n",
    "def calculate_y(a, eta):\n",
    "    return (a * np.exp(3 * eta)) / (a * np.exp(3 * eta) + (1 - a))\n",
    "\n",
    "a_values = np.linspace(0, 1, 500)\n",
    "\n",
    "eta_values = [0.1, 0.3, 0.5]\n",
    "\n",
    "plt.figure(figsize=(10, 6))\n",
    "for eta in eta_values:\n",
    "    y_values = calculate_y(a_values, eta)\n",
    "    plt.plot(a_values, y_values, label=f'eta = {eta}')\n",
    "\n",
    "plt.title('')\n",
    "plt.xlabel('a')\n",
    "plt.ylabel(r'$\\frac{ae^{3*\\eta}}{ae^{3*\\eta} + (1-a)}$')\n",
    "plt.legend()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e102a963",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow",
   "language": "python",
   "name": "tensorflow"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
