{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Installing the required libraries\n",
    "```bash\n",
    "conda create -n eppo python=3.12 -y\n",
    "conda activate eppo\n",
    "\n",
    "pip install gymnasium torch numpy scipy seaborn tqdm imageio \n",
    "pip install gymnasium['mujoco']\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import gymnasium as gym\n",
    "import itertools\n",
    "import time\n",
    "import copy\n",
    "from gymnasium.wrappers import RescaleAction\n",
    "from typing import Optional\n",
    "from torch.distributions import Normal\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# common parameters\n",
    "max_steps = 500_000\n",
    "learning_rate = 3e-4\n",
    "learn_frequency = 2047  # it is 2048 but we are using 0 based index\n",
    "batch_size = 256\n",
    "max_iter = 10\n",
    "act_actor = \"relu\"\n",
    "act_critic = \"relu\"\n",
    "depth_actor = 3\n",
    "depth_critic = 3\n",
    "width_actor = 256\n",
    "width_critic = 256\n",
    "gamma = 0.99\n",
    "no_norm_actor = False\n",
    "no_norm_critic = False\n",
    "eval_frequency = 20_000\n",
    "eval_episodes = 10\n",
    "buffer_size = 2048\n",
    "\n",
    "# ppo related\n",
    "clip_param = 0.2\n",
    "gae_lambda = 0.95\n",
    "max_grad_norm = 0.5\n",
    "\n",
    "# eppo related\n",
    "regularization_coeff = 0.01\n",
    "radius = 0.01\n",
    "seed = 1\n",
    "exploration_types = [\"mean\", \"cor\", \"ind\"]\n",
    "exploration_type = exploration_types[0]\n",
    "\n",
    "# experiment related\n",
    "environments = [\"ant\", \"halfcheetah\"]\n",
    "strategies = {\n",
    "    \"ant\": [\"back_one\", \"front_one\", \"back_two\", \"front_two\", \"parallel\", \"cross\"],\n",
    "    \"halfcheetah\": [\"back_one\", \"front_one\", \"cross_v1\", \"cross_v2\"],\n",
    "}\n",
    "environment = environments[1]\n",
    "strategy = strategies[environment][0]\n",
    "exp_name = f\"{environment}_{strategy}\"\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ParalyzeActionWrapper\n",
    "class ParalyzeActionWrapper(gym.ActionWrapper):\n",
    "    def __init__(self, env, joint_idxs, paralyzed_ratio=0.0):\n",
    "        super().__init__(env)\n",
    "        self.paralyzed_ratio = paralyzed_ratio\n",
    "        self.joint_idxs = joint_idxs\n",
    "        self.coeff = np.ones(self.action_space.shape)\n",
    "        self.coeff[joint_idxs] = paralyzed_ratio\n",
    "\n",
    "    def step(self, action):\n",
    "        action = action * self.coeff\n",
    "        return self.env.step(action)\n",
    "\n",
    "class SinglePrecision(gym.ObservationWrapper):\n",
    "\n",
    "    def __init__(self, env):\n",
    "        super().__init__(env)\n",
    "\n",
    "        if isinstance(self.observation_space, gym.spaces.Box):\n",
    "            obs_space = self.observation_space\n",
    "            self.observation_space = gym.spaces.Box(obs_space.low, obs_space.high, obs_space.shape)\n",
    "        elif isinstance(self.observation_space, gym.spaces.Dict):\n",
    "            obs_spaces = copy.copy(self.observation_space.spaces)\n",
    "            for k, v in obs_spaces.items():\n",
    "                obs_spaces[k] = gym.spaces.Box(v.low, v.high, v.shape)\n",
    "            self.observation_space = gym.spaces.Dict(obs_spaces)\n",
    "        else:\n",
    "            raise NotImplementedError\n",
    "\n",
    "    def observation(self, observation: np.ndarray) -> np.ndarray:\n",
    "        if isinstance(observation, np.ndarray):\n",
    "            return observation.astype(np.float32)\n",
    "        elif isinstance(observation, dict):\n",
    "            observation = copy.copy(observation)\n",
    "            for k, v in observation.items():\n",
    "                observation[k] = v.astype(np.float32)\n",
    "            return observation\n",
    "\n",
    "\n",
    "def get_idx_to_paralyze(exp_name):\n",
    "    if \"ant\" in exp_name:\n",
    "        if \"front_one\" in exp_name:  # front left joints\n",
    "            return [2, 3]\n",
    "        elif \"front_two\" in exp_name:  # front left and right joints\n",
    "            return [2, 3, 4, 5]\n",
    "        elif \"back_one\" in exp_name:  # back left joints\n",
    "            return [6, 7]\n",
    "        elif \"back_two\" in exp_name:  # back left and right joints\n",
    "            return [0, 1, 6, 7]\n",
    "        elif \"parallel\" in exp_name:  # left front and back joints\n",
    "            return [2, 3, 6, 7]\n",
    "        elif \"cross\" in exp_name:  # left front and right back joints\n",
    "            return [2, 3, 0, 1]\n",
    "    elif \"halfcheetah\" in exp_name:\n",
    "        if \"front_one\" in exp_name:\n",
    "            return [5]\n",
    "        elif \"back_one\" in exp_name:\n",
    "            return [2]\n",
    "        elif \"cross_v1\" in exp_name:\n",
    "            return [2, 4]\n",
    "        elif \"cross_v2\" in exp_name:\n",
    "            return [1, 5]\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown experiment: {exp_name}\")\n",
    "\n",
    "\n",
    "def make_env(\n",
    "    exp_name: str,\n",
    "    seed: int,\n",
    "    idxs: Optional[list] = None,\n",
    "    paralyzed_ratio: Optional[float] = 0.0,\n",
    ") -> gym.Env:\n",
    "\n",
    "    if \"ant\" in exp_name:\n",
    "        env_name = \"Ant-v5\"\n",
    "    elif \"halfcheetah\" in exp_name:\n",
    "        env_name = \"HalfCheetah-v5\"\n",
    "    else:\n",
    "        raise NotImplementedError(f\"Environment {exp_name} not implemented\")\n",
    "\n",
    "    env = gym.make(env_name)\n",
    "    env = RescaleAction(env, -1.0, 1.0)\n",
    "    env = SinglePrecision(env)\n",
    "    env = ParalyzeActionWrapper(env, idxs, paralyzed_ratio)\n",
    "\n",
    "    env.reset(seed=seed)\n",
    "    env.action_space.seed(seed)\n",
    "    env.observation_space.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    return env"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experimenter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def totorch(x, dtype=torch.float32, device=\"cuda\"):\n",
    "    return torch.as_tensor(x, dtype=dtype, device=device)\n",
    "\n",
    "\n",
    "def tonumpy(x):\n",
    "    return x.data.cpu().numpy()\n",
    "\n",
    "\n",
    "class ExperienceMemoryTorch:\n",
    "    \"\"\"Fixed-size buffer to store experience tuples.\"\"\"\n",
    "\n",
    "    field_names = [\"state\", \"action\", \"reward\", \"next_state\", \"terminated\", \"step\"]\n",
    "\n",
    "    def __init__(self, device, buffer_size, dims):\n",
    "        self.device = device\n",
    "        self.buffer_size = buffer_size\n",
    "        self.dims = dims\n",
    "        self.reset()\n",
    "\n",
    "    def reset(self, buffer_size=None):\n",
    "        if buffer_size is not None:\n",
    "            self.buffer_size = buffer_size\n",
    "        self.data_size = 0\n",
    "        self.pointer = 0\n",
    "        self.memory = {\n",
    "            field: torch.empty(self.dims[field], device=self.device)\n",
    "            for field in self.field_names\n",
    "        }\n",
    "\n",
    "    def add(self, state, action, reward, next_state, terminated, step):\n",
    "        for field, value in zip(\n",
    "            self.field_names, [state, action, reward, next_state, terminated, step]\n",
    "        ):\n",
    "            self.memory[field][self.pointer] = value\n",
    "        self.pointer = (self.pointer + 1) % self.buffer_size\n",
    "        self.data_size = min(self.data_size + 1, self.buffer_size)\n",
    "\n",
    "    def sample_by_index(self, index):\n",
    "        return tuple(self.memory[field][index] for field in self.field_names)\n",
    "\n",
    "    def sample_all(self):\n",
    "        return self.sample_by_index(range(self.data_size))\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.data_size\n",
    "\n",
    "    @property\n",
    "    def size(self):\n",
    "        return self.data_size\n",
    "\n",
    "\n",
    "class ParalysisExpriment(object):\n",
    "    def __init__(\n",
    "        self,\n",
    "        agent,\n",
    "        exp_name,\n",
    "        seed,\n",
    "        max_steps,\n",
    "        eval_frequency,\n",
    "        device,\n",
    "        learn_frequency,\n",
    "        eval_episodes,\n",
    "        max_iter,\n",
    "        gamma,\n",
    "    ):\n",
    "        self.exp_name = exp_name\n",
    "        self.seed = seed\n",
    "        self.max_steps = max_steps\n",
    "        self.eval_frequency = eval_frequency\n",
    "        self.device = device\n",
    "        self.learn_frequency = learn_frequency\n",
    "        self.eval_episodes = eval_episodes\n",
    "        self.max_iter = max_iter\n",
    "        self.gamma = gamma\n",
    "\n",
    "        self.idx_to_paralyze = get_idx_to_paralyze(self.exp_name)\n",
    "        self.prepare_tasks()\n",
    "\n",
    "        self.agent = agent\n",
    "        \n",
    "        self.AULCS = []\n",
    "        self.FINAL_RETURNS = []\n",
    "\n",
    "    def prepare_tasks(self):\n",
    "        task_order = [1.0, 0.75, 0.5, 0.25, 0.0, 0.25, 0.5, 0.75, 1.0]\n",
    "        self.n_tasks = len(task_order)\n",
    "\n",
    "        self.tasks = {\n",
    "            task_id: {\"task\": coeff, \"idxs\": self.idx_to_paralyze}\n",
    "            for task_id, coeff in enumerate(task_order)\n",
    "        }\n",
    "        task_names = [task_info[\"task\"] for task_info in self.tasks.values()]\n",
    "        print(f\"Tasks: {task_names}\")\n",
    "\n",
    "    def set_task(self, task_id, task_info):\n",
    "        task = task_info[\"task\"]\n",
    "\n",
    "        self.env = make_env(\n",
    "            exp_name=self.exp_name,\n",
    "            seed=self.seed,\n",
    "            idxs=task_info[\"idxs\"],\n",
    "            paralyzed_ratio=task,\n",
    "        )\n",
    "\n",
    "        self.eval_env = make_env(\n",
    "            exp_name=self.exp_name,\n",
    "            seed=self.seed + 100,\n",
    "            idxs=task_info[\"idxs\"],\n",
    "            paralyzed_ratio=task,\n",
    "        )\n",
    "\n",
    "        return task\n",
    "\n",
    "    def train(self):\n",
    "        time_start = time.time()\n",
    "\n",
    "        information_dict = {\n",
    "            \"episode_rewards\": torch.zeros(self.max_steps * (self.n_tasks + 1)),\n",
    "            \"episode_steps\": torch.zeros(self.max_steps * (self.n_tasks + 1)),\n",
    "            \"step_rewards\": np.empty((2 * self.max_steps * self.n_tasks), dtype=object),\n",
    "        }\n",
    "\n",
    "        r_cum = np.zeros(1)\n",
    "        episode = 0\n",
    "        e_step = 0\n",
    "        for task_id, task_info in self.tasks.items():\n",
    "            # task starts\n",
    "            task = self.set_task(task_id, task_info)\n",
    "            print(f\"Starting to task {task_id}: {task}\")\n",
    "\n",
    "            r_cum = np.zeros(1)\n",
    "            s, _ = self.env.reset()\n",
    "            s = totorch(s, device=self.device)\n",
    "            for step in tqdm(\n",
    "                range(task_id * self.max_steps, (task_id + 1) * self.max_steps),\n",
    "                leave=True,\n",
    "                disable=True,\n",
    "            ):\n",
    "                e_step += 1\n",
    "\n",
    "                if step % self.eval_frequency == 0:\n",
    "                    self.eval(step)\n",
    "\n",
    "                a = self.agent.select_action(s).clip(-1.0, 1.0)\n",
    "\n",
    "                sp, r, done, truncated, info = self.env.step(tonumpy(a))\n",
    "                sp = totorch(sp, device=self.device)\n",
    "\n",
    "                self.agent.store_transition(s, a, r, sp, done, truncated, step + 1)\n",
    "\n",
    "                information_dict[\"step_rewards\"][step] = (\n",
    "                    episode,\n",
    "                    step,\n",
    "                    r,\n",
    "                )\n",
    "\n",
    "                s = sp  # Update state\n",
    "                r_cum += r  # Update cumulative reward\n",
    "\n",
    "                if (step % self.learn_frequency) == 0:\n",
    "                    # print(\"Learning at step: \", step)\n",
    "                    self.agent.learn(max_iter=self.max_iter)\n",
    "\n",
    "                if done or truncated:\n",
    "\n",
    "                    information_dict[\"episode_rewards\"][episode] = r_cum.item()\n",
    "                    information_dict[\"episode_steps\"][episode] = step\n",
    "                    if episode % 10 == 0:\n",
    "                        print(\n",
    "                            f\"Episode: {episode + 1:4d}\\tN-steps: {step:7d}\\tReward: {r_cum.item():10.3f}\"\n",
    "                        )\n",
    "                    s, _ = self.env.reset()\n",
    "                    s = totorch(s, device=self.device)\n",
    "                    r_cum = np.zeros(1)\n",
    "                    episode += 1\n",
    "                    e_step = 0\n",
    "\n",
    "            # task finishes\n",
    "            self.eval(step, final=True)\n",
    "            self.agent.end_task()\n",
    "\n",
    "        time_end = time.time()\n",
    "        print(f\"Training time: {time_end - time_start:.2f} seconds\")\n",
    "        \n",
    "        print(f\"AULC: {np.mean(self.AULCS)}\")\n",
    "        print(f\"Final Return: {np.mean(self.FINAL_RETURNS)}\")\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def eval(self, n_step, final=False):\n",
    "        self.agent.eval()\n",
    "        results = torch.zeros(self.eval_episodes)\n",
    "        collect_infos = {}\n",
    "        for episode in range(self.eval_episodes):\n",
    "            collect_infos[episode] = []\n",
    "            s, info = self.eval_env.reset()\n",
    "            s = totorch(s, device=self.device)\n",
    "            step = 0\n",
    "            a = self.agent.select_action(s, is_training=False)\n",
    "            done = False\n",
    "\n",
    "            while not done:\n",
    "                a = self.agent.select_action(s, is_training=False)\n",
    "\n",
    "                sp, r, term, trunc, info = self.eval_env.step(tonumpy(a))\n",
    "                collect_infos[episode].append(info)\n",
    "\n",
    "                done = term or trunc\n",
    "                s = totorch(sp, device=self.device)\n",
    "                results[episode] += r\n",
    "                step += 1\n",
    "\n",
    "        print(f\"EVALUATION\\tN-steps: {n_step:7d}\\tMean_Reward: {results.mean():10.3f}\")\n",
    "        self.AULCS.append(results.mean())\n",
    "        if final:\n",
    "            self.FINAL_RETURNS.append(results.mean())\n",
    "\n",
    "        self.agent.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_activation(act):\n",
    "    if act == \"relu\":\n",
    "        return nn.ReLU\n",
    "    elif act == \"tanh\":\n",
    "        return nn.Tanh\n",
    "    else:\n",
    "        raise NotImplementedError(f\"{act} is not implemented\")\n",
    "\n",
    "\n",
    "def create_net(d_in, d_out, depth, width, act=\"crelu\", has_norm=True, n_elements=1):\n",
    "    assert depth > 0, \"Need at least one layer\"\n",
    "\n",
    "    act = get_activation(act)\n",
    "\n",
    "    if depth == 1:\n",
    "        arch = nn.Linear(d_in, d_out)\n",
    "    elif depth == 2:\n",
    "        arch = nn.Sequential(\n",
    "            nn.Linear(d_in, width),\n",
    "            (\n",
    "                nn.LayerNorm(width, elementwise_affine=False)\n",
    "                if has_norm\n",
    "                else nn.Identity()\n",
    "            ),\n",
    "            act(),\n",
    "            nn.Linear(width, d_out),\n",
    "        )\n",
    "    else:\n",
    "        in_layer = nn.Linear(d_in, width)\n",
    "        if n_elements > 1:\n",
    "            out_layer = nn.Linear(width, d_out, n_elements)\n",
    "        else:\n",
    "            out_layer = nn.Linear(width, d_out)\n",
    "\n",
    "        # This can probably be done in a more readable way, but it's fast and works...\n",
    "        hidden = list(\n",
    "            itertools.chain.from_iterable(\n",
    "                [\n",
    "                    [\n",
    "                        (\n",
    "                            nn.LayerNorm(width, elementwise_affine=False)\n",
    "                            if has_norm\n",
    "                            else nn.Identity()\n",
    "                        ),\n",
    "                        act(),\n",
    "                        nn.Linear(width, width),\n",
    "                    ]\n",
    "                    for _ in range(depth - 1)\n",
    "                ]\n",
    "            )\n",
    "        )[:-1]\n",
    "        arch = nn.Sequential(in_layer, *hidden, out_layer)\n",
    "\n",
    "    return arch\n",
    "\n",
    "\n",
    "class GaussianHead(nn.Module):\n",
    "    def __init__(self, n):\n",
    "        super().__init__()\n",
    "        self._n = n\n",
    "\n",
    "    def forward(self, x, is_training=True, return_dist=False):\n",
    "        mean = x[..., : self._n]\n",
    "        logstd = x[..., self._n :].clamp(-10.0, -2.0)\n",
    "        std = logstd.exp()\n",
    "        dist = Normal(mean, std, validate_args=False)\n",
    "        if is_training:\n",
    "            y = dist.rsample()\n",
    "            y_logprob = dist.log_prob(y).sum(dim=-1, keepdim=True)\n",
    "        else:\n",
    "            y = dist.mode\n",
    "            y_logprob = None\n",
    "        if return_dist:\n",
    "            return y, y_logprob, dist\n",
    "        return y, mean\n",
    "\n",
    "\n",
    "class ActorNetProbabilistic(nn.Module):\n",
    "    def __init__(self, dim_obs, dim_act, depth=3, width=256, act=\"relu\", has_norm=True):\n",
    "        super().__init__()\n",
    "        self.dim_act = dim_act\n",
    "        self.arch = create_net(dim_obs, 2 * dim_act, depth, width, act, has_norm)\n",
    "\n",
    "        self.head = GaussianHead(self.dim_act)\n",
    "\n",
    "    def forward(self, x, is_training=True, return_dist=False):\n",
    "        f = self.arch(x)\n",
    "        return self.head(f, is_training, return_dist=return_dist)\n",
    "\n",
    "\n",
    "class EvidentialCriticNet(nn.Module):\n",
    "    def __init__(self, dim_obs, depth=3, width=256, act=\"relu\", has_norm=False):\n",
    "        super().__init__()\n",
    "\n",
    "        self.arch = create_net(dim_obs, 4, depth, width, act=act, has_norm=has_norm)\n",
    "\n",
    "    @staticmethod\n",
    "    def evidence(x):\n",
    "        return torch.exp(x)\n",
    "\n",
    "    def forward(self, x):\n",
    "        output = self.arch(x)\n",
    "        gamma, logv, log_alpha, log_beta = output.chunk(4, dim=-1)\n",
    "        v = self.evidence(logv)\n",
    "        alpha = self.evidence(log_alpha) + 1  # to ensure that alpha > 1\n",
    "        beta = self.evidence(log_beta)\n",
    "        return gamma, v, alpha, beta"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# EPPO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EPPOActor(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        arch,\n",
    "        n_state,\n",
    "        n_action,\n",
    "        clip_param,\n",
    "        max_grad_norm,\n",
    "        learning_rate,\n",
    "        depth_actor,\n",
    "        width_actor,\n",
    "        act_actor,\n",
    "        no_norm_actor,\n",
    "        device,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.n_state = n_state\n",
    "        self.n_action = n_action\n",
    "        self.arch = arch\n",
    "        self.clip = clip_param\n",
    "        self.max_grad_norm = max_grad_norm\n",
    "        self.learning_rate = learning_rate\n",
    "        self.depth_actor = depth_actor\n",
    "        self.width_actor = width_actor\n",
    "        self.act_actor = act_actor\n",
    "        self.no_norm_actor = no_norm_actor\n",
    "        self.device = device\n",
    "\n",
    "        self.initialize()\n",
    "\n",
    "    def initialize(self):\n",
    "        self.model = self.arch(\n",
    "            self.n_state,\n",
    "            self.n_action,\n",
    "            depth=self.depth_actor,\n",
    "            width=self.width_actor,\n",
    "            act=self.act_actor,\n",
    "            has_norm=not self.no_norm_actor,\n",
    "        ).to(self.device)\n",
    "        self.optim = torch.optim.Adam(self.model.parameters(), self.learning_rate)\n",
    "\n",
    "    def evaluate(self, s):\n",
    "        _, _, dist = self.model(s, return_dist=True)\n",
    "        # return dist here\n",
    "        return dist\n",
    "\n",
    "    def loss(self, s, a, old_probs, adv):\n",
    "        dist = self.evaluate(s)\n",
    "        new_probs = dist.log_prob(a)\n",
    "        foo = new_probs.sum(1, keepdim=True) - old_probs.sum(1, keepdim=True)\n",
    "        prob_ratio = torch.exp(foo.clamp(-20, 1))\n",
    "\n",
    "        weighted_probs = adv * prob_ratio\n",
    "        weighted_clipped_probs = (\n",
    "            torch.clamp(prob_ratio, 1 - self.clip, 1 + self.clip) * adv\n",
    "        )\n",
    "\n",
    "        actor_loss = -torch.min(weighted_probs, weighted_clipped_probs).mean()\n",
    "        self.clip_fraction = (\n",
    "            (torch.abs((prob_ratio - 1)) > self.clip).to(torch.float).mean()\n",
    "        )\n",
    "\n",
    "        return actor_loss\n",
    "\n",
    "    def update(self, s, a, old_probs, adv):\n",
    "        self.optim.zero_grad()\n",
    "        loss = self.loss(s, a, old_probs, adv)\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)\n",
    "        self.optim.step()\n",
    "\n",
    "    def act(self, s, is_training=True):\n",
    "        a, e = self.model(s, is_training=is_training)\n",
    "        return a, e\n",
    "\n",
    "\n",
    "class EvidentialCritic(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        arch,\n",
    "        n_state,\n",
    "        max_grad_norm,\n",
    "        depth_critic,\n",
    "        width_critic,\n",
    "        act_critic,\n",
    "        no_norm_critic,\n",
    "        device,\n",
    "        learning_rate,\n",
    "        regularization_coeff,\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.arch = arch\n",
    "        self.max_grad_norm = max_grad_norm\n",
    "        self.depth_critic = depth_critic\n",
    "        self.width_critic = width_critic\n",
    "        self.act_critic = act_critic\n",
    "        self.no_norm_critic = no_norm_critic\n",
    "        self.n_state = n_state\n",
    "        self.learning_rate = learning_rate\n",
    "        self.iter = 0\n",
    "        self.device = device\n",
    "        self.initialize()\n",
    "\n",
    "        self.regularization_coeff = regularization_coeff\n",
    "\n",
    "    def initialize(self):\n",
    "        self.model = self.arch(\n",
    "            self.n_state,\n",
    "            depth=self.depth_critic,\n",
    "            width=self.width_critic,\n",
    "            act=self.act_critic,\n",
    "            has_norm=not self.no_norm_critic,\n",
    "        ).to(self.device)\n",
    "        self.loss = torch.nn.MSELoss()\n",
    "        self.optim = torch.optim.Adam(self.model.parameters(), self.learning_rate)\n",
    "\n",
    "        self.prior_gamma = torch.distributions.Normal(\n",
    "            torch.tensor(0.0).to(self.device), torch.tensor(100.0).to(self.device)\n",
    "        )\n",
    "        self.prior_v = torch.distributions.Gamma(\n",
    "            torch.tensor(5.0).to(self.device), torch.tensor(1.0).to(self.device)\n",
    "        )\n",
    "        self.prior_alpha = torch.distributions.TransformedDistribution(\n",
    "            torch.distributions.Gamma(\n",
    "                torch.tensor(5.0).to(self.device), torch.tensor(1.0).to(self.device)\n",
    "            ),\n",
    "            [\n",
    "                torch.distributions.transforms.AffineTransform(\n",
    "                    loc=1.0, scale=1.0, cache_size=1\n",
    "                )\n",
    "            ],\n",
    "        )\n",
    "        self.prior_beta = torch.distributions.Gamma(\n",
    "            torch.tensor(5.0).to(self.device), torch.tensor(1.0).to(self.device)\n",
    "        )\n",
    "\n",
    "    def get_prior(self, x):\n",
    "        gamma, v, alpha, beta = self.model(x)\n",
    "        return gamma, v, alpha, beta\n",
    "\n",
    "    def loss(self, state, target):\n",
    "        gamma, v, alpha, beta = self.get_prior(state)\n",
    "        twoBlambda = 2 * beta * (1 + v)\n",
    "        loss = (\n",
    "            -0.5 * torch.log(v)\n",
    "            - alpha * torch.log(twoBlambda)\n",
    "            + (alpha + 0.5) * torch.log(v * (target - gamma) ** 2 + twoBlambda)\n",
    "            + torch.lgamma(alpha)\n",
    "            - torch.lgamma(alpha + 0.5)\n",
    "        )\n",
    "\n",
    "        regularization = (\n",
    "            self.prior_gamma.log_prob(gamma).mean()\n",
    "            + self.prior_v.log_prob(v).mean()\n",
    "            + self.prior_alpha.log_prob(alpha).mean()\n",
    "            + self.prior_beta.log_prob(beta).mean()\n",
    "        )\n",
    "        loss -= regularization * self.regularization_coeff\n",
    "\n",
    "        return loss.mean()\n",
    "\n",
    "    def update(self, state, target):  # y denotes bellman target\n",
    "        self.optim.zero_grad()\n",
    "        loss = self.loss(state, target)\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)\n",
    "        self.optim.step()\n",
    "\n",
    "    def forward(self, x, target=False):\n",
    "        gamma, v, alpha, beta = self.get_prior(x)\n",
    "        if target:\n",
    "            return gamma\n",
    "\n",
    "        return gamma\n",
    "\n",
    "    def get_mean_and_variance_of_y(self, x):\n",
    "        gamma, v, alpha, beta = self.get_prior(x)\n",
    "        mean = gamma\n",
    "        variance = (beta / (alpha - 1)) * (1 + 1.0 / v)\n",
    "        return mean, variance\n",
    "\n",
    "\n",
    "class EvidentialProximalPolicyOptimization(nn.Module):\n",
    "    _agent_name = \"EPPO\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        env,\n",
    "        actor_nn,\n",
    "        critic_nn,\n",
    "        device,\n",
    "        gamma,\n",
    "        buffer_size,\n",
    "        clip_param,\n",
    "        gae_lambda,\n",
    "        max_grad_norm,\n",
    "        learning_rate,\n",
    "        depth_actor,\n",
    "        width_actor,\n",
    "        act_actor,\n",
    "        no_norm_actor,\n",
    "        depth_critic,\n",
    "        width_critic,\n",
    "        act_critic,\n",
    "        no_norm_critic,\n",
    "        regularization_coeff,\n",
    "        radius,\n",
    "        exploration_type,\n",
    "        batch_size,\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.device = device\n",
    "        self.env = env\n",
    "        self.eps = 1e-6  # small value to avoid division by zero\n",
    "        self.clip_param = clip_param\n",
    "        self.gae_lambda = gae_lambda\n",
    "        self.max_grad_norm = max_grad_norm\n",
    "        self.learning_rate = learning_rate\n",
    "        self.depth_actor = depth_actor\n",
    "        self.width_actor = width_actor\n",
    "        self.act_actor = act_actor\n",
    "        self.no_norm_actor = no_norm_actor\n",
    "        self.depth_critic = depth_critic\n",
    "        self.width_critic = width_critic\n",
    "        self.act_critic = act_critic\n",
    "        self.no_norm_critic = no_norm_critic\n",
    "        self.regularization_coeff = regularization_coeff\n",
    "        self.radius = radius\n",
    "        self.exploration_type = (exploration_type,)\n",
    "        self.batch_size = batch_size\n",
    "\n",
    "        self.dim_obs, self.dim_act = (\n",
    "            self.env.observation_space.shape,\n",
    "            self.env.action_space.shape,\n",
    "        )\n",
    "        self.dim_obs_flat, self.dim_act_flat = np.prod(self.dim_obs), np.prod(\n",
    "            self.dim_act\n",
    "        )\n",
    "        self._u_min = totorch(self.env.action_space.low, device=self.device)\n",
    "        self._u_max = totorch(self.env.action_space.high, device=self.device)\n",
    "        self._x_min = totorch(self.env.observation_space.low, device=self.device)\n",
    "        self._x_max = totorch(self.env.observation_space.high, device=self.device)\n",
    "\n",
    "        self._gamma = gamma\n",
    "        self.buffer_size = buffer_size\n",
    "\n",
    "        dims = {\n",
    "            \"state\": (self.buffer_size, self.dim_obs_flat),\n",
    "            \"action\": (self.buffer_size, self.dim_act_flat),\n",
    "            \"next_state\": (self.buffer_size, self.dim_obs_flat),\n",
    "            \"reward\": (self.buffer_size),\n",
    "            \"terminated\": (self.buffer_size),\n",
    "            \"step\": (self.buffer_size),\n",
    "        }\n",
    "\n",
    "        self.experience_memory = ExperienceMemoryTorch(\n",
    "            self.device, self.buffer_size, dims\n",
    "        )\n",
    "\n",
    "        self.actor = EPPOActor(\n",
    "            actor_nn,\n",
    "            self.dim_obs_flat,\n",
    "            self.dim_act_flat,\n",
    "            self.clip_param,\n",
    "            self.max_grad_norm,\n",
    "            self.learning_rate,\n",
    "            self.depth_actor,\n",
    "            self.width_actor,\n",
    "            self.act_actor,\n",
    "            self.no_norm_actor,\n",
    "            self.device,\n",
    "        )\n",
    "\n",
    "        self.critic = EvidentialCritic(\n",
    "            critic_nn,\n",
    "            self.dim_obs_flat,\n",
    "            self.max_grad_norm,\n",
    "            self.depth_critic,\n",
    "            self.width_critic,\n",
    "            self.act_critic,\n",
    "            self.no_norm_critic,\n",
    "            self.device,\n",
    "            self.learning_rate,\n",
    "            self.regularization_coeff,\n",
    "        )\n",
    "\n",
    "        self._variance_coeff = (1.0 - self.gae_lambda) / (1.0 + self.gae_lambda)\n",
    "        self._next_variance_coeff = ((1.0 - self.gae_lambda) / (self.gae_lambda)) ** 2\n",
    "        self._accumulation_coeff = (self._gamma * self.gae_lambda) ** 2\n",
    "\n",
    "        if exploration_type == \"mean\":\n",
    "            self.calculate_advantages = self.calculate_advantages_mean\n",
    "        elif exploration_type == \"cor\":\n",
    "            self.calculate_advantages = self.calculate_advantages_cor\n",
    "        elif exploration_type == \"ind\":\n",
    "            self.calculate_advantages = self.calculate_advantages_ind\n",
    "        else:\n",
    "            raise ValueError(f\"Unknown exploration type: {exploration_type}\")\n",
    "\n",
    "    def end_task(self):\n",
    "        self.experience_memory.reset()\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def calculate_advantages_mean(self, states, next_states, rewards, dones):\n",
    "        # EPPO_mean\n",
    "        values = self.critic(states)\n",
    "        next_values = self.critic(next_states)\n",
    "        deltas = rewards + self._gamma * next_values * (1 - dones) - values\n",
    "        advantages = torch.zeros_like(rewards)\n",
    "        advantage = 0\n",
    "        for i in reversed(range(len(deltas))):\n",
    "            advantage = (\n",
    "                self._gamma * self.gae_lambda * advantage * (1 - dones[i]) + deltas[i]\n",
    "            )\n",
    "            advantages[i] = advantage\n",
    "\n",
    "        returns = advantages + values\n",
    "        advantages = (advantages - advantages.mean()) / (advantages.std() + self.eps)\n",
    "        return advantages, returns\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def calculate_advantages_cor(self, states, next_states, rewards, dones):\n",
    "        # EPPO_cor\n",
    "        mean_values, variance_values = self.critic.get_mean_and_variance_of_y(states)\n",
    "        mean_next_values, variance_next_values = self.critic.get_mean_and_variance_of_y(\n",
    "            next_states\n",
    "        )\n",
    "        mean_deltas = (\n",
    "            rewards + self._gamma * mean_next_values * (1 - dones) - mean_values\n",
    "        )\n",
    "\n",
    "        mean_advantages = torch.zeros_like(rewards)\n",
    "        variance_advantages = torch.zeros_like(rewards)\n",
    "        mean_advantage = 0\n",
    "        variance_accumulated = 0\n",
    "        for i in reversed(range(len(mean_deltas))):\n",
    "            mean_advantage = (\n",
    "                self._gamma * self.gae_lambda * mean_advantage * (1 - dones[i])\n",
    "                + mean_deltas[i]\n",
    "            )\n",
    "            mean_advantages[i] = mean_advantage\n",
    "            variance_accumulated = self._accumulation_coeff * (\n",
    "                variance_accumulated * (1 - dones[i]) + variance_next_values[i]\n",
    "            )\n",
    "            variance_advantages[i] = (\n",
    "                variance_values[i] + self._next_variance_coeff * variance_accumulated\n",
    "            )\n",
    "\n",
    "        std_advantages = torch.sqrt(variance_advantages)\n",
    "        advantages = mean_advantages + self.radius * std_advantages\n",
    "        returns = mean_advantages + mean_values\n",
    "        advantages = (advantages - advantages.mean()) / (advantages.std() + self.eps)\n",
    "        return advantages, returns\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def calculate_advantages_ind(self, states, next_states, rewards, dones):\n",
    "        # EPPO_ind\n",
    "        mean_values, variance_values = self.critic.get_mean_and_variance_of_y(states)\n",
    "        mean_next_values, variance_next_values = self.critic.get_mean_and_variance_of_y(\n",
    "            next_states\n",
    "        )\n",
    "        mean_deltas = (\n",
    "            rewards + self._gamma * mean_next_values * (1 - dones) - mean_values\n",
    "        )\n",
    "\n",
    "        mean_advantages = torch.zeros_like(rewards)\n",
    "        variance_advantages = torch.zeros_like(rewards)\n",
    "        mean_advantage = 0\n",
    "        variance_accumulated = 0\n",
    "        for i in reversed(range(len(mean_deltas))):\n",
    "            mean_advantage = (\n",
    "                self._gamma * self.gae_lambda * mean_advantage * (1 - dones[i])\n",
    "                + mean_deltas[i]\n",
    "            )\n",
    "            mean_advantages[i] = mean_advantage\n",
    "            variance_accumulated = self._accumulation_coeff * (\n",
    "                variance_accumulated * (1 - dones[i]) + variance_next_values[i]\n",
    "            )\n",
    "            variance_advantages[i] = (\n",
    "                self._variance_coeff * variance_values[i]\n",
    "                + self._next_variance_coeff * variance_accumulated\n",
    "            )\n",
    "\n",
    "        std_advantages = torch.sqrt(variance_advantages)\n",
    "        advantages = mean_advantages + self.radius * std_advantages\n",
    "        returns = mean_advantages + mean_values\n",
    "        advantages = (advantages - advantages.mean()) / (advantages.std() + self.eps)\n",
    "        return advantages, returns\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def calculate_old_probs(self, states, actions):\n",
    "        dist = self.actor.evaluate(states)\n",
    "        return dist.log_prob(actions)\n",
    "\n",
    "    def learn(self, max_iter=1):\n",
    "        if self.batch_size > len(self.experience_memory):\n",
    "            return None\n",
    "\n",
    "        states, actions, rewards, next_states, terminateds, _ = (\n",
    "            self.experience_memory.sample_all()\n",
    "        )\n",
    "        rewards = rewards.reshape(-1, 1)\n",
    "        terminateds = terminateds.reshape(-1, 1)\n",
    "        advantages, returns = self.calculate_advantages(\n",
    "            states, next_states, rewards, terminateds\n",
    "        )\n",
    "        old_probs = self.calculate_old_probs(states, actions)\n",
    "\n",
    "        for ii in range(max_iter):\n",
    "            # shuffle data\n",
    "            indices = torch.randperm(len(states))\n",
    "            for i in range(0, len(states), self.batch_size):\n",
    "                batch_indices = indices[i : i + self.batch_size]\n",
    "                batch_states = states[batch_indices]\n",
    "                batch_actions = actions[batch_indices]\n",
    "                batch_old_probs = old_probs[batch_indices]\n",
    "                batch_advantages = advantages[batch_indices]\n",
    "                batch_returns = returns[batch_indices]\n",
    "\n",
    "                self.actor.update(\n",
    "                    batch_states, batch_actions, batch_old_probs, batch_advantages\n",
    "                )\n",
    "                self.critic.update(batch_states, batch_returns)\n",
    "\n",
    "        # clear memory after learning due to on-policy\n",
    "        self.experience_memory.reset()\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def select_action(self, s, is_training=True):\n",
    "        a, _ = self.actor.act(s, is_training=is_training)\n",
    "        return a.clamp(self._u_min + self.eps, self._u_max - self.eps)\n",
    "\n",
    "    def Q_value(self, s, a):\n",
    "        return self.critic(s)\n",
    "\n",
    "    def store_transition(self, s, a, r, sp, terminated, truncated, step):\n",
    "        self.experience_memory.add(s, a, r, sp, terminated or truncated, step)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experimenting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# env\n",
    "eval_env = make_env(\n",
    "    exp_name=exp_name,\n",
    "    seed=seed + 100,\n",
    "    idxs=get_idx_to_paralyze(exp_name),\n",
    "    paralyzed_ratio=0.0,\n",
    ")\n",
    "\n",
    "# EPPO\n",
    "agent = EvidentialProximalPolicyOptimization(\n",
    "    env=eval_env,\n",
    "    actor_nn=ActorNetProbabilistic,\n",
    "    critic_nn=EvidentialCriticNet,\n",
    "    device=device,\n",
    "    gamma=gamma,\n",
    "    buffer_size=buffer_size,\n",
    "    clip_param=clip_param,\n",
    "    gae_lambda=gae_lambda,\n",
    "    max_grad_norm=max_grad_norm,\n",
    "    learning_rate=learning_rate,\n",
    "    depth_actor=depth_actor,\n",
    "    width_actor=width_actor,\n",
    "    act_actor=act_actor,\n",
    "    no_norm_actor=no_norm_actor,\n",
    "    depth_critic=depth_critic,\n",
    "    width_critic=width_critic,\n",
    "    act_critic=act_critic,\n",
    "    no_norm_critic=no_norm_critic,\n",
    "    regularization_coeff=regularization_coeff,\n",
    "    radius=radius,\n",
    "    exploration_type=exploration_type,\n",
    "    batch_size=batch_size,\n",
    ")\n",
    "\n",
    "# experimenter\n",
    "experimenter = ParalysisExpriment(\n",
    "    agent=agent,\n",
    "    exp_name=exp_name,\n",
    "    seed=seed,\n",
    "    max_steps=max_steps,\n",
    "    eval_frequency=eval_frequency,\n",
    "    device=device,\n",
    "    learn_frequency=learn_frequency,\n",
    "    eval_episodes=eval_episodes,\n",
    "    max_iter=max_iter,\n",
    "    gamma=gamma,\n",
    ")\n",
    "\n",
    "experimenter.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "eppo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
