{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f6be2e27-68a1-4e7e-84f2-cf201ff97608",
   "metadata": {},
   "source": [
    "# Interpretability Study"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5583dceb-8d43-45f8-b8c1-51f6ce8be949",
   "metadata": {},
   "source": [
    "Credits\n",
    "- [PyTorch](https://github.com/pytorch/pytorch)\n",
    "- [pytorch-gradcam](https://github.com/vickyliin/gradcam_plus_plus-pytorch)\n",
    "- [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3)\n",
    "- [RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "558e63b4-6464-4558-ad1d-3ef845b5d0a6",
   "metadata": {},
   "source": [
    "### Import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0bd86cc-8b33-4e85-8c1d-4de260860932",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scripts.interpretability_utils import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ed5aec0-c849-4703-b916-29ac1a41b705",
   "metadata": {},
   "source": [
    "### Create and wrap the game env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a78cff0-226b-40af-a540-c92f1988f901",
   "metadata": {},
   "outputs": [],
   "source": [
    "game = 'PongNoFrameskip-v4'\n",
    "n_envs = 1\n",
    "seed = 0\n",
    "vec_env = make_atari_env(game, n_envs=n_envs, seed=seed)\n",
    "vec_env = VecFrameStack(vec_env, n_stack=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3205d13d-7cf7-466a-a0c3-68c92db37310",
   "metadata": {},
   "outputs": [],
   "source": [
    "wrap_with_vectranspose = is_image_space(vec_env.observation_space) and not is_image_space_channels_first(vec_env.observation_space)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54d89e81-13e0-4d29-acd9-70aa4116af21",
   "metadata": {},
   "outputs": [],
   "source": [
    "# wrap the vec_env with VecTransposeImage if wrap_with_vectranspose is True\n",
    "if wrap_with_vectranspose:\n",
    "    vec_env = VecTransposeImage(vec_env)\n",
    "    print(\"VecTransposeImage is applied.\")\n",
    "else:\n",
    "    warnings.warn(\"VecTransposeImage is not applied.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae685149-6683-46c7-acb5-4555f9270f4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get action meanings\n",
    "action_meanings = deepcopy(vec_env.unwrapped.envs[0].unwrapped.get_action_meanings())\n",
    "print(f\"action meanings: {action_meanings}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "215e7cf8-e2c1-4994-87c4-8a35122193dd",
   "metadata": {},
   "source": [
    "### Re-generate 10 random observations (optional)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de6c9d39-1022-49f7-8aab-abbba48a6b15",
   "metadata": {},
   "source": [
    "You can generate random observations by running the cell below recursively."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d522b328-6cb9-4f2f-ae8e-519a05367912",
   "metadata": {},
   "outputs": [],
   "source": [
    "# my_obs_list = get_random_obs(1)\n",
    "\n",
    "# for idx, obs in enumerate(my_obs_list):\n",
    "#     frames = obs.squeeze()\n",
    "#     fig = plt.figure(idx, figsize=(10,4))\n",
    "#     for pos, frame in enumerate(frames):\n",
    "#         ax = fig.add_subplot(1, 4, pos+1)\n",
    "#         plt.imshow(frame, cmap='gray', vmin=0, vmax=255)\n",
    "#         plt.axis('off')\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8b15fe1-2227-4f9d-b8bb-18f2fb34311a",
   "metadata": {},
   "source": [
    "You can save the observations as a `.npy` file. \n",
    "\n",
    "**Note: this may overwrite the existing observations. Do use a different filename!**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa7753ac-ac60-4b00-abea-8dad35286b53",
   "metadata": {},
   "outputs": [],
   "source": [
    "# with open('gradcam/pong_new_1.npy', 'wb') as f:\n",
    "#     np.save(f, my_obs_list[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fcd2e6f4-6d19-4953-b8e3-6f43ff69bb6e",
   "metadata": {},
   "source": [
    "### Load all saved observation files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86c2e62b-b3a6-45ed-98a9-e40e97c899d7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for idx in range(10):\n",
    "    with open(f'gradcam/pong_{idx+1}.npy', 'rb') as f:\n",
    "        obs = np.load(f)\n",
    "        fig = plt.figure(idx, figsize=(10,4))\n",
    "        for pos, frame in enumerate(obs.squeeze()):\n",
    "            ax = fig.add_subplot(1, 4, pos+1)\n",
    "            plt.imshow(frame, cmap='gray', vmin=0, vmax=255)\n",
    "            plt.axis('off')\n",
    "        # plt.savefig(f\"gradcam/pong_{idx+1}.pdf\", dpi=300)\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59462fd6-e3ce-4da3-b88c-cea83757e339",
   "metadata": {},
   "source": [
    "### Define parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b095978-c23f-498a-8e23-5d04e8dedeef",
   "metadata": {},
   "outputs": [],
   "source": [
    "DEVICE = \"cpu\"\n",
    "CHECKPOINT_EXT = \"model_checkpoint_3000000_steps\"\n",
    "GAME_VERSION = \"NoFrameskip-v4\"\n",
    "game_no_version = game.replace(GAME_VERSION, '')\n",
    "sat_to_seed_mapping = {\n",
    "    'CWCA': 42,\n",
    "    'NA': 42,\n",
    "    'SWA': 42,\n",
    "    'CWRA': 1234,\n",
    "    'CWRCA': 0,\n",
    "}\n",
    "sat_list = ['CWCA', 'NA', 'SWA', 'CWRA', 'CWRCA'] # put winner agent at the leftmost"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1e47da4-6ef0-49fd-8936-e95ac0ef74ff",
   "metadata": {},
   "source": [
    "### Generate saliency maps (target layer c1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1db1b86d-0d03-4098-885f-0c927a004807",
   "metadata": {},
   "source": [
    "target_layer = policy.features_extractor.c1\n",
    "\n",
    "deterministic = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33b6a9bd-9241-4659-a3a6-203a4b7ce833",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "heatmaps_all = [] # gradcam\n",
    "actions_all = [] # actions selected by the agent (depends on deterministic parameter)\n",
    "values_all = [] # value_net output\n",
    "logits_all = [] # action_net outputs\n",
    "scores_all = [] # logits selected based on actions\n",
    "gradients_all = [] # gradients of the target layer\n",
    "activations_all = [] # feed-forward output of the target layer\n",
    "attended_feature_maps_all = [] # self_attn_layer outputs\n",
    "\n",
    "h = 84\n",
    "w = 84\n",
    "deterministic = True\n",
    "\n",
    "# each row holds all heatmaps for an obs \n",
    "# each column is an sat model\n",
    "for obs_idx in range(10):\n",
    "    # load observation .npy file\n",
    "    with open(f'gradcam/pong_{obs_idx+1}.npy', 'rb') as f:\n",
    "        obs = np.load(f)\n",
    "    # convert observation to torch tensor\n",
    "    obs_th = th.FloatTensor(obs)\n",
    "    heatmaps_per_obs = []\n",
    "    actions_per_obs = []\n",
    "    values_per_obs = []\n",
    "    logits_per_obs = []\n",
    "    scores_per_obs = []\n",
    "    gradients_per_obs = []\n",
    "    activations_per_obs = []\n",
    "    attended_feature_maps_per_obs = []\n",
    "    for sat in sat_list:\n",
    "        # instantiate the PPO model\n",
    "        policy_kwargs = eval(f\"dict(features_extractor_class=SelfAttnCNNPPO, features_extractor_kwargs=dict(self_attn='{sat}'), net_arch=[])\")\n",
    "        model = PPO(policy='CnnPolicy', env=vec_env, seed=seed, verbose=True, device=DEVICE, policy_kwargs=policy_kwargs) # seed is the env's seed set at the beginning\n",
    "        # update the model by loading the selected checkpoint zip file\n",
    "        # get seed used in the model checkpoint file\n",
    "        seed_checkpoint = sat_to_seed_mapping.get(sat)\n",
    "        model_updated = model.load(f\"gradcam/{game_no_version}_{sat}_{seed_checkpoint}_{CHECKPOINT_EXT}\", device=DEVICE, print_system_info=False)\n",
    "        # get the policy and the target layer\n",
    "        policy = model_updated.policy\n",
    "        target_layer = policy.features_extractor.c1\n",
    "        # instantiate a GradCAM object\n",
    "        gradcam = GradCAM(policy, target_layer, sat=sat)\n",
    "        # call the forward() in GradCAM\n",
    "        saliency_map, action, value, logit, score, gradients, activations, attended_feature_maps = gradcam(obs_th, deterministic=deterministic)\n",
    "        # convert saliency map to heatmap\n",
    "        heatmap = convert_to_heatmap(saliency_map)\n",
    "        heatmaps_per_obs.extend([heatmap])\n",
    "        actions_per_obs.extend([action])\n",
    "        values_per_obs.extend([value])\n",
    "        logits_per_obs.extend([logit])\n",
    "        scores_per_obs.extend([score])\n",
    "        gradients_per_obs.extend([gradients])\n",
    "        activations_per_obs.extend([activations])\n",
    "        # preprocess attended_feature_maps before converting to headmap style\n",
    "        if attended_feature_maps is not None:\n",
    "            attended_feature_maps_sum = attended_feature_maps.sum(1, keepdim=True) # shape=(1, 1, 20, 20)\n",
    "            attended_feature_maps_upsample = F.interpolate(attended_feature_maps_sum, size=(h,w), mode='bilinear', align_corners=False) # upsample to input size (1, 1, 84, 84)\n",
    "            attended_feature_maps_upsample_min, attended_feature_maps_upsample_max = attended_feature_maps_upsample.min(), attended_feature_maps_upsample.max() # get min and max\n",
    "            attended_feature_maps_norm = (attended_feature_maps_upsample - attended_feature_maps_upsample_min).div(attended_feature_maps_upsample_max - attended_feature_maps_upsample_min).data\n",
    "            # convert to heatmap style (color)\n",
    "            attended_feature_maps_color = convert_to_heatmap(attended_feature_maps_norm)\n",
    "            attended_feature_maps_per_obs.extend([attended_feature_maps_color])\n",
    "    heatmaps_all.extend(heatmaps_per_obs)\n",
    "    actions_all.extend(actions_per_obs)\n",
    "    values_all.extend(values_per_obs)\n",
    "    logits_all.extend(logits_per_obs)\n",
    "    scores_all.extend(scores_per_obs)\n",
    "    gradients_all.extend(gradients_per_obs)\n",
    "    activations_all.extend(activations_per_obs)\n",
    "    attended_feature_maps_all.extend(attended_feature_maps_per_obs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91d38c27-1c7e-4016-a147-bd8b4e11bb52",
   "metadata": {},
   "source": [
    "#### Heatmaps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c819383-d16f-452d-9751-af79e4c0831d",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_image_heatmap = make_grid(heatmaps_all, nrow=5)\n",
    "grid_image_heatmap_PIL = transforms.ToPILImage()(grid_image_heatmap)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62a5a719-f389-4a10-9e58-fd2985ef4f3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_image_heatmap_PIL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13688f07-52de-4504-a78a-7178b6047678",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_image_heatmap_PIL.save(\"gradcam/heatmap_c1_deterministic_true.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6dd8924f-73a7-4a5e-b0a0-ebb03b64c016",
   "metadata": {},
   "source": [
    "#### Attended feature maps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76d46689-69ed-41c9-8d45-dca620c8380b",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_image_attended_feature_map = make_grid(attended_feature_maps_all, nrow=4)\n",
    "grid_image_attended_feature_map_PIL = transforms.ToPILImage()(grid_image_attended_feature_map)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79b58c49-13da-448a-b1ea-399fac05a469",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_image_attended_feature_map_PIL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37dc642f-e248-4d43-be8e-ca4ff29f0b2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_image_attended_feature_map_PIL.save(\"gradcam/attended_feature_map_c1_deterministic_true.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c419f329-33bc-4ee2-ac72-647dc3eb97f5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "actions_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5aa44cfa-aa6f-4f02-bff6-474fbe0ef37a",
   "metadata": {},
   "outputs": [],
   "source": [
    "actions_CWCA = actions_all[0::5]\n",
    "actions_CWCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79fc9a5a-bac8-46f0-a49a-8eacf2e6a196",
   "metadata": {},
   "outputs": [],
   "source": [
    "actions_CWRCA = actions_all[4::5]\n",
    "actions_CWRCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b79e6ef7-cfb9-4faa-9fef-74bff41cd2f7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "values_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69b02826-de30-49c9-8d5b-c5870f48ee03",
   "metadata": {},
   "outputs": [],
   "source": [
    "values_CWCA = values_all[0::5]\n",
    "values_CWCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffa9a252-243e-4271-bc2e-fe91e8da4237",
   "metadata": {},
   "outputs": [],
   "source": [
    "values_CWRCA = values_all[4::5]\n",
    "values_CWRCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef51629d-4586-403a-b190-853b3deb0021",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "logits_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8e7dea0-9dd7-4287-8787-4e2f49f906fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "logits_CWCA = logits_all[0::5]\n",
    "logits_CWCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2489e852-5121-47b6-b898-eff475ed5859",
   "metadata": {},
   "outputs": [],
   "source": [
    "logits_CWRCA = logits_all[4::5]\n",
    "logits_CWRCA"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f238c746-f978-477b-a45e-505150d77c5b",
   "metadata": {},
   "source": [
    "### Generate saliency maps (target layer c2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97ec54d5-1666-4f97-adde-208c9d4298de",
   "metadata": {},
   "source": [
    "target_layer = policy.features_extractor.c2\n",
    "\n",
    "deterministic = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca36a2dd-466d-43b1-ad5e-073e008eb3da",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "heatmaps_all = [] # gradcam\n",
    "actions_all = [] # actions selected by the agent (depends on deterministic parameter)\n",
    "values_all = [] # value_net output\n",
    "logits_all = [] # action_net outputs\n",
    "scores_all = [] # logits selected based on actions\n",
    "gradients_all = [] # gradients of the target layer\n",
    "activations_all = [] # feed-forward output of the target layer\n",
    "attended_feature_maps_all = [] # self_attn_layer outputs\n",
    "\n",
    "h = 84\n",
    "w = 84\n",
    "deterministic = True\n",
    "\n",
    "# each row holds all heatmaps for an obs \n",
    "# each column is an sat model\n",
    "for obs_idx in range(10):\n",
    "    # load observation .npy file\n",
    "    with open(f'gradcam/pong_{obs_idx+1}.npy', 'rb') as f:\n",
    "        obs = np.load(f)\n",
    "    # convert observation to torch tensor\n",
    "    obs_th = th.FloatTensor(obs)\n",
    "    heatmaps_per_obs = []\n",
    "    actions_per_obs = []\n",
    "    values_per_obs = []\n",
    "    logits_per_obs = []\n",
    "    scores_per_obs = []\n",
    "    gradients_per_obs = []\n",
    "    activations_per_obs = []\n",
    "    attended_feature_maps_per_obs = []\n",
    "    for sat in sat_list:\n",
    "        # instantiate the PPO model\n",
    "        policy_kwargs = eval(f\"dict(features_extractor_class=SelfAttnCNNPPO, features_extractor_kwargs=dict(self_attn='{sat}'), net_arch=[])\")\n",
    "        model = PPO(policy='CnnPolicy', env=vec_env, seed=seed, verbose=True, device=DEVICE, policy_kwargs=policy_kwargs) # seed is the env's seed set at the beginning\n",
    "        # update the model by loading the selected checkpoint zip file\n",
    "        # get seed used in the model checkpoint file\n",
    "        seed_checkpoint = sat_to_seed_mapping.get(sat)\n",
    "        model_updated = model.load(f\"gradcam/{game_no_version}_{sat}_{seed_checkpoint}_{CHECKPOINT_EXT}\", device=DEVICE, print_system_info=False)\n",
    "        # get the policy and the target layer\n",
    "        policy = model_updated.policy\n",
    "        target_layer = policy.features_extractor.c2\n",
    "        # instantiate a GradCAM object\n",
    "        gradcam = GradCAM(policy, target_layer, sat=sat)\n",
    "        # call the forward() in GradCAM\n",
    "        saliency_map, action, value, logit, score, gradients, activations, attended_feature_maps = gradcam(obs_th, deterministic=deterministic)\n",
    "        # convert saliency map to heatmap\n",
    "        heatmap = convert_to_heatmap(saliency_map)\n",
    "        heatmaps_per_obs.extend([heatmap])\n",
    "        actions_per_obs.extend([action])\n",
    "        values_per_obs.extend([value])\n",
    "        logits_per_obs.extend([logit])\n",
    "        scores_per_obs.extend([score])\n",
    "        gradients_per_obs.extend([gradients])\n",
    "        activations_per_obs.extend([activations])\n",
    "        # preprocess attended_feature_maps before converting to headmap style\n",
    "        if attended_feature_maps is not None:\n",
    "            attended_feature_maps_sum = attended_feature_maps.sum(1, keepdim=True) # shape=(1, 1, 20, 20)\n",
    "            attended_feature_maps_upsample = F.interpolate(attended_feature_maps_sum, size=(h,w), mode='bilinear', align_corners=False) # upsample to input size (1, 1, 84, 84)\n",
    "            attended_feature_maps_upsample_min, attended_feature_maps_upsample_max = attended_feature_maps_upsample.min(), attended_feature_maps_upsample.max() # get min and max\n",
    "            attended_feature_maps_norm = (attended_feature_maps_upsample - attended_feature_maps_upsample_min).div(attended_feature_maps_upsample_max - attended_feature_maps_upsample_min).data\n",
    "            # convert to heatmap style (color)\n",
    "            attended_feature_maps_color = convert_to_heatmap(attended_feature_maps_norm)\n",
    "            attended_feature_maps_per_obs.extend([attended_feature_maps_color])\n",
    "    heatmaps_all.extend(heatmaps_per_obs)\n",
    "    actions_all.extend(actions_per_obs)\n",
    "    values_all.extend(values_per_obs)\n",
    "    logits_all.extend(logits_per_obs)\n",
    "    scores_all.extend(scores_per_obs)\n",
    "    gradients_all.extend(gradients_per_obs)\n",
    "    activations_all.extend(activations_per_obs)\n",
    "    attended_feature_maps_all.extend(attended_feature_maps_per_obs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d168347b-7693-454f-9d2c-8b66e316e1f6",
   "metadata": {},
   "source": [
    "#### Heatmaps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a03a126-705f-4252-86a6-1e67d2f00da2",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_image_heatmap = make_grid(heatmaps_all, nrow=5)\n",
    "grid_image_heatmap_PIL = transforms.ToPILImage()(grid_image_heatmap)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eaa0fdb5-460f-4a85-abbc-0537aa82bb46",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_image_heatmap_PIL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3ee3e24-8cf6-4c1e-a142-50199997e77f",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_image_heatmap_PIL.save(\"gradcam/heatmap_c2_deterministic_true.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03178477-74ac-40fd-9dd1-8da1e9642068",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "actions_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "757c4801-f613-4bab-be75-1e37146cf569",
   "metadata": {},
   "outputs": [],
   "source": [
    "actions_CWCA = actions_all[0::5]\n",
    "actions_CWCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cb89523-fd16-435b-a88b-220962f38043",
   "metadata": {},
   "outputs": [],
   "source": [
    "actions_CWRCA = actions_all[4::5]\n",
    "actions_CWRCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a32e1225-924e-4f26-8d62-c3480e3d5d62",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "values_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "385b681d-ddd6-4828-8c02-ed9cf504282d",
   "metadata": {},
   "outputs": [],
   "source": [
    "values_CWCA = values_all[0::5]\n",
    "values_CWCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bce8bb26-4deb-427a-be58-85c41a483ab8",
   "metadata": {},
   "outputs": [],
   "source": [
    "values_CWRCA = values_all[4::5]\n",
    "values_CWRCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f097858d-8cef-4f32-9bf1-d85600a70748",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "logits_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18523798-e582-4cd0-8770-4f3f65843768",
   "metadata": {},
   "outputs": [],
   "source": [
    "logits_CWCA = logits_all[0::5]\n",
    "logits_CWCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71d254a7-8a9f-4f6c-87d2-4bab10537934",
   "metadata": {},
   "outputs": [],
   "source": [
    "logits_CWRCA = logits_all[4::5]\n",
    "logits_CWRCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c4e9be0-8bcd-421f-83c1-62ac440a11a2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "actions_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8194e68-7c93-444a-8862-a5f4faeb45b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "actions_CWCA = actions_all[0::5]\n",
    "actions_CWCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "908b55f2-f019-4ac3-8c9c-ac6107c880af",
   "metadata": {},
   "outputs": [],
   "source": [
    "actions_CWRCA = actions_all[4::5]\n",
    "actions_CWRCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e6cdfb7-6441-4586-8b43-66b771c93d71",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "values_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4fbf8c4-fe4e-457b-b58a-1d9944a32ad5",
   "metadata": {},
   "outputs": [],
   "source": [
    "values_CWCA = values_all[0::5]\n",
    "values_CWCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89083fb2-2eff-4869-b685-09206b6c4477",
   "metadata": {},
   "outputs": [],
   "source": [
    "values_CWRCA = values_all[4::5]\n",
    "values_CWRCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a32f3fb8-7fcb-4211-9e6e-93d10a5ef75e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "logits_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aed5f205-4ede-4729-8680-2d338c97dd1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "logits_CWCA = logits_all[0::5]\n",
    "logits_CWCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66bf3fa7-f96e-4990-8560-d38c3b9d4acb",
   "metadata": {},
   "outputs": [],
   "source": [
    "logits_CWRCA = logits_all[4::5]\n",
    "logits_CWRCA"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
