{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 148,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "device = 'cuda:1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import copy\n",
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import gymnasium as gym\n",
    "import gym\n",
    "import d4rl\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.distributions import Normal, Categorical\n",
    "from torch.nn import functional as F\n",
    "from tqdm import tqdm\n",
    "from stable_baselines3 import PPO, SAC\n",
    "import torch\n",
    "\n",
    "sys.path.append('/home/hossein/Off-Policy-Evaluation-Lab')\n",
    "\n",
    "#from opelab.core.baselines.diffuser import Diffuser\n",
    "from opelab.core.policy import MixturePolicy, PPOPolicy, SACPolicy, SacPolicyNetwork, PolicyNetwork, D4RLPolicy, D4RLSACPolicy, TD3Policy\n",
    "from opelab.core.baselines.diffusion.temporal import TemporalUnet\n",
    "from opelab.core.baselines.diffusion.diffusion import GaussianDiffusion\n",
    "from opelab.examples.helpers import TanhBijector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "metadata": {},
   "outputs": [],
   "source": [
    "policy = TD3Policy('policy/pendulum/Pi_3.pkl', std=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 164,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3,)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[-2.1147122]], dtype=float32)"
      ]
     },
     "execution_count": 164,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state = env.reset()\n",
    "print(state.shape)\n",
    "policy.sample(state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 168,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 2/100 [00:00<00:07, 13.34it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:07<00:00, 14.24it/s]\n"
     ]
    }
   ],
   "source": [
    "# Rollout a trajectory\n",
    "returns = []\n",
    "for i in tqdm(range(100)):\n",
    "    state = env.reset()\n",
    "    done = False\n",
    "    t = 0\n",
    "    trajectory = []\n",
    "    total_reward = 0\n",
    "\n",
    "    while not done and t < max_T:\n",
    "        action = policy.sample(state.reshape(-1), deterministic=True)\n",
    "        next_state, reward, done, _ = env.step(action)\n",
    "        trajectory.append((state, action, reward, next_state, done))\n",
    "        state = next_state\n",
    "        total_reward += reward\n",
    "        t += 1\n",
    "    \n",
    "    returns.append(total_reward)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|          | 2/200 [00:00<00:14, 13.93it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [00:13<00:00, 14.68it/s]\n",
      "100%|██████████| 200/200 [00:13<00:00, 14.61it/s]\n",
      "100%|██████████| 200/200 [00:13<00:00, 14.80it/s]\n",
      "100%|██████████| 200/200 [00:13<00:00, 14.71it/s]\n",
      "100%|██████████| 200/200 [00:13<00:00, 14.55it/s]\n",
      "100%|██████████| 200/200 [00:13<00:00, 14.50it/s]\n",
      "100%|██████████| 200/200 [00:13<00:00, 14.52it/s]\n",
      "100%|██████████| 200/200 [00:13<00:00, 14.82it/s]\n",
      "100%|██████████| 200/200 [00:13<00:00, 14.64it/s]\n",
      "100%|██████████| 200/200 [00:13<00:00, 14.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Policy: policy/pendulum/TD3_Pendulum-v1_0_t10k.pkl, Average Return: -906.8829956054688\n",
      "Policy: policy/pendulum/TD3_Pendulum-v1_0_t20k.pkl, Average Return: -888.802734375\n",
      "Policy: policy/pendulum/TD3_Pendulum-v1_0_t30k.pkl, Average Return: -220.9932403564453\n",
      "Policy: policy/pendulum/TD3_Pendulum-v1_0_t40k.pkl, Average Return: -141.5447998046875\n",
      "Policy: policy/pendulum/TD3_Pendulum-v1_0_t50k.pkl, Average Return: -144.28160095214844\n",
      "Policy: policy/pendulum/TD3_Pendulum-v1_0_t60k.pkl, Average Return: -152.19606018066406\n",
      "Policy: policy/pendulum/TD3_Pendulum-v1_0_t70k.pkl, Average Return: -143.69801330566406\n",
      "Policy: policy/pendulum/TD3_Pendulum-v1_0_t80k.pkl, Average Return: -139.94456481933594\n",
      "Policy: policy/pendulum/TD3_Pendulum-v1_0_t90k.pkl, Average Return: -139.42214965820312\n",
      "Policy: policy/pendulum/TD3_Pendulum-v1_0_t100k.pkl, Average Return: -139.98504638671875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Define the range of policies to load\n",
    "policy_paths = [f'policy/pendulum/TD3_Pendulum-v1_0_t{t}k.pkl' for t in range(10, 101, 10)]\n",
    "\n",
    "# Initialize a dictionary to store the returns for each policy\n",
    "policy_returns = {}\n",
    "\n",
    "# Evaluate each policy\n",
    "for path in policy_paths:\n",
    "    policy = TD3Policy(path, std=0.1)\n",
    "    returns = []\n",
    "    for i in tqdm(range(200)):\n",
    "        state = env.reset()\n",
    "        done = False\n",
    "        t = 0\n",
    "        total_reward = 0\n",
    "\n",
    "        while not done and t < max_T:\n",
    "            action = policy.sample(state.reshape(1, -1), deterministic=True)\n",
    "            next_state, reward, done, _ = env.step(action)\n",
    "            state = next_state\n",
    "            total_reward += reward\n",
    "            t += 1\n",
    "\n",
    "        returns.append(total_reward)\n",
    "    \n",
    "    policy_returns[path] = returns\n",
    "\n",
    "# Print the returns for each policy\n",
    "for path, returns in policy_returns.items():\n",
    "    print(f'Policy: {path}, Average Return: {np.mean(returns)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|          | 2/200 [00:00<00:13, 14.71it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [00:13<00:00, 14.59it/s]\n",
      "100%|██████████| 200/200 [00:13<00:00, 14.66it/s]\n",
      "100%|██████████| 200/200 [00:13<00:00, 14.73it/s]\n",
      "100%|██████████| 200/200 [00:13<00:00, 14.58it/s]\n",
      "100%|██████████| 200/200 [00:13<00:00, 14.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Policy: policy/pendulum/Pi_1.pkl, Average Return: -895.6682739257812\n",
      "Policy: policy/pendulum/Pi_2.pkl, Average Return: -910.1066284179688\n",
      "Policy: policy/pendulum/Pi_3.pkl, Average Return: -212.7943115234375\n",
      "Policy: policy/pendulum/Pi_4.pkl, Average Return: -146.25396728515625\n",
      "Policy: policy/pendulum/Pi_5.pkl, Average Return: -142.9557342529297\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Define the range of policies to load\n",
    "policy_paths = [f'policy/pendulum/Pi_{t}.pkl' for t in range(1, 6, 1)]\n",
    "\n",
    "# Initialize a dictionary to store the returns for each policy\n",
    "policy_returns = {}\n",
    "\n",
    "# Evaluate each policy\n",
    "for path in policy_paths:\n",
    "    policy = TD3Policy(path, std=0.1)\n",
    "    returns = []\n",
    "    for i in tqdm(range(200)):\n",
    "        state = env.reset()\n",
    "        done = False\n",
    "        t = 0\n",
    "        total_reward = 0\n",
    "\n",
    "        while not done and t < max_T:\n",
    "            action = policy.sample(state.reshape(1, -1), deterministic=True)\n",
    "            next_state, reward, done, _ = env.step(action)\n",
    "            state = next_state\n",
    "            total_reward += reward\n",
    "            t += 1\n",
    "\n",
    "        returns.append(total_reward)\n",
    "    \n",
    "    policy_returns[path] = returns\n",
    "\n",
    "# Print the returns for each policy\n",
    "for path, returns in policy_returns.items():\n",
    "    print(f'Policy: {path}, Average Return: {np.mean(returns)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import copy\n",
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import gymnasium as gym\n",
    "import gym\n",
    "import d4rl\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.distributions import Normal, Categorical\n",
    "from torch.nn import functional as F\n",
    "from tqdm import tqdm\n",
    "from stable_baselines3 import PPO, SAC\n",
    "import torch\n",
    "\n",
    "sys.path.append('/home/hossein/Off-Policy-Evaluation-Lab')\n",
    "\n",
    "#from opelab.core.baselines.diffuser import Diffuser\n",
    "from opelab.core.policy import MixturePolicy, PPOPolicy, SACPolicy, SacPolicyNetwork, PolicyNetwork, D4RLPolicy, D4RLSACPolicy\n",
    "from opelab.core.baselines.diffusion.temporal import TemporalUnet\n",
    "from opelab.core.baselines.diffusion.diffusion import GaussianDiffusion\n",
    "from opelab.examples.helpers import TanhBijector\n",
    "\n",
    "transform = TanhBijector() if atanh else None\n",
    "\n",
    "temporal_model = TemporalUnet(\n",
    "    horizon=T,\n",
    "    transition_dim=state_dim + action_dim,\n",
    ").to(device)\n",
    "\n",
    "diffusion_model = GaussianDiffusion(\n",
    "    model=temporal_model,\n",
    "    horizon=T,\n",
    "    observation_dim=state_dim,\n",
    "    action_dim=action_dim,\n",
    "    n_timesteps=D,\n",
    "    normalizer=normalize_fn,\n",
    "    unnormalizer=unnormalize_fn,\n",
    "    transform=None,\n",
    "    gmode=atanh\n",
    ").to(device)\n",
    "\n",
    "#model_path = '/home/hossein/Off-Policy-Evaluation-Lab/opelab/examples/hopper_diffusion/diffusion/T16D128/hopper_medium.pth'\n",
    "#model_path = '/home/hossein/Off-Policy-Evaluation-Lab/opelab/examples/hopper_diffusion/diffusion/T16D128/hopper_medium_action.pth'\n",
    "model_path = '/home/hossein/Off-Policy-Evaluation-Lab/opelab/examples/hopper_diffusion/diffusion/T8D256/hopper_action.pth'\n",
    "\n",
    "diffusion_model.load_state_dict(torch.load(model_path))\n",
    "\n",
    "diffusion_model.policy = policy\n",
    "diffusion_model.behavior_policy = behavior"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
