{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.  0.  0.  0.  0. 10. 10. 10. 10. 10. 10. 10. 10. 20. 20. 20. 20. 20.\n",
      " 20. 20.]\n",
      "[ 19.13277907  23.91597384  29.8949673   37.36870912  46.7108864\n",
      "  58.388608    60.48576     63.1072      66.384       70.48\n",
      "  75.6         82.          90.         100.         100.\n",
      " 100.         100.         100.         100.         100.        ]\n",
      "[ 4.54403503  5.68004379  7.10005473  8.87506842 11.09383552 11.4922944\n",
      " 11.990368   12.61296    13.3912     14.364      15.58       17.1\n",
      " 19.         19.         19.         19.         19.         19.\n",
      " 19.         20.        ]\n"
     ]
    }
   ],
   "source": [
    "from crsfd_sac import potential\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "\n",
    "data = np.load('demo_v/3/deterministic_sparse_large_standerd_length20.npz')\n",
    "# obs = data['arr_0'][...,-7:]\n",
    "# state = obs.reshape(50,-1,7)\n",
    "\n",
    "r = data['arr_2'].reshape(50,20)\n",
    "v_target = np.zeros(r.shape)\n",
    "print(r[0])\n",
    "# start reward shaping\n",
    "gamma_1=0.8\n",
    "gamma_2=0.99\n",
    "length = r.shape[1]\n",
    "for i in range(r.shape[1]):\n",
    "    if i==0:\n",
    "        v_target[:,length-1-i] = r[:,length-1-i]/(1-gamma_1)\n",
    "    else:\n",
    "        v_target[:,length-1-i] = r[:,length-1-i]+gamma_1* v_target[:,length-i]\n",
    "print(v_target[0])\n",
    "reward_shaping = np.zeros(r.shape)\n",
    "reward_shaping[:,:-1]=r[:,:-1]+gamma_2*v_target[:,1:]-v_target[:,:-1]\n",
    "reward_shaping[:,-1]=20\n",
    "print(reward_shaping[0])\n",
    "# print(obs.shape)\n",
    "# print(r[0])\n",
    "\n",
    "obs = data['arr_0'][...,-7:]\n",
    "regression = v_target.reshape(-1,1).copy()\n",
    "# print(regression[0:20])\n",
    "\n",
    "\n",
    "    \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(5728.6748, device='cuda:0', grad_fn=<MseLossBackward0>)\n",
      "target: tensor([75.6000], device='cuda:0')\n",
      "output: tensor([-0.0410], device='cuda:0', grad_fn=<SelectBackward0>)\n",
      "tensor(764.1410, device='cuda:0', grad_fn=<MseLossBackward0>)\n",
      "target: tensor([40.3886], device='cuda:0')\n",
      "output: tensor([72.1827], device='cuda:0', grad_fn=<SelectBackward0>)\n",
      "tensor(327.7452, device='cuda:0', grad_fn=<MseLossBackward0>)\n",
      "target: tensor([56.3840], device='cuda:0')\n",
      "output: tensor([70.2141], device='cuda:0', grad_fn=<SelectBackward0>)\n",
      "tensor(94.7011, device='cuda:0', grad_fn=<MseLossBackward0>)\n",
      "target: tensor([53.1072], device='cuda:0')\n",
      "output: tensor([58.6092], device='cuda:0', grad_fn=<SelectBackward0>)\n",
      "tensor(151.8082, device='cuda:0', grad_fn=<MseLossBackward0>)\n",
      "target: tensor([100.], device='cuda:0')\n",
      "output: tensor([91.4885], device='cuda:0', grad_fn=<SelectBackward0>)\n",
      "tensor(82.3696, device='cuda:0', grad_fn=<MseLossBackward0>)\n",
      "target: tensor([100.], device='cuda:0')\n",
      "output: tensor([88.6203], device='cuda:0', grad_fn=<SelectBackward0>)\n",
      "tensor(75.9342, device='cuda:0', grad_fn=<MseLossBackward0>)\n",
      "target: tensor([56.3840], device='cuda:0')\n",
      "output: tensor([60.2661], device='cuda:0', grad_fn=<SelectBackward0>)\n",
      "tensor(60.4683, device='cuda:0', grad_fn=<MseLossBackward0>)\n",
      "target: tensor([100.], device='cuda:0')\n",
      "output: tensor([88.0188], device='cuda:0', grad_fn=<SelectBackward0>)\n",
      "tensor(86.2764, device='cuda:0', grad_fn=<MseLossBackward0>)\n",
      "target: tensor([40.3886], device='cuda:0')\n",
      "output: tensor([41.6115], device='cuda:0', grad_fn=<SelectBackward0>)\n",
      "tensor(39.5515, device='cuda:0', grad_fn=<MseLossBackward0>)\n",
      "target: tensor([100.], device='cuda:0')\n",
      "output: tensor([97.3318], device='cuda:0', grad_fn=<SelectBackward0>)\n",
      "tensor(49.3605, device='cuda:0', grad_fn=<MseLossBackward0>)\n",
      "target: tensor([100.], device='cuda:0')\n",
      "output: tensor([101.4423], device='cuda:0', grad_fn=<SelectBackward0>)\n",
      "tensor(68.3630, device='cuda:0', grad_fn=<MseLossBackward0>)\n",
      "target: tensor([100.], device='cuda:0')\n",
      "output: tensor([101.7296], device='cuda:0', grad_fn=<SelectBackward0>)\n",
      "tensor(50.8231, device='cuda:0', grad_fn=<MseLossBackward0>)\n",
      "target: tensor([29.0360], device='cuda:0')\n",
      "output: tensor([32.6379], device='cuda:0', grad_fn=<SelectBackward0>)\n",
      "tensor(65.8257, device='cuda:0', grad_fn=<MseLossBackward0>)\n",
      "target: tensor([100.], device='cuda:0')\n",
      "output: tensor([103.4223], device='cuda:0', grad_fn=<SelectBackward0>)\n",
      "tensor(51.0106, device='cuda:0', grad_fn=<MseLossBackward0>)\n",
      "target: tensor([82.], device='cuda:0')\n",
      "output: tensor([83.3337], device='cuda:0', grad_fn=<SelectBackward0>)\n",
      "tensor(77.8912, device='cuda:0', grad_fn=<MseLossBackward0>)\n",
      "target: tensor([100.], device='cuda:0')\n",
      "output: tensor([89.0525], device='cuda:0', grad_fn=<SelectBackward0>)\n"
     ]
    }
   ],
   "source": [
    "device = 'cuda'\n",
    "value = potential(obs_dim=7).to(device)\n",
    "# value_optimizer = torch.optim.Adam(value.parameters(), lr=1e-3, betas=(0.9, 0.999))\n",
    "value_optimizer = torch.optim.Adam(value.parameters(), lr=1e-3, weight_decay=0.99)\n",
    "for epoch in range(8000):\n",
    "    idx = np.random.randint(0, obs.shape[0],size=32)\n",
    "    obs_batch = torch.from_numpy(obs[idx]).to(device).float()\n",
    "    regression_batch = torch.from_numpy(regression[idx]).to(device).float()\n",
    "    outputs = value(obs_batch)\n",
    "    loss = F.mse_loss(regression_batch,outputs)\n",
    "    value_optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    value_optimizer.step()\n",
    "    if epoch%500==0:\n",
    "        print(loss)\n",
    "        print('target:',regression_batch[0])\n",
    "        print('output:',outputs[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 22.421227]\n",
      " [ 23.48027 ]\n",
      " [ 28.046183]\n",
      " [ 36.05583 ]\n",
      " [ 46.304394]\n",
      " [ 57.436375]\n",
      " [ 68.66611 ]\n",
      " [ 78.03945 ]\n",
      " [ 80.7573  ]\n",
      " [ 81.75727 ]\n",
      " [ 83.82342 ]\n",
      " [ 87.2133  ]\n",
      " [ 89.78035 ]\n",
      " [ 92.42827 ]\n",
      " [ 95.17423 ]\n",
      " [ 96.52243 ]\n",
      " [ 98.12437 ]\n",
      " [ 99.775734]\n",
      " [101.48047 ]\n",
      " [102.79901 ]]\n",
      "[ 0.82424164  4.28544998  7.64908791  9.78552246 10.55761719 10.54307175\n",
      "  8.59294891  1.91027832  0.18239594  1.2279129   2.5177536   1.66924286\n",
      "  1.72364044  1.79421997  0.38297272  0.62069702  0.65361023  0.68993378\n",
      "  0.29055023  0.        ]\n"
     ]
    }
   ],
   "source": [
    "state = obs.reshape(50,-1,7)\n",
    "reward_shaping2 = np.zeros(r.shape)\n",
    "index = 10\n",
    "with torch.no_grad():\n",
    "    value_nn = value(torch.from_numpy(state[index]).to(device).float()).detach().cpu().numpy()\n",
    "    print(value_nn)\n",
    "    reward_shaping2[index,:-1]=gamma_2*value_nn[1:,0]-value_nn[:-1,0]\n",
    "print(reward_shaping2[index])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(value.state_dict(), 'potential_weight/value.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.01082538 -0.02062765 -0.3117076 ] 0.0\n",
      "[ 0.00783324 -0.01779822 -0.30569392] 0.0\n",
      "[ 0.00077174 -0.01188893 -0.2866196 ] 0.0\n",
      "[-0.00619212 -0.00628947 -0.25527546] 0.0\n",
      "[-0.01044918 -0.00230936 -0.21734808] 0.0\n",
      "[-0.01118617 -0.00063541 -0.17785837] 10.0\n",
      "[-0.00976246 -0.00047971 -0.14085245] 10.0\n",
      "[-9.7044781e-03 -7.9958168e-05 -1.2696342e-01] 10.0\n",
      "[-9.0520550e-03  5.1729829e-05 -1.2694176e-01] 10.0\n",
      "[-8.1958408e-03 -6.4928994e-05 -1.2698922e-01] 10.0\n",
      "[-7.2191511e-03 -3.4795401e-05 -1.2450873e-01] 10.0\n",
      "[-0.00658089 -0.00042943 -0.11935372] 10.0\n",
      "[-0.00585905 -0.00070997 -0.11357114] 10.0\n",
      "[-0.00517583 -0.00094859 -0.10819958] 20.0\n",
      "[-0.00438961 -0.00131511 -0.10331113] 20.0\n",
      "[-0.00362989 -0.00148181 -0.09865505] 20.0\n",
      "[-0.00292709 -0.0016322  -0.09549285] 20.0\n",
      "[-0.00213311 -0.00178226 -0.09283621] 20.0\n",
      "[-0.00154233 -0.00184098 -0.08977883] 20.0\n",
      "[-0.00110409 -0.00181354 -0.08545475] 20.0\n"
     ]
    }
   ],
   "source": [
    "for i in range(20):\n",
    "    print(obs[i][0:3],r[0][i])"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "0f479afe68b2a67f6be6581d12f00554171ca63ad3f980713ca439838399a991"
  },
  "kernelspec": {
   "display_name": "Python 3.7.12 ('mjrl-env')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
