{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98843c3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from stable_baselines3 import PPO\n",
    "from coin_toss import Coin_toss\n",
    "from transform import TS_VS\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b5e1fd7",
   "metadata": {},
   "source": [
    "# Get initial trajectory and learn transformation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "676f9673",
   "metadata": {},
   "outputs": [],
   "source": [
    "game = Coin_toss(ergodic=False)\n",
    "wealths = np.zeros(101)\n",
    "wealths[0] = 100\n",
    "for idx in range(100):\n",
    "    obs, rew, _, _ = game.step(np.array([1]))\n",
    "    wealths[idx + 1] = obs \n",
    "transformation = TS_VS(wealths)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cee49c23",
   "metadata": {},
   "source": [
    "# Train 3 models: with logarithmic transformation, without transformation, and with learned transformation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0efe7ac5",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_time_steps = 1e7\n",
    "model_logarithm = PPO('MlpPolicy', Coin_toss(ergodic=True)).learn(total_timesteps=num_time_steps, progress_bar=True)\n",
    "model_standard = PPO('MlpPolicy', Coin_toss(ergodic=False)).learn(total_timesteps=num_time_steps, progress_bar=True)\n",
    "model_transform = PPO('MlpPolicy', Coin_toss(ergodic=True, trans=transformation)).learn(total_timesteps=num_time_steps, progress_bar=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44bc88fc",
   "metadata": {},
   "source": [
    "# Play 1000 games with trained agents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98ef7f99",
   "metadata": {},
   "outputs": [],
   "source": [
    "game_logarithm = Coin_toss(ergodic=False)\n",
    "game_standard = Coin_toss(ergodic=False)\n",
    "game_transform = Coin_toss(ergodic=False)\n",
    "num_exp = 1000\n",
    "ep_len = 1000\n",
    "traj_logarithm = np.zeros((num_exp, ep_len + 1))\n",
    "traj_standard = np.zeros((num_exp, ep_len + 1))\n",
    "traj_transform = np.zeros((num_exp, ep_len + 1))\n",
    "for idx_1 in range(num_exp):\n",
    "    obs_logarithm = game_logarithm.reset()\n",
    "    obs_standard = game_standard.reset()\n",
    "    obs_transform = game_transform.reset()\n",
    "    traj_logarithm[idx_1, 0] = 100\n",
    "    traj_standard[idx_1, 0] = 100\n",
    "    traj_transform[idx_1, 0] = 100\n",
    "    action_logarithm, _ = model_logarithm.predict(np.log(obs_logarithm))\n",
    "    action_standard, _ = model_standard.predict(obs_standard)\n",
    "    action_transform, _ = model_transform.predict(obs_transform)\n",
    "    for idx_2 in range(ep_len):\n",
    "        obs_logarithm, rew_logarithm, _, _ = game_logarithm.step(action_logarithm)\n",
    "        obs_standard, rew_standard, _, _ = game_standard.step(action_standard)\n",
    "        obs_transform, rew_transform, _, _ = game_transform.step(action_transform)\n",
    "        action_logarithm, _ = model_logarithm.predict(np.log(obs_logarithm))\n",
    "        action_standard, _ = model_standard.predict(obs_standard)\n",
    "        action_transform, _ = model_transform.predict(obs_transform)\n",
    "        traj_logarithm[idx_1, idx_2 + 1] = game_logarithm.cum_reward.item()\n",
    "        traj_standard[idx_1, idx_2 + 1] = game_standard.cum_reward.item()\n",
    "        traj_transform[idx_1, idx_2 + 1] = game_transform.cum_reward.item()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb7acb04",
   "metadata": {},
   "source": [
    "# Plot first 10 trajectories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5993ea0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.subplot(3, 1, 1)\n",
    "plt.plot(traj_logarithm[0:10, :].T)\n",
    "plt.title('logarithmic transform')\n",
    "plt.yscale('log')\n",
    "plt.subplot(3, 1, 2)\n",
    "plt.plot(traj_standard[0:10, :].T)\n",
    "plt.title('standard learning')\n",
    "plt.yscale('log')\n",
    "plt.subplot(3, 1, 3)\n",
    "plt.plot(traj_transform[0:10, :].T)\n",
    "plt.title('learned transform')\n",
    "plt.yscale('log')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e908fea7",
   "metadata": {},
   "source": [
    "# Compute statistics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a06221d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.mean(traj_logarithm[:, -1]))\n",
    "print(np.mean(traj_standard[:, -1]))\n",
    "print(np.mean(traj_transform[:, -1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c291811c",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.median(traj_logarithm[:, -1]))\n",
    "print(np.median(traj_standard[:, -1]))\n",
    "print(np.median(traj_transform[:, -1]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
