{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "050f4326",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b2b8e22b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "63e06995",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e969a255",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Warning: Flow failed to import. Set the environment variable D4RL_SUPPRESS_IMPORT_ERROR=1 to suppress this message.\n",
      "No module named 'flow'\n",
      "Warning: CARLA failed to import. Set the environment variable D4RL_SUPPRESS_IMPORT_ERROR=1 to suppress this message.\n",
      "No module named 'carla'\n",
      "pybullet build time: May 20 2022 19:44:17\n"
     ]
    }
   ],
   "source": [
    "import d4rl\n",
    "import gym\n",
    "import numpy as np\n",
    "import pyrallis\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import wandb\n",
    "from torch.distributions import Normal\n",
    "from tqdm import trange\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d314bccb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import os\n",
    "import random\n",
    "import uuid\n",
    "from copy import deepcopy\n",
    "from dataclasses import asdict, dataclass\n",
    "from typing import Any, Dict, List, Optional, Tuple, Union\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "177dc4dc",
   "metadata": {},
   "source": [
    "with open('hopper-medium-v2.pkl', 'rb') as f:\n",
    "    trajectories = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ca486da3",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('hopper-expert-v2.pkl', 'rb') as f:\n",
    "    trajectories = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "eb9dfb5d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "Starting new experiment: hopper medium\n",
      "1027 trajectories, 999494 timesteps found\n",
      "Average return: 3511.36, std: 328.59\n",
      "Max return: 3759.08, min: 1645.28\n",
      "==================================================\n"
     ]
    }
   ],
   "source": [
    "states, traj_lens, returns = [], [], []\n",
    "for path in trajectories:\n",
    "    states.append(path['observations'])\n",
    "    traj_lens.append(len(path['observations']))\n",
    "    returns.append(path['rewards'].sum())\n",
    "traj_lens, returns = np.array(traj_lens), np.array(returns)\n",
    "states = np.concatenate(states, axis=0)\n",
    "state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6\n",
    "\n",
    "\n",
    "num_timesteps = sum(traj_lens)\n",
    "env_name = \"hopper\"\n",
    "dataset = \"medium\"\n",
    "print('=' * 50)\n",
    "print(f'Starting new experiment: {env_name} {dataset}')\n",
    "print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found')\n",
    "print(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}')\n",
    "print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}')\n",
    "print('=' * 50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "379e93d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class VectorizedLinear(nn.Module):\n",
    "    def __init__(self, in_features: int, out_features: int, ensemble_size: int):\n",
    "        super().__init__()\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "        self.ensemble_size = ensemble_size\n",
    "\n",
    "        self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))\n",
    "        self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))\n",
    "\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        # default pytorch init for nn.Linear module\n",
    "        for layer in range(self.ensemble_size):\n",
    "            nn.init.kaiming_uniform_(self.weight[layer], a=math.sqrt(5))\n",
    "\n",
    "        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])\n",
    "        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0\n",
    "        nn.init.uniform_(self.bias, -bound, bound)\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        # input: [ensemble_size, batch_size, input_size]\n",
    "        # weight: [ensemble_size, input_size, out_size]\n",
    "        # out: [ensemble_size, batch_size, out_size]\n",
    "        return x @ self.weight + self.bias\n",
    "\n",
    "\n",
    "class Actor(nn.Module):\n",
    "    def __init__(\n",
    "        self, state_dim: int, action_dim: int, hidden_dim: int, max_action: float = 1.0\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.trunk = nn.Sequential(\n",
    "            nn.Linear(state_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "        )\n",
    "        # with separate layers works better than with Linear(hidden_dim, 2 * action_dim)\n",
    "        self.mu = nn.Linear(hidden_dim, action_dim)\n",
    "        self.log_sigma = nn.Linear(hidden_dim, action_dim)\n",
    "\n",
    "        # init as in the EDAC paper\n",
    "        for layer in self.trunk[::2]:\n",
    "            torch.nn.init.constant_(layer.bias, 0.1)\n",
    "\n",
    "        torch.nn.init.uniform_(self.mu.weight, -1e-3, 1e-3)\n",
    "        torch.nn.init.uniform_(self.mu.bias, -1e-3, 1e-3)\n",
    "        torch.nn.init.uniform_(self.log_sigma.weight, -1e-3, 1e-3)\n",
    "        torch.nn.init.uniform_(self.log_sigma.bias, -1e-3, 1e-3)\n",
    "\n",
    "        self.action_dim = action_dim\n",
    "        self.max_action = max_action\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        state: torch.Tensor,\n",
    "        deterministic: bool = False,\n",
    "        need_log_prob: bool = False,\n",
    "    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n",
    "        hidden = self.trunk(state)\n",
    "        mu, log_sigma = self.mu(hidden), self.log_sigma(hidden)\n",
    "\n",
    "        # clipping params from EDAC paper, not as in SAC paper (-20, 2)\n",
    "        log_sigma = torch.clip(log_sigma, -5, 2)\n",
    "        policy_dist = Normal(mu, torch.exp(log_sigma))\n",
    "\n",
    "        if deterministic:\n",
    "            action = mu\n",
    "        else:\n",
    "            action = policy_dist.rsample()\n",
    "\n",
    "        tanh_action, log_prob = torch.tanh(action), None\n",
    "        if need_log_prob:\n",
    "            # change of variables formula (SAC paper, appendix C, eq 21)\n",
    "            log_prob = policy_dist.log_prob(action).sum(axis=-1)\n",
    "            log_prob = log_prob - torch.log(1 - tanh_action.pow(2) + 1e-6).sum(axis=-1)\n",
    "\n",
    "        return tanh_action * self.max_action, log_prob\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def act(self, state: np.ndarray, device: str) -> np.ndarray:\n",
    "        deterministic = not self.training\n",
    "        state = torch.tensor(state, device=device, dtype=torch.float32)\n",
    "        action = self(state, deterministic=deterministic)[0].cpu().numpy()\n",
    "        return action\n",
    "\n",
    "\n",
    "class VectorizedCritic(nn.Module):\n",
    "    def __init__(\n",
    "        self, state_dim: int, action_dim: int, hidden_dim: int, num_critics: int\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.critic = nn.Sequential(\n",
    "            VectorizedLinear(state_dim + action_dim, hidden_dim, num_critics),\n",
    "            nn.ReLU(),\n",
    "            VectorizedLinear(hidden_dim, hidden_dim, num_critics),\n",
    "            nn.ReLU(),\n",
    "            VectorizedLinear(hidden_dim, hidden_dim, num_critics),\n",
    "            nn.ReLU(),\n",
    "            VectorizedLinear(hidden_dim, 1, num_critics),\n",
    "        )\n",
    "        # init as in the EDAC paper\n",
    "        for layer in self.critic[::2]:\n",
    "            torch.nn.init.constant_(layer.bias, 0.1)\n",
    "\n",
    "        torch.nn.init.uniform_(self.critic[-1].weight, -3e-3, 3e-3)\n",
    "        torch.nn.init.uniform_(self.critic[-1].bias, -3e-3, 3e-3)\n",
    "\n",
    "        self.num_critics = num_critics\n",
    "\n",
    "    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:\n",
    "        # [batch_size, state_dim + action_dim]\n",
    "        state_action = torch.cat([state, action], dim=-1)\n",
    "        # [num_critics, batch_size, state_dim + action_dim]\n",
    "        state_action = state_action.unsqueeze(0).repeat_interleave(\n",
    "            self.num_critics, dim=0\n",
    "        )\n",
    "        # [num_critics, batch_size]\n",
    "        q_values = self.critic(state_action).squeeze(-1)\n",
    "        return q_values\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ea118e88",
   "metadata": {},
   "outputs": [],
   "source": [
    "TensorBatch = List[torch.Tensor]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5e0e841d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SACN:\n",
    "    def __init__(\n",
    "        self,\n",
    "        actor: Actor,\n",
    "        actor_optimizer: torch.optim.Optimizer,\n",
    "        critic: VectorizedCritic,\n",
    "        critic_optimizer: torch.optim.Optimizer,\n",
    "        gamma: float = 0.99,\n",
    "        tau: float = 0.005,\n",
    "        alpha_learning_rate: float = 1e-4,\n",
    "        device: str = \"cpu\",\n",
    "    ):\n",
    "        self.device = device\n",
    "\n",
    "        self.actor = actor\n",
    "        self.critic = critic\n",
    "        with torch.no_grad():\n",
    "            self.target_critic = deepcopy(self.critic)\n",
    "\n",
    "        self.actor_optimizer = actor_optimizer\n",
    "        self.critic_optimizer = critic_optimizer\n",
    "\n",
    "        self.tau = tau\n",
    "        self.gamma = gamma\n",
    "\n",
    "        # adaptive alpha setup\n",
    "        self.target_entropy = -float(self.actor.action_dim)\n",
    "        self.log_alpha = torch.tensor(\n",
    "            [0.0], dtype=torch.float32, device=self.device, requires_grad=True\n",
    "        )\n",
    "        self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_learning_rate)\n",
    "        self.alpha = self.log_alpha.exp().detach()\n",
    "\n",
    "    def _alpha_loss(self, state: torch.Tensor) -> torch.Tensor:\n",
    "        with torch.no_grad():\n",
    "            action, action_log_prob = self.actor(state, need_log_prob=True)\n",
    "\n",
    "        loss = (-self.log_alpha * (action_log_prob + self.target_entropy)).mean()\n",
    "\n",
    "        return loss\n",
    "\n",
    "    def _actor_loss(self, state: torch.Tensor) -> Tuple[torch.Tensor, float, float]:\n",
    "        action, action_log_prob = self.actor(state, need_log_prob=True)\n",
    "        q_value_dist = self.critic(state, action)\n",
    "        assert q_value_dist.shape[0] == self.critic.num_critics\n",
    "        q_value_min = q_value_dist.min(0).values\n",
    "        # needed for logging\n",
    "        q_value_std = q_value_dist.std(0).mean().item()\n",
    "        batch_entropy = -action_log_prob.mean().item()\n",
    "\n",
    "        assert action_log_prob.shape == q_value_min.shape\n",
    "        loss = (self.alpha * action_log_prob - q_value_min).mean()\n",
    "\n",
    "        return loss, batch_entropy, q_value_std\n",
    "\n",
    "    def _critic_loss(\n",
    "        self,\n",
    "        state: torch.Tensor,\n",
    "        action: torch.Tensor,\n",
    "        reward: torch.Tensor,\n",
    "        next_state: torch.Tensor,\n",
    "        done: torch.Tensor,\n",
    "    ) -> torch.Tensor:\n",
    "        with torch.no_grad():\n",
    "            next_action, next_action_log_prob = self.actor(\n",
    "                next_state, need_log_prob=True\n",
    "            )\n",
    "            q_next = self.target_critic(next_state, next_action).min(0).values\n",
    "            q_next = q_next - self.alpha * next_action_log_prob\n",
    "\n",
    "            assert q_next.unsqueeze(-1).shape == done.shape == reward.shape\n",
    "            q_target = reward + self.gamma * (1 - done) * q_next.unsqueeze(-1)\n",
    "\n",
    "        q_values = self.critic(state, action)\n",
    "        # [ensemble_size, batch_size] - [1, batch_size]\n",
    "        loss = ((q_values - q_target.view(1, -1)) ** 2).mean(dim=1).sum(dim=0)\n",
    "\n",
    "        return loss\n",
    "\n",
    "    def update(self, batch: TensorBatch) -> Dict[str, float]:\n",
    "        state, action, reward, next_state, done = [arr.to(self.device) for arr in batch]\n",
    "        # Usually updates are done in the following order: critic -> actor -> alpha\n",
    "        # But we found that EDAC paper uses reverse (which gives better results)\n",
    "\n",
    "        # Alpha update\n",
    "        alpha_loss = self._alpha_loss(state)\n",
    "        self.alpha_optimizer.zero_grad()\n",
    "        alpha_loss.backward()\n",
    "        self.alpha_optimizer.step()\n",
    "\n",
    "        self.alpha = self.log_alpha.exp().detach()\n",
    "\n",
    "        # Actor update\n",
    "        actor_loss, actor_batch_entropy, q_policy_std = self._actor_loss(state)\n",
    "        self.actor_optimizer.zero_grad()\n",
    "        actor_loss.backward()\n",
    "        self.actor_optimizer.step()\n",
    "\n",
    "        # Critic update\n",
    "        critic_loss = self._critic_loss(state, action, reward, next_state, done)\n",
    "        self.critic_optimizer.zero_grad()\n",
    "        critic_loss.backward()\n",
    "        self.critic_optimizer.step()\n",
    "\n",
    "        #  Target networks soft update\n",
    "        with torch.no_grad():\n",
    "            soft_update(self.target_critic, self.critic, tau=self.tau)\n",
    "            # for logging, Q-ensemble std estimate with the random actions:\n",
    "            # a ~ U[-max_action, max_action]\n",
    "            max_action = self.actor.max_action\n",
    "            random_actions = -max_action + 2 * max_action * torch.rand_like(action)\n",
    "\n",
    "            q_random_std = self.critic(state, random_actions).std(0).mean().item()\n",
    "\n",
    "        update_info = {\n",
    "            \"alpha_loss\": alpha_loss.item(),\n",
    "            \"critic_loss\": critic_loss.item(),\n",
    "            \"actor_loss\": actor_loss.item(),\n",
    "            \"batch_entropy\": actor_batch_entropy,\n",
    "            \"alpha\": self.alpha.item(),\n",
    "            \"q_policy_std\": q_policy_std,\n",
    "            \"q_random_std\": q_random_std,\n",
    "        }\n",
    "        return update_info\n",
    "\n",
    "    def state_dict(self) -> Dict[str, Any]:\n",
    "        state = {\n",
    "            \"actor\": self.actor.state_dict(),\n",
    "            \"critic\": self.critic.state_dict(),\n",
    "            \"target_critic\": self.target_critic.state_dict(),\n",
    "            \"log_alpha\": self.log_alpha.item(),\n",
    "            \"actor_optim\": self.actor_optimizer.state_dict(),\n",
    "            \"critic_optim\": self.critic_optimizer.state_dict(),\n",
    "            \"alpha_optim\": self.alpha_optimizer.state_dict(),\n",
    "        }\n",
    "        return state\n",
    "\n",
    "    def load_state_dict(self, state_dict: Dict[str, Any]):\n",
    "        self.actor.load_state_dict(state_dict[\"actor\"])\n",
    "        self.critic.load_state_dict(state_dict[\"critic\"])\n",
    "        self.target_critic.load_state_dict(state_dict[\"target_critic\"])\n",
    "        self.actor_optimizer.load_state_dict(state_dict[\"actor_optim\"])\n",
    "        self.critic_optimizer.load_state_dict(state_dict[\"critic_optim\"])\n",
    "        self.alpha_optimizer.load_state_dict(state_dict[\"alpha_optim\"])\n",
    "        self.log_alpha.data[0] = state_dict[\"log_alpha\"]\n",
    "        self.alpha = self.log_alpha.exp().detach()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "cf475958",
   "metadata": {},
   "outputs": [],
   "source": [
    "project: str = \"CORL\"\n",
    "group: str = \"SAC-N\"\n",
    "name: str = \"SAC-N\"\n",
    "# model params\n",
    "hidden_dim: int = 256\n",
    "num_critics: int = 200\n",
    "    \n",
    "# CHANGE NUM_CRITICS    # CHANGE NUM_CRITICS   # CHANGE NUM_CRITICS   # CHANGE NUM_CRITICS   # CHANGE NUM_CRITICS   # CHANGE NUM_CRITICS   # CHANGE NUM_CRITICS   # CHANGE NUM_CRITICS   # CHANGE NUM_CRITICS   # CHANGE NUM_CRITICS   \n",
    "# 10 for halfcheetah 200 for hopper\n",
    "gamma: float = 0.99\n",
    "tau: float = 5e-3\n",
    "actor_learning_rate: float = 3e-4\n",
    "critic_learning_rate: float = 3e-4\n",
    "alpha_learning_rate: float = 3e-4\n",
    "max_action: float = 1.0\n",
    "# training params\n",
    "buffer_size: int = 2_000_000\n",
    "env_name: str = \"hopper-expert-v2\"\n",
    "batch_size: int = 256\n",
    "num_epochs: int = 3000\n",
    "num_updates_on_epoch: int = 1000\n",
    "normalize_reward: bool = False\n",
    "# evaluation params\n",
    "eval_episodes: int = 10\n",
    "eval_every: int = 20\n",
    "# general params\n",
    "checkpoints_path: Optional[str] = \"./checkpoints\"\n",
    "deterministic_torch: bool = False\n",
    "train_seed: int = 10\n",
    "eval_seed: int = 42\n",
    "log_every: int = 100\n",
    "device: str = \"cuda\"\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "1db4573d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def wrap_env(\n",
    "    env: gym.Env,\n",
    "    state_mean: Union[np.ndarray, float] = 0.0,\n",
    "    state_std: Union[np.ndarray, float] = 1.0,\n",
    "    reward_scale: float = 1.0,\n",
    ") -> gym.Env:\n",
    "    def normalize_state(state):\n",
    "        return (state - state_mean) / state_std\n",
    "\n",
    "    def scale_reward(reward):\n",
    "        return reward_scale * reward\n",
    "\n",
    "    env = gym.wrappers.TransformObservation(env, normalize_state)\n",
    "    if reward_scale != 1.0:\n",
    "        env = gym.wrappers.TransformReward(env, scale_reward)\n",
    "    return env\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "15106903",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ReplayBuffer:\n",
    "    def __init__(\n",
    "        self,\n",
    "        state_dim: int,\n",
    "        action_dim: int,\n",
    "        buffer_size: int,\n",
    "        device: str = \"cpu\",\n",
    "    ):\n",
    "        self._buffer_size = buffer_size\n",
    "        self._pointer = 0\n",
    "        self._size = 0\n",
    "\n",
    "        self._states = torch.zeros(\n",
    "            (buffer_size, state_dim), dtype=torch.float32, device=device\n",
    "        )\n",
    "        self._actions = torch.zeros(\n",
    "            (buffer_size, action_dim), dtype=torch.float32, device=device\n",
    "        )\n",
    "        self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)\n",
    "        self._next_states = torch.zeros(\n",
    "            (buffer_size, state_dim), dtype=torch.float32, device=device\n",
    "        )\n",
    "        self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)\n",
    "        self._device = device\n",
    "\n",
    "    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:\n",
    "        return torch.tensor(data, dtype=torch.float32, device=self._device)\n",
    "\n",
    "    # Loads data in d4rl format, i.e. from Dict[str, np.array].\n",
    "    def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):\n",
    "        if self._size != 0:\n",
    "            raise ValueError(\"Trying to load data into non-empty replay buffer\")\n",
    "        n_transitions = data[\"observations\"].shape[0]\n",
    "        if n_transitions > self._buffer_size:\n",
    "            raise ValueError(\n",
    "                \"Replay buffer is smaller than the dataset you are trying to load!\"\n",
    "            )\n",
    "        self._states[:n_transitions] = self._to_tensor(data[\"observations\"])\n",
    "        self._actions[:n_transitions] = self._to_tensor(data[\"actions\"])\n",
    "        self._rewards[:n_transitions] = self._to_tensor(data[\"rewards\"][..., None])\n",
    "        self._next_states[:n_transitions] = self._to_tensor(data[\"next_observations\"])\n",
    "        self._dones[:n_transitions] = self._to_tensor(data[\"terminals\"][..., None])\n",
    "        self._size += n_transitions\n",
    "        self._pointer = min(self._size, n_transitions)\n",
    "\n",
    "        print(f\"Dataset size: {n_transitions}\")\n",
    "\n",
    "    def sample(self, batch_size: int) -> TensorBatch:\n",
    "        indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size)\n",
    "        states = self._states[indices]\n",
    "        actions = self._actions[indices]\n",
    "        rewards = self._rewards[indices]\n",
    "        next_states = self._next_states[indices]\n",
    "        dones = self._dones[indices]\n",
    "        return [states, actions, rewards, next_states, dones]\n",
    "\n",
    "    def add_transition(self):\n",
    "        # Use this method to add new data into the replay buffer during fine-tuning.\n",
    "        raise NotImplementedError\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ec106f3c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/anderson/anaconda3/envs/dtf-gym/lib/python3.8/site-packages/gym/logger.py:30: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n",
      "  warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))\n",
      "load datafile: 100%|████████████████████████████| 21/21 [00:01<00:00, 20.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset size: 999061\n"
     ]
    }
   ],
   "source": [
    "# data, evaluation, env setup\n",
    "eval_env = wrap_env(gym.make(env_name))\n",
    "state_dim = eval_env.observation_space.shape[0]\n",
    "action_dim = eval_env.action_space.shape[0]\n",
    "\n",
    "d4rl_dataset = d4rl.qlearning_dataset(eval_env)\n",
    "\n",
    "if normalize_reward:\n",
    "    modify_reward(d4rl_dataset, env_name)\n",
    "\n",
    "buffer = ReplayBuffer(\n",
    "    state_dim=state_dim,\n",
    "    action_dim=action_dim,\n",
    "    buffer_size=buffer_size,\n",
    "    device=device,\n",
    ")\n",
    "buffer.load_d4rl_dataset(d4rl_dataset)\n",
    "\n",
    "# Actor & Critic setup\n",
    "actor = Actor(state_dim, action_dim, hidden_dim, max_action)\n",
    "actor.to(device)\n",
    "actor_optimizer = torch.optim.Adam(actor.parameters(), lr=actor_learning_rate)\n",
    "critic = VectorizedCritic(\n",
    "    state_dim, action_dim, hidden_dim, num_critics\n",
    ")\n",
    "critic.to(device)\n",
    "critic_optimizer = torch.optim.Adam(\n",
    "    critic.parameters(), lr=critic_learning_rate\n",
    ")\n",
    "\n",
    "trainer = SACN(\n",
    "    actor=actor,\n",
    "    actor_optimizer=actor_optimizer,\n",
    "    critic=critic,\n",
    "    critic_optimizer=critic_optimizer,\n",
    "    gamma=gamma,\n",
    "    tau=tau,\n",
    "    alpha_learning_rate=alpha_learning_rate,\n",
    "    device=device,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4ed1aa2",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "744ef3d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torch.load(\"./SAC-N-hopper-medium-expert-v2-ac25f151/2900.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "b56fe941",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.load_state_dict(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "8b92b529",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "28946f7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('hopper-expert-v2.pkl', 'rb') as f:\n",
    "    trajectories = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2200e89",
   "metadata": {},
   "outputs": [],
   "source": [
    "type(trajectories[0]['observations'][0][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "410faa3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = []\n",
    "\n",
    "# for i in range(len(trajectories)):\n",
    "i = 1\n",
    "for i in range(len(trajectories)):\n",
    "    data_dict = {}\n",
    "    data_dict['observations'], data_dict['actions'], data_dict['q_values'] = [], [], []\n",
    "    \n",
    "    state = torch.from_numpy(trajectories[i]['observations']).to(\"cuda\") \n",
    "    action = torch.from_numpy(trajectories[i]['actions']).to(\"cuda\") \n",
    "    \n",
    "    for first_action in torch.arange(-0.90, 0.901, 0.2):\n",
    "        action[:,0] = action[:,0] * 0 + first_action\n",
    "        with torch.no_grad():\n",
    "\n",
    "            q_values = trainer.critic(state, action)\n",
    "        data_dict['observations'].append(state.cpu().float().numpy())\n",
    "        data_dict['actions'].append(action.cpu().float().numpy())\n",
    "        data_dict['q_values'].append(q_values.mean(dim=0).cpu().float().numpy())\n",
    "    data.append(data_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "93f692d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"q_values_hopper.pkl\", 'wb') as file:\n",
    "    pickle.dump(data, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e28fd0e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "c3b0d8a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "project: str = \"CORL\"\n",
    "group: str = \"SAC-N\"\n",
    "name: str = \"SAC-N\"\n",
    "# model params\n",
    "hidden_dim: int = 256\n",
    "num_critics: int = 10\n",
    "    \n",
    "# 10 for halfcheetah 200 for hopper\n",
    "gamma: float = 0.99\n",
    "tau: float = 5e-3\n",
    "actor_learning_rate: float = 3e-4\n",
    "critic_learning_rate: float = 3e-4\n",
    "alpha_learning_rate: float = 3e-4\n",
    "max_action: float = 1.0\n",
    "# training params\n",
    "buffer_size: int = 2_000_000\n",
    "env_name: str = \"halfcheetah-expert-v2\"\n",
    "batch_size: int = 256\n",
    "num_epochs: int = 3000\n",
    "num_updates_on_epoch: int = 1000\n",
    "normalize_reward: bool = False\n",
    "# evaluation params\n",
    "eval_episodes: int = 10\n",
    "eval_every: int = 20\n",
    "# general params\n",
    "checkpoints_path: Optional[str] = \"./checkpoints\"\n",
    "deterministic_torch: bool = False\n",
    "train_seed: int = 10\n",
    "eval_seed: int = 42\n",
    "log_every: int = 100\n",
    "device: str = \"cuda\"\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "49c9749b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "load datafile: 100%|████████████████████████████| 21/21 [00:01<00:00, 15.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset size: 999000\n"
     ]
    }
   ],
   "source": [
    "# data, evaluation, env setup\n",
    "eval_env = wrap_env(gym.make(env_name))\n",
    "state_dim = eval_env.observation_space.shape[0]\n",
    "action_dim = eval_env.action_space.shape[0]\n",
    "\n",
    "d4rl_dataset = d4rl.qlearning_dataset(eval_env)\n",
    "\n",
    "if normalize_reward:\n",
    "    modify_reward(d4rl_dataset, env_name)\n",
    "\n",
    "buffer = ReplayBuffer(\n",
    "    state_dim=state_dim,\n",
    "    action_dim=action_dim,\n",
    "    buffer_size=buffer_size,\n",
    "    device=device,\n",
    ")\n",
    "buffer.load_d4rl_dataset(d4rl_dataset)\n",
    "\n",
    "# Actor & Critic setup\n",
    "actor = Actor(state_dim, action_dim, hidden_dim, max_action)\n",
    "actor.to(device)\n",
    "actor_optimizer = torch.optim.Adam(actor.parameters(), lr=actor_learning_rate)\n",
    "critic = VectorizedCritic(\n",
    "    state_dim, action_dim, hidden_dim, num_critics\n",
    ")\n",
    "critic.to(device)\n",
    "critic_optimizer = torch.optim.Adam(\n",
    "    critic.parameters(), lr=critic_learning_rate\n",
    ")\n",
    "\n",
    "trainer = SACN(\n",
    "    actor=actor,\n",
    "    actor_optimizer=actor_optimizer,\n",
    "    critic=critic,\n",
    "    critic_optimizer=critic_optimizer,\n",
    "    gamma=gamma,\n",
    "    tau=tau,\n",
    "    alpha_learning_rate=alpha_learning_rate,\n",
    "    device=device,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "ac218102",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torch.load(\"./SAC-N-halfcheetah-medium-expert-v2-6ede747b/2900.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "f1891d4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.load_state_dict(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "f34a89e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('halfcheetah-expert-v2.pkl', 'rb') as f:\n",
    "    trajectories = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "4a374ea0",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = []\n",
    "\n",
    "# for i in range(len(trajectories)):\n",
    "i = 1\n",
    "for i in range(len(trajectories)):\n",
    "    data_dict = {}\n",
    "    data_dict['observations'], data_dict['actions'], data_dict['q_values'] = [], [], []\n",
    "    \n",
    "    state = torch.from_numpy(trajectories[i]['observations']).to(\"cuda\") \n",
    "    action = torch.from_numpy(trajectories[i]['actions']).to(\"cuda\") \n",
    "    \n",
    "    for first_action in torch.arange(-0.90, 0.901, 0.2):\n",
    "        action[:,0] = action[:,0] * 0 + first_action\n",
    "        with torch.no_grad():\n",
    "\n",
    "            q_values = trainer.critic(state, action)\n",
    "        data_dict['observations'].append(state.cpu().float().numpy())\n",
    "        data_dict['actions'].append(action.cpu().float().numpy())\n",
    "        data_dict['q_values'].append(q_values.mean(dim=0).cpu().float().numpy())\n",
    "    data.append(data_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "f526f08d",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"q_values_halfcheetah.pkl\", 'wb') as file:\n",
    "    pickle.dump(data, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9b87c1e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
