{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45be1da2-d634-477c-907d-c2ecc83e48a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running with seed 42!\n",
      "Saving results to: /root/capsule/code/files_seed42_NNM4_tau1000p0_hs256_lr0p001_ctx0p01_init10000_ad1000\n",
      "Training on 15 tasks (first 3): ['yang19.dnmc-v0', 'yang19.dlyanti-v0', 'yang19.ctxdlydm2-v0']\n",
      "Adapting to 5 tasks: ['yang19.dms-v0', 'yang19.dm2-v0', 'yang19.ctxdm1-v0', 'yang19.go-v0', 'yang19.anti-v0']\n",
      "Device: cuda\n",
      "\n",
      "\n",
      "--- Training and Adapting Model: nmRNN_Spatial ---\n",
      "Total trainable parameters for nmRNN_Spatial: 359,952\n",
      "\n",
      "--- Phase 1: Initial Training on 15 tasks for nmRNN_Spatial ---\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/gym/utils/passive_env_checker.py:216: UserWarning: \u001b[33mWARN: Future gym versions will require that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator. \u001b[0m\n",
      "  logger.warn(\n",
      "/opt/conda/lib/python3.8/site-packages/gym/utils/passive_env_checker.py:228: UserWarning: \u001b[33mWARN: Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting.\u001b[0m\n",
      "  logger.warn(\n",
      "/opt/conda/lib/python3.8/site-packages/gym/utils/passive_env_checker.py:233: UserWarning: \u001b[33mWARN: Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information.\u001b[0m\n",
      "  logger.warn(\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[4], line 739\u001b[0m\n\u001b[1;32m    736\u001b[0m optimizer_model\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m    737\u001b[0m outputs, _ \u001b[38;5;241m=\u001b[39m model(inputs) \n\u001b[0;32m--> 739\u001b[0m outputs_view \u001b[38;5;241m=\u001b[39m \u001b[43moutputs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mview\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mact_size\u001b[49m\u001b[43m)\u001b[49m \n\u001b[1;32m    740\u001b[0m outputs_for_loss \u001b[38;5;241m=\u001b[39m outputs_view[valid_indices]\n\u001b[1;32m    741\u001b[0m labels_filtered \u001b[38;5;241m=\u001b[39m labels_flat[valid_indices]\n",
      "\u001b[0;31mRuntimeError\u001b[0m: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead."
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "from pathlib import Path\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import gym\n",
    "import neurogym as ngym\n",
    "from neurogym.wrappers import ScheduleEnvs\n",
    "# from neurogym.utils.scheduler import RandomSchedule # Not directly used with custom datasets\n",
    "# from models import RNNNet, get_performance # Assuming models.py contains these\n",
    "\n",
    "# import argparse # Removed argparse\n",
    "import numpy as np\n",
    "import random\n",
    "import copy\n",
    "import pandas as pd\n",
    "import math # For nmRNN decay and other calculations\n",
    "import scipy.spatial as ss # For SpatialWeight in nmRNN\n",
    "\n",
    "# --- Notebook-friendly Argument Definition ---\n",
    "class Args:\n",
    "    def __init__(self):\n",
    "        self.seed = 0\n",
    "        self.batch_size = 64\n",
    "        self.seq_len = 100\n",
    "        self.hidden_size = 256\n",
    "        self.lr_model = 1e-3\n",
    "        self.lr_context = 1e-2\n",
    "        self.initial_train_steps = 10000 # Reduced for quicker runs\n",
    "        self.adapt_steps = 1000       # Reduced\n",
    "        self.print_step = 200\n",
    "        self.eval_trials = 50         # Reduced\n",
    "        self.N_NM = 4                 # Number of neuromodulators for nmRNN\n",
    "        self.tau_nmrnn = 1000.0       # Time constant (ms) for nmRNN decay calculation\n",
    "        self.grad_clip_nmrnn = 0.1    # Gradient clipping value for nmRNN internal hook (0 or None to disable)\n",
    "\n",
    "args = Args() # Instantiate the new args class\n",
    "\n",
    "def get_performance(model, env, num_trial=100, device='cpu',\n",
    "                    context_vector_global=None, raw_obs_size=None, seq_len_eval=100):\n",
    "    \"\"\"\n",
    "    Evaluates model performance on a given environment.\n",
    "    Handles context vector injection if provided.\n",
    "    Args:\n",
    "        model: The neural network model.\n",
    "        env: The Gym environment instance.\n",
    "        num_trial (int): Number of trials to run for evaluation.\n",
    "        device (str): Device to run evaluation on ('cpu' or 'cuda').\n",
    "        context_vector_global (np.array, optional): A global context vector to be concatenated\n",
    "                                                     to raw observations. Shape (num_total_tasks,).\n",
    "        raw_obs_size (int, optional): The feature size of the raw observation from the env.\n",
    "                                      Required if context_vector_global is used.\n",
    "        seq_len_eval (int): The sequence length for evaluation trials.\n",
    "    Returns:\n",
    "        float: Average performance (accuracy) over trials.\n",
    "    \"\"\"\n",
    "    model.eval() # Set model to evaluation mode for performance assessment\n",
    "    total_correct = 0\n",
    "    total_steps = 0\n",
    "\n",
    "    # Determine sequence length for evaluation\n",
    "    # Priority: env.seq_len > env.spec.max_episode_steps > seq_len_eval (arg)\n",
    "    if hasattr(env, 'seq_len') and env.seq_len is not None:\n",
    "        current_seq_len = env.seq_len\n",
    "    elif hasattr(env, 'spec') and env.spec is not None and hasattr(env.spec, 'max_episode_steps') and env.spec.max_episode_steps is not None:\n",
    "        current_seq_len = env.spec.max_episode_steps\n",
    "    else: \n",
    "        current_seq_len = seq_len_eval\n",
    "\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for _ in range(num_trial):\n",
    "            env_obs = env.reset()\n",
    "            raw_obs = env_obs['observation'] if isinstance(env_obs, dict) and 'observation' in env_obs else env_obs\n",
    "            \n",
    "            hidden_state = None # For stateful models like RNN/LSTM if needed step-by-step\n",
    "\n",
    "            trial_raw_obs_list = []\n",
    "            trial_gt_list = []\n",
    "            \n",
    "            for t in range(current_seq_len): # Iterate for the determined sequence length\n",
    "                trial_raw_obs_list.append(raw_obs)\n",
    "                \n",
    "                action = env.action_space.sample() \n",
    "                env_obs, _, done, info = env.step(action)\n",
    "                raw_obs = env_obs['observation'] if isinstance(env_obs, dict) and 'observation' in env_obs else env_obs\n",
    "                if 'gt' not in info:\n",
    "                    trial_gt_list.append(-1) \n",
    "                else:\n",
    "                    trial_gt_list.append(info['gt'])\n",
    "\n",
    "                if done:\n",
    "                    break \n",
    "            \n",
    "            raw_obs_sequence_np = np.array(trial_raw_obs_list) \n",
    "            gt_sequence_np = np.array(trial_gt_list)       \n",
    "            trial_actual_len = raw_obs_sequence_np.shape[0]\n",
    "\n",
    "            if trial_actual_len == 0: continue \n",
    "\n",
    "            model_input_np = raw_obs_sequence_np\n",
    "            if context_vector_global is not None and raw_obs_size is not None:\n",
    "                if raw_obs_sequence_np.shape[-1] != raw_obs_size:\n",
    "                    # This condition should ideally not be met if raw_obs_size is correctly identified.\n",
    "                    # print(f\"Warning: raw_obs_sequence_np dim {raw_obs_sequence_np.shape[-1]} != raw_obs_size {raw_obs_size} during get_performance.\")\n",
    "                    pass # Assuming raw_obs_size is correct and raw_obs_sequence_np is truly raw.\n",
    "                context_expanded = np.tile(context_vector_global, (trial_actual_len, 1)) \n",
    "                model_input_np = np.concatenate((raw_obs_sequence_np.astype(np.float32), context_expanded.astype(np.float32)), axis=-1)\n",
    "\n",
    "\n",
    "            inputs = torch.from_numpy(model_input_np).type(torch.float).unsqueeze(0).to(device) \n",
    "            \n",
    "            # Pass hidden_state only if model is RNN or LSTM and expects it for sequential processing\n",
    "            # For Transformer or nmRNN_Adapter, hidden_state is often managed internally or not used in this way.\n",
    "            if isinstance(model, (RNNNet, LSTMNet)): \n",
    "                 outputs, hidden_state = model(inputs, hidden_state) # hidden_state would be from previous step if doing step-by-step eval\n",
    "                                                                    # For full sequence eval, hidden_state is initial (None)\n",
    "            else: \n",
    "                 outputs, _ = model(inputs) # Assumes model handles its state or is stateless per call\n",
    "\n",
    "            predictions = torch.argmax(outputs.squeeze(0), dim=-1) \n",
    "            \n",
    "            valid_gt_indices = gt_sequence_np != -1\n",
    "            if not valid_gt_indices.any(): continue\n",
    "\n",
    "            correct_in_trial = (predictions.cpu().numpy()[valid_gt_indices] == gt_sequence_np[valid_gt_indices]).sum()\n",
    "            total_correct += correct_in_trial\n",
    "            total_steps += valid_gt_indices.sum() # Count only valid steps\n",
    "            \n",
    "    return total_correct / total_steps if total_steps > 0 else 0.0\n",
    "\n",
    "\n",
    "# --- Model Definitions (Placeholders for LSTM and Transformer) ---\n",
    "class LSTMNet(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, output_size, dt):\n",
    "        super(LSTMNet, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) \n",
    "        self.fc = nn.Linear(hidden_size, output_size)\n",
    "        self.dt = dt \n",
    "\n",
    "    def forward(self, x, hidden=None):\n",
    "        lstm_out, hidden = self.lstm(x, hidden) \n",
    "        output = self.fc(lstm_out) \n",
    "        return output, hidden\n",
    "\n",
    "class TransformerNet(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, output_size, nhead=4, num_layers=2, dt=None):\n",
    "        super(TransformerNet, self).__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size \n",
    "        self.dt = dt \n",
    "\n",
    "        if input_size != hidden_size:\n",
    "            self.input_proj = nn.Linear(input_size, hidden_size)\n",
    "        else:\n",
    "            self.input_proj = nn.Identity()\n",
    "        \n",
    "        encoder_layers = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=nhead, dim_feedforward=hidden_size*2, batch_first=True, dropout=0.1) \n",
    "        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)\n",
    "        self.fc_out = nn.Linear(hidden_size, output_size)\n",
    "\n",
    "    def forward(self, src, src_mask=None): \n",
    "        src = self.input_proj(src) \n",
    "        output = self.transformer_encoder(src, mask=src_mask) \n",
    "        output = self.fc_out(output) \n",
    "        return output, None \n",
    "\n",
    "\n",
    "# --- nmRNN Model Classes (from user) ---\n",
    "class SpatialWeight(nn.Module):\n",
    "    \"\"\"\n",
    "    Module to compute spatially dependent weights based on neuron positions.\n",
    "    Includes distance decay and inhibitory neuron specification.\n",
    "    \"\"\"\n",
    "    def __init__(self, N_nm, observable_size=64, ell=0.1): \n",
    "        super(SpatialWeight, self).__init__()\n",
    "        \n",
    "        self.register_buffer('pos_const', torch.tensor(np.random.rand(observable_size, 2), dtype=torch.float32))\n",
    "\n",
    "        pos_np = self.pos_const.cpu().numpy() \n",
    "        delpoints_np = ss.distance.cdist(pos_np, pos_np)\n",
    "        # Ensure N_nm is at least 1 for this expansion to make sense\n",
    "        if N_nm == 0 : N_nm_effective = 1 # Treat as if 1 NM dim for structure, though it won't be used if N_NM=0 overall\n",
    "        else: N_nm_effective = N_nm\n",
    "        delpoints_expanded_np = delpoints_np[:, :, None] * np.ones([observable_size, observable_size, N_nm_effective])\n",
    "\n",
    "\n",
    "        self.ell = ell  \n",
    "        pinhib = 0.5  \n",
    "        self.scale = 1.0 \n",
    "\n",
    "        inhib_np = (np.random.choice([0, 1], size=(observable_size,1,1), p=[1 - pinhib, pinhib])) * np.ones_like(delpoints_expanded_np)\n",
    "        self.register_buffer('inhib_const', torch.tensor(inhib_np, dtype=torch.float32))\n",
    "        \n",
    "        self.register_buffer('Delta_const', torch.tensor(delpoints_expanded_np / self.ell, dtype=torch.float32))\n",
    "        \n",
    "        mask_np = np.logical_and(delpoints_expanded_np < 5 * self.ell, np.eye(observable_size)[:, :, None] * np.ones_like(delpoints_expanded_np) == 0)\n",
    "        self.register_buffer('mask_const', torch.tensor(mask_np, dtype=torch.float32))\n",
    "\n",
    "    def forward(self, W):\n",
    "        \"\"\" W is the base weight tensor before spatial modulation. \"\"\"\n",
    "        # If W's last dim doesn't match mask_const (e.g. N_NM=0 case), adjust mask or W\n",
    "        if W.shape[-1] != self.mask_const.shape[-1]:\n",
    "            # This can happen if N_NM=0, where W might be (H,H) but mask_const is (H,H,1)\n",
    "            # This part needs careful handling based on how W is shaped when N_NM=0\n",
    "            # For now, assume W will have the N_nm_effective dimension if spatialNet is used\n",
    "            pass\n",
    "\n",
    "        return self.scale * ((-1.0)**self.inhib_const) * torch.exp(W - self.Delta_const) * self.mask_const\n",
    "\n",
    "\n",
    "class spatial_nmRNNCell_base(nn.Module):\n",
    "    def __init__(self, N_nm, input_size, hidden_size, nonlinearity, bias, keepW0=False, use_spatial_net=True):\n",
    "        super().__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.nonlinearity = nonlinearity \n",
    "        self.N_nm = N_nm \n",
    "        self.keepW0 = keepW0 \n",
    "        self.g = 10.0 \n",
    "        self.use_spatial_net = use_spatial_net\n",
    "\n",
    "        self.weight_ih = nn.Parameter(torch.Tensor(hidden_size, input_size))\n",
    "        \n",
    "        if self.use_spatial_net and self.N_nm > 0: \n",
    "            self.spatialNet_instance = SpatialWeight(N_nm=N_nm, observable_size=hidden_size) \n",
    "            self.base_weight_hh_modulated = nn.Parameter(torch.Tensor(hidden_size, hidden_size, N_nm))\n",
    "        elif self.N_nm > 0 : \n",
    "            self.spatialNet_instance = None\n",
    "            self.weight_hh_direct_modulated = nn.Parameter(torch.Tensor(hidden_size, hidden_size, N_nm)) \n",
    "        else: \n",
    "            self.spatialNet_instance = None\n",
    "            # If N_nm = 0, no modulated recurrent weights are created here. Logic in forward handles it.\n",
    "\n",
    "        self.weight_h2nm = nn.Parameter(torch.Tensor(N_nm, hidden_size)) if N_nm > 0 else None\n",
    "        self.weight_nm2nm = nn.Parameter(torch.Tensor(N_nm, N_nm)) if N_nm > 0 else None\n",
    "\n",
    "        if keepW0:\n",
    "            self.weight0_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size))\n",
    "        else:\n",
    "            self.register_buffer('weight0_hh_const', torch.zeros(hidden_size, hidden_size), persistent=False) # Renamed to avoid clash\n",
    "\n",
    "        if bias:\n",
    "            self.bias = nn.Parameter(torch.Tensor(hidden_size))\n",
    "        else:\n",
    "            self.register_parameter('bias', None) \n",
    "\n",
    "        self.reset_parameters() \n",
    "\n",
    "    def get_weight0_hh(self): # Helper to get W0, whether param or buffer\n",
    "        if self.keepW0:\n",
    "            return self.weight0_hh\n",
    "        else:\n",
    "            return self.weight0_hh_const\n",
    "\n",
    "    def get_modulated_hh_weights(self):\n",
    "        if self.use_spatial_net and self.N_nm > 0 and hasattr(self, 'spatialNet_instance') and self.spatialNet_instance is not None:\n",
    "            return self.spatialNet_instance(self.base_weight_hh_modulated)\n",
    "        elif self.N_nm > 0 and hasattr(self, 'weight_hh_direct_modulated'):\n",
    "             return self.weight_hh_direct_modulated\n",
    "        return None \n",
    "\n",
    "    def reset_parameters(self):\n",
    "        nn.init.kaiming_uniform_(self.weight_ih, a=math.sqrt(5))\n",
    "        \n",
    "        if self.use_spatial_net and self.N_nm > 0 and hasattr(self, 'base_weight_hh_modulated'):\n",
    "            if self.hidden_size > 0:\n",
    "                 nn.init.kaiming_uniform_(self.base_weight_hh_modulated, a=self.g / math.sqrt(self.hidden_size))\n",
    "            else:\n",
    "                 nn.init.zeros_(self.base_weight_hh_modulated)\n",
    "        elif self.N_nm > 0 and hasattr(self, 'weight_hh_direct_modulated'):\n",
    "             if self.hidden_size > 0:\n",
    "                 nn.init.kaiming_uniform_(self.weight_hh_direct_modulated, a=self.g / math.sqrt(self.hidden_size))\n",
    "             else:\n",
    "                 nn.init.zeros_(self.weight_hh_direct_modulated)\n",
    "\n",
    "        if self.weight_h2nm is not None: nn.init.sparse_(self.weight_h2nm, sparsity=0.1) # Corrected sparse_ call\n",
    "        if self.weight_nm2nm is not None: nn.init.zeros_(self.weight_nm2nm) \n",
    "\n",
    "        if self.keepW0 and isinstance(self.get_weight0_hh(), nn.Parameter): \n",
    "            nn.init.kaiming_uniform_(self.get_weight0_hh(), a=math.sqrt(5))\n",
    "        # No need to re-initialize buffer if not keepW0, it's already zeros\n",
    "\n",
    "        if self.bias is not None:\n",
    "            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_ih)\n",
    "            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0\n",
    "            nn.init.uniform_(self.bias, -bound, bound)\n",
    "\n",
    "\n",
    "class s_nmRNNCell(spatial_nmRNNCell_base):\n",
    "    def __init__(self, N_nm, input_size, hidden_size, out_size, nonlinearity=None, decay=0.0, bias=True, keepW0=True, use_spatial_net=True):\n",
    "        super().__init__(N_nm, input_size, hidden_size, nonlinearity, bias, keepW0=keepW0, use_spatial_net=use_spatial_net)\n",
    "        self.decay = decay \n",
    "        self.out_size = out_size \n",
    "\n",
    "    def forward(self, input_val, hiddenCombined): \n",
    "        batch_size = input_val.shape[0]\n",
    "        \n",
    "        if hiddenCombined.dim() == 3 and hiddenCombined.shape[0] == 1:\n",
    "            hiddenCombined = hiddenCombined.squeeze(0) \n",
    "\n",
    "        if self.N_nm > 0:\n",
    "            hidden = hiddenCombined[:, :self.hidden_size] # Slice hidden part correctly\n",
    "            nm = hiddenCombined[:, self.hidden_size:]    # Slice NM part correctly\n",
    "            if nm.shape[1] != self.N_nm:\n",
    "                raise ValueError(f\"NM state slice incorrect in s_nmRNNCell. Expected {self.N_nm} features, got {nm.shape[1]}\")\n",
    "        else:\n",
    "            hidden = hiddenCombined\n",
    "            nm = None \n",
    "\n",
    "        pre_activity = torch.matmul(input_val, self.weight_ih.t()) \n",
    "        \n",
    "        current_weight0_hh = self.get_weight0_hh() # Get W0 (param or buffer)\n",
    "        pre_activity += torch.matmul(hidden, current_weight0_hh.t())\n",
    "\n",
    "        current_modulated_hh = self.get_modulated_hh_weights()\n",
    "        if nm is not None and self.N_nm > 0 and current_modulated_hh is not None:\n",
    "            modulated_rec = torch.einsum('bj,ijk,bk->bi', hidden, current_modulated_hh, nm)\n",
    "            pre_activity += modulated_rec \n",
    "\n",
    "        if self.bias is not None:\n",
    "            pre_activity += self.bias \n",
    "\n",
    "        activity = self.nonlinearity(pre_activity)\n",
    "        hidden_new = self.decay * hidden + (1 - self.decay) * activity\n",
    "\n",
    "        if nm is not None and self.N_nm > 0 and self.weight_h2nm is not None and self.weight_nm2nm is not None:\n",
    "            pre_activity_nm = torch.matmul(hidden, self.weight_h2nm.t()) \n",
    "            pre_activity_nm += torch.matmul(nm, self.weight_nm2nm.t()) \n",
    "            activity_nm = self.nonlinearity(pre_activity_nm) \n",
    "            nm_new = self.decay * nm + (1 - self.decay) * activity_nm\n",
    "            hiddenCombined_new = torch.cat([hidden_new, nm_new], dim=1) \n",
    "        else:\n",
    "            hiddenCombined_new = hidden_new\n",
    "\n",
    "        return hiddenCombined_new.unsqueeze(0)\n",
    "\n",
    "\n",
    "class s_nmRNNLayer(nn.Module):\n",
    "    def __init__(self, N_nm, input_size, hidden_size, out_size, nonlinearity, decay=0.9, bias=False, keepW0=False, use_spatial_net=True):\n",
    "        super().__init__()\n",
    "        self.rnncell = s_nmRNNCell(N_nm, input_size, hidden_size, out_size, nonlinearity=nonlinearity, decay=decay, bias=bias, keepW0=keepW0, use_spatial_net=use_spatial_net)\n",
    "        self.N_nm = N_nm\n",
    "        self.hidden_size = hidden_size \n",
    "        self.out_size = out_size \n",
    "        \n",
    "        if self.N_nm > 0:\n",
    "            self.weight_readout = nn.Parameter(torch.Tensor(self.out_size, self.hidden_size, self.N_nm))\n",
    "            if self.hidden_size > 0 : \n",
    "                 nn.init.kaiming_uniform_(self.weight_readout, a=1/(math.sqrt(self.hidden_size) if self.hidden_size > 0 else 1.0))\n",
    "            else:\n",
    "                 nn.init.zeros_(self.weight_readout)\n",
    "        else: \n",
    "            self.fc_readout_no_nm = nn.Linear(self.hidden_size, self.out_size)\n",
    "\n",
    "\n",
    "    def forward(self, input_val, initH): \n",
    "        expected_initH_features = self.hidden_size + self.N_nm\n",
    "        if initH.shape[-1] != expected_initH_features:\n",
    "             raise ValueError(f\"Initial hidden state feature dimension mismatch. Expected {expected_initH_features}, got {initH.shape[-1]}\")\n",
    "\n",
    "        inputs_unbound = input_val.unbind(0) \n",
    "        current_hidden_combined = initH \n",
    "        hidden_all_history = []\n",
    "        outputs_readout_list = []\n",
    "\n",
    "        for i in range(len(inputs_unbound)):\n",
    "            current_hidden_combined = self.rnncell(inputs_unbound[i], current_hidden_combined) \n",
    "            current_combined_state_squeezed = current_hidden_combined.squeeze(0)\n",
    "            \n",
    "            if self.N_nm > 0:\n",
    "                rates = current_combined_state_squeezed[:, :self.hidden_size] \n",
    "                nm_state = current_combined_state_squeezed[:, self.hidden_size:]    \n",
    "                if nm_state.shape[1] != self.N_nm: \n",
    "                    raise ValueError(f\"NM state slice incorrect in s_nmRNNLayer. Expected {self.N_nm} features, got {nm_state.shape[1]}\")\n",
    "                outputs_readout_list.append(torch.einsum('bj,ijk,bk->bi', rates, self.weight_readout, nm_state))\n",
    "            else: \n",
    "                rates = current_combined_state_squeezed \n",
    "                outputs_readout_list.append(self.fc_readout_no_nm(rates))\n",
    "\n",
    "            hidden_all_history.append(current_combined_state_squeezed) \n",
    "\n",
    "        outputs_stacked = torch.stack(outputs_readout_list, dim=0) \n",
    "        history_stacked = torch.stack(hidden_all_history, dim=0) \n",
    "        return outputs_stacked, history_stacked, current_hidden_combined\n",
    "\n",
    "\n",
    "class Model_nm(nn.Module):\n",
    "    def __init__(self, hp, RNNLayer_class): \n",
    "        super().__init__()\n",
    "        n_input = hp['n_input']\n",
    "        n_rnn = hp['n_rnn'] \n",
    "        n_output = hp['n_output'] \n",
    "        decay = hp['decay']\n",
    "        N_NM = hp['N_NM']\n",
    "        bias = hp.get('bias', True) \n",
    "        keepW0 = hp.get('keepW0', True) \n",
    "        clip_value = hp.get('grad_clip', None) \n",
    "        use_spatial_net = hp.get('use_spatial_net', True) \n",
    "        activation_str = hp.get('activation', 'relu') \n",
    "\n",
    "        if activation_str == 'relu':\n",
    "            nonlinearity = nn.ReLU()\n",
    "        elif activation_str == 'tanh':\n",
    "            nonlinearity = nn.Tanh()\n",
    "        else:\n",
    "            raise ValueError(f\"Unsupported activation function: {activation_str}\")\n",
    "\n",
    "        self.n_rnn = n_rnn\n",
    "        self.N_NM = N_NM\n",
    "\n",
    "        self.rnn = RNNLayer_class(N_NM, n_input, n_rnn, n_output, nonlinearity, decay, bias=bias, keepW0=keepW0, use_spatial_net=use_spatial_net)\n",
    "\n",
    "        if clip_value is not None and clip_value > 0:\n",
    "            for p in self.parameters():\n",
    "                if p.requires_grad:\n",
    "                    p.register_hook(lambda grad: torch.clamp(grad, -clip_value, clip_value) if grad is not None else None)\n",
    "\n",
    "\n",
    "    def forward(self, x, device='cpu'): # x: (Time, batch, input_size)\n",
    "        batch_size = x.shape[1]\n",
    "        hidden0_rnn = torch.zeros(1, batch_size, self.n_rnn, device=device)\n",
    "        \n",
    "        if self.N_NM > 0:\n",
    "            nm0 = torch.zeros(1, batch_size, self.N_NM, device=device)\n",
    "            hiddenCombined0_for_layer = torch.cat([hidden0_rnn, nm0], dim=2)\n",
    "        else: \n",
    "            hiddenCombined0_for_layer = hidden0_rnn \n",
    "        \n",
    "        output_readout, hiddenCombined_seq, final_hc_state = self.rnn(x, hiddenCombined0_for_layer)\n",
    "        \n",
    "        hidden_rnn_seq = hiddenCombined_seq[:, :, :self.n_rnn] \n",
    "        if self.N_NM > 0:\n",
    "            nm_seq = hiddenCombined_seq[:, :, self.n_rnn : self.n_rnn + self.N_NM]\n",
    "        else: \n",
    "            nm_seq = torch.empty(hiddenCombined_seq.shape[0], hiddenCombined_seq.shape[1], 0, device=device)\n",
    "\n",
    "        return output_readout, hidden_rnn_seq, nm_seq\n",
    "\n",
    "\n",
    "class nmRNN_Adapter(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, output_size, dt, N_NM, tau_nmrnn, use_spatial_net, activation='relu', bias=True, keepW0=True, grad_clip_nmrnn=None): \n",
    "        super().__init__()\n",
    "        self.dt = float(dt) \n",
    "        self.tau_nmrnn = float(tau_nmrnn) \n",
    "        if self.dt <= 0: raise ValueError(\"dt must be positive.\")\n",
    "        if self.tau_nmrnn <= 0: raise ValueError(\"tau_nmrnn must be positive for decay calculation.\")\n",
    "        \n",
    "        decay_val = math.exp(-self.dt / self.tau_nmrnn) \n",
    "\n",
    "        hp = {\n",
    "            'n_input': input_size,\n",
    "            'n_rnn': hidden_size, \n",
    "            'n_output': output_size, \n",
    "            'decay': decay_val,\n",
    "            'N_NM': N_NM,\n",
    "            'activation': activation,\n",
    "            'bias': bias,\n",
    "            'keepW0': keepW0,\n",
    "            'grad_clip': grad_clip_nmrnn, \n",
    "            'use_spatial_net': use_spatial_net\n",
    "        }\n",
    "        self.model_nm = Model_nm(hp, s_nmRNNLayer) \n",
    "\n",
    "    def forward(self, x, hidden_state_ignored=None): # x: (batch, seq, features)\n",
    "        current_device = x.device\n",
    "        x_transposed = x.transpose(0, 1) # (seq, batch, features)\n",
    "        \n",
    "        output_readout_seq_timefirst, hidden_rnn_seq_timefirst, _ = self.model_nm(x_transposed, device=current_device)\n",
    "        \n",
    "        output_final_batchfirst = output_readout_seq_timefirst.transpose(0, 1) \n",
    "        hidden_to_return_batchfirst = hidden_rnn_seq_timefirst.transpose(0,1) \n",
    "        \n",
    "        return output_final_batchfirst, hidden_to_return_batchfirst\n",
    "\n",
    "\n",
    "# --- Main script execution starts here ---\n",
    "def set_seed(seed):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "    print(f\"Running with seed {seed}!\")\n",
    "\n",
    "set_seed(args.seed)\n",
    "\n",
    "# --- Create Save Directory ---\n",
    "path_str = f'files_seed{args.seed}_NNM{args.N_NM}_tau{args.tau_nmrnn}_hs{args.hidden_size}_lr{args.lr_model}_ctx{args.lr_context}_init{args.initial_train_steps}_ad{args.adapt_steps}'\n",
    "path = Path('.') / path_str.replace('.', 'p') \n",
    "os.makedirs(path, exist_ok=True)\n",
    "print(f\"Saving results to: {path.resolve()}\")\n",
    "\n",
    "# --- Environment Setup ---\n",
    "kwargs_env = {'dt': 100.0} \n",
    "all_tasks_collection_names = ngym.get_collection('yang19') \n",
    "\n",
    "if not all_tasks_collection_names or len(all_tasks_collection_names) < 20:\n",
    "    print(\"Default Yang19 task collection issue. Using a manual list.\")\n",
    "    all_tasks_collection_names = [ \n",
    "        'ContextDecisionMaking-v0', 'DecisionMaking-v0', 'DelayComparison-v0',\n",
    "        'DelayMatchCategory-v0', 'DelayMatchSample-v0', 'Detection-v0',\n",
    "        'DualDecisionMaking-v0', 'GoNogo-v0', 'IntervalDiscrimination-v0',\n",
    "        'MatchCategory-v0', 'MatchSample-v0', 'MultisensoryIntegration-v0',\n",
    "        'PerceptualDecisionMaking-v0', 'ReachingDelayResponse-v0', 'SpatialSuppressMotion-v0',\n",
    "        'AntiReach-v0', 'Countermanding-v0', 'EconomicDecisionMaking-v0', \n",
    "        'MotorTiming-v0', 'ProbabilisticReasoning-v0'\n",
    "    ]\n",
    "    if len(all_tasks_collection_names) < 20:\n",
    "         raise ValueError(\"Fallback Yang19 task list also has less than 20 tasks.\")\n",
    "    try:\n",
    "        gym.make(all_tasks_collection_names[0], **kwargs_env).close()\n",
    "    except Exception as e:\n",
    "        print(f\"Error making fallback task: {e}. Ensure neurogym tasks are registered or list is correct.\")\n",
    "        raise\n",
    "\n",
    "task_indices = list(range(len(all_tasks_collection_names)))\n",
    "random.shuffle(task_indices) \n",
    "\n",
    "train_task_indices = task_indices[:15]\n",
    "adapt_task_indices = task_indices[15:20]\n",
    "\n",
    "train_tasks_names = [all_tasks_collection_names[i] for i in train_task_indices]\n",
    "adapt_tasks_names = [all_tasks_collection_names[i] for i in adapt_task_indices]\n",
    "\n",
    "print(f\"Training on {len(train_tasks_names)} tasks (first 3): {train_tasks_names[:3]}\")\n",
    "print(f\"Adapting to {len(adapt_tasks_names)} tasks: {adapt_tasks_names}\")\n",
    "\n",
    "_temp_env = gym.make(all_tasks_collection_names[0], **kwargs_env)\n",
    "if isinstance(_temp_env.observation_space, gym.spaces.Dict):\n",
    "    raw_ob_size = _temp_env.observation_space['observation'].shape[0]\n",
    "else:\n",
    "    raw_ob_size = _temp_env.observation_space.shape[0]\n",
    "act_size = _temp_env.action_space.n\n",
    "_temp_env.close()\n",
    "\n",
    "num_total_tasks = len(all_tasks_collection_names) \n",
    "model_input_size = raw_ob_size + num_total_tasks\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "print(f\"Device: {device}\")\n",
    "\n",
    "# --- Model Dictionary ---\n",
    "models_to_test = {\n",
    "    \"nmRNN_Spatial\": lambda: nmRNN_Adapter(input_size=model_input_size, hidden_size=args.hidden_size, output_size=act_size, dt=kwargs_env['dt'],\n",
    "                                           N_NM=args.N_NM, tau_nmrnn=args.tau_nmrnn, use_spatial_net=True, grad_clip_nmrnn=args.grad_clip_nmrnn).to(device),\n",
    "    \"nmRNN_NonSpatial\": lambda: nmRNN_Adapter(input_size=model_input_size, hidden_size=args.hidden_size, output_size=act_size, dt=kwargs_env['dt'],\n",
    "                                              N_NM=args.N_NM, tau_nmrnn=args.tau_nmrnn, use_spatial_net=False, grad_clip_nmrnn=args.grad_clip_nmrnn).to(device),\n",
    "    \"VanillaRNN\": lambda: RNNNet(input_size=model_input_size, hidden_size=args.hidden_size, output_size=act_size, dt=kwargs_env['dt']).to(device),\n",
    "    \"LSTM\": lambda: LSTMNet(input_size=model_input_size, hidden_size=args.hidden_size, output_size=act_size, dt=kwargs_env['dt']).to(device),\n",
    "    \"Transformer\": lambda: TransformerNet(input_size=model_input_size, hidden_size=args.hidden_size, output_size=act_size, dt=kwargs_env['dt'], nhead=4, num_layers=2).to(device),\n",
    "}\n",
    "if args.N_NM == 0: \n",
    "    print(\"N_NM is 0, running only nmRNN_NonSpatial (effectively a standard RNN with nmRNN structure but no NMs).\")\n",
    "    models_to_test.pop(\"nmRNN_Spatial\", None) \n",
    "    if \"nmRNN_NonSpatial\" in models_to_test: # Rename for clarity\n",
    "        models_to_test[\"nmRNN_N0\"] = models_to_test.pop(\"nmRNN_NonSpatial\")\n",
    "\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "# --- Results Storage ---\n",
    "initial_training_performance_curves = {model_name: [] for model_name in models_to_test}\n",
    "adaptation_performance_curves = {model_name: {task_name: [] for task_name in adapt_tasks_names} for model_name in models_to_test}\n",
    "summary_table_data = []\n",
    "\n",
    "\n",
    "# --- Custom Dataset for Training (Phase 1) ---\n",
    "class Yang19TrainDataset:\n",
    "    def __init__(self, task_names_to_train, all_task_names_ordered_list, batch_size, seq_len, env_kwargs_dict):\n",
    "        self.batch_size = batch_size\n",
    "        self.seq_len = seq_len\n",
    "        self.task_names_to_train = task_names_to_train\n",
    "        self.all_task_names_ordered_list = all_task_names_ordered_list\n",
    "        self.num_total_tasks = len(self.all_task_names_ordered_list)\n",
    "        self.env_kwargs_dict = env_kwargs_dict\n",
    "        \n",
    "        self.envs_dict = {}\n",
    "        valid_task_names_for_training = []\n",
    "        for name in self.task_names_to_train:\n",
    "            try:\n",
    "                self.envs_dict[name] = gym.make(name, **self.env_kwargs_dict)\n",
    "                valid_task_names_for_training.append(name)\n",
    "            except Exception as e:\n",
    "                print(f\"Could not make environment {name} for training dataset: {e}. It will be excluded.\")\n",
    "        \n",
    "        self.task_names_to_train = valid_task_names_for_training # Update to only include valid tasks\n",
    "        self.task_global_indices = {name: self.all_task_names_ordered_list.index(name) for name in self.task_names_to_train}\n",
    "\n",
    "        if not self.task_names_to_train or not self.envs_dict:\n",
    "            raise ValueError(\"No valid environments could be created for Yang19TrainDataset.\")\n",
    "        self.current_task_name = random.choice(list(self.envs_dict.keys()))\n",
    "\n",
    "\n",
    "    def __call__(self):\n",
    "        inputs_batch = []\n",
    "        labels_batch = []\n",
    "        for _ in range(self.batch_size):\n",
    "            chosen_task_name = random.choice(list(self.envs_dict.keys())) \n",
    "            current_env = self.envs_dict[chosen_task_name]\n",
    "            task_global_idx = self.task_global_indices[chosen_task_name]\n",
    "\n",
    "            raw_obs_list_item = []\n",
    "            gt_actions_list_item = []\n",
    "            \n",
    "            env_obs = current_env.reset()\n",
    "            current_raw_obs = env_obs['observation'] if isinstance(env_obs, dict) and 'observation' in env_obs else env_obs\n",
    "\n",
    "            for step_num in range(self.seq_len):\n",
    "                raw_obs_list_item.append(current_raw_obs)\n",
    "                action = current_env.action_space.sample() \n",
    "                env_obs, _, done, info = current_env.step(action)\n",
    "                current_raw_obs = env_obs['observation'] if isinstance(env_obs, dict) and 'observation' in env_obs else env_obs\n",
    "                \n",
    "                if 'gt' not in info: \n",
    "                    gt_actions_list_item.append(-1) \n",
    "                else:\n",
    "                    gt_actions_list_item.append(info['gt'])\n",
    "                \n",
    "                if done: \n",
    "                    padding_needed = self.seq_len - (step_num + 1)\n",
    "                    if padding_needed > 0:\n",
    "                        raw_obs_list_item.extend([current_raw_obs] * padding_needed)\n",
    "                        last_gt = gt_actions_list_item[-1] if gt_actions_list_item else -1\n",
    "                        gt_actions_list_item.extend([last_gt] * padding_needed)\n",
    "                    break\n",
    "            \n",
    "            raw_obs_seq_np = np.array(raw_obs_list_item) \n",
    "            context_vec = np.zeros((self.seq_len, self.num_total_tasks), dtype=np.float32) \n",
    "            context_vec[:, task_global_idx] = 1.0\n",
    "            model_input_seq = np.concatenate((raw_obs_seq_np.astype(np.float32), context_vec), axis=-1)\n",
    "            \n",
    "            inputs_batch.append(model_input_seq)\n",
    "            labels_batch.append(np.array(gt_actions_list_item))\n",
    "\n",
    "        return np.array(inputs_batch), np.array(labels_batch)\n",
    "\n",
    "    def close_envs(self):\n",
    "        for env in self.envs_dict.values():\n",
    "            env.close()\n",
    "\n",
    "# --- Custom Dataset for Adaptation (Phase 2) ---\n",
    "class SingleTaskDataset:\n",
    "    def __init__(self, env_instance, batch_size, seq_len): \n",
    "        self.env = env_instance\n",
    "        self.batch_size = batch_size\n",
    "        self.seq_len = seq_len\n",
    "\n",
    "    def __call__(self):\n",
    "        obs_batch = []\n",
    "        labels_batch = []\n",
    "        for _ in range(self.batch_size):\n",
    "            obs_list_item = []\n",
    "            gt_list_item = []\n",
    "            env_obs = self.env.reset()\n",
    "            current_raw_obs = env_obs['observation'] if isinstance(env_obs, dict) and 'observation' in env_obs else env_obs\n",
    "\n",
    "            for step_num in range(self.seq_len):\n",
    "                obs_list_item.append(current_raw_obs)\n",
    "                action = self.env.action_space.sample() \n",
    "                env_obs, _, done, info = self.env.step(action)\n",
    "                current_raw_obs = env_obs['observation'] if isinstance(env_obs, dict) and 'observation' in env_obs else env_obs\n",
    "                \n",
    "                if 'gt' not in info:\n",
    "                    gt_list_item.append(-1)\n",
    "                else:\n",
    "                    gt_list_item.append(info['gt'])\n",
    "\n",
    "                if done:\n",
    "                    padding_needed = self.seq_len - (step_num + 1)\n",
    "                    if padding_needed > 0:\n",
    "                        obs_list_item.extend([current_raw_obs] * padding_needed)\n",
    "                        last_gt = gt_list_item[-1] if gt_list_item else -1\n",
    "                        gt_list_item.extend([last_gt] * padding_needed)\n",
    "                    break\n",
    "            obs_batch.append(np.array(obs_list_item))\n",
    "            labels_batch.append(np.array(gt_list_item))\n",
    "        return np.array(obs_batch, dtype=np.float32), np.array(labels_batch) \n",
    "\n",
    "\n",
    "# --- Main Loop for Each Model Type ---\n",
    "for model_name, model_constructor in models_to_test.items():\n",
    "    print(f\"\\n\\n--- Training and Adapting Model: {model_name} ---\")\n",
    "    model = model_constructor() \n",
    "    # print(model) # Optional: print model structure\n",
    "    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "    print(f\"Total trainable parameters for {model_name}: {total_params:,}\")\n",
    "\n",
    "    # === Phase 1: Initial Training ===\n",
    "    print(f\"\\n--- Phase 1: Initial Training on {len(train_tasks_names)} tasks for {model_name} ---\")\n",
    "    try:\n",
    "        dataset_phase1 = Yang19TrainDataset(train_tasks_names, all_tasks_collection_names, args.batch_size, args.seq_len, env_kwargs_dict=kwargs_env)\n",
    "    except ValueError as e:\n",
    "        print(f\"Error initializing dataset for Phase 1 for model {model_name}: {e}. Skipping this model.\")\n",
    "        continue # Skip to the next model if dataset fails\n",
    "\n",
    "    optimizer_model = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr_model)\n",
    "    \n",
    "    running_loss = 0.0\n",
    "    current_initial_train_perf_curve = []\n",
    "    time_phase1_start = time.time()\n",
    "\n",
    "    for i in range(args.initial_train_steps):\n",
    "        model.train() # Ensure model is in training mode for this phase\n",
    "        inputs, labels = dataset_phase1() \n",
    "        inputs = torch.from_numpy(inputs).type(torch.float).to(device) \n",
    "        labels_flat = torch.from_numpy(labels.flatten()).type(torch.long).to(device)\n",
    "\n",
    "        valid_indices = labels_flat != -1\n",
    "        if not valid_indices.any(): \n",
    "            continue\n",
    "        \n",
    "        optimizer_model.zero_grad()\n",
    "        outputs, _ = model(inputs) \n",
    "        \n",
    "        outputs_view = outputs.view(-1, act_size) \n",
    "        outputs_for_loss = outputs_view[valid_indices]\n",
    "        labels_filtered = labels_flat[valid_indices]\n",
    "        \n",
    "        if outputs_for_loss.shape[0] == 0: \n",
    "            continue\n",
    "\n",
    "        loss = criterion(outputs_for_loss, labels_filtered)\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, model.parameters()), 1.0) \n",
    "        optimizer_model.step()\n",
    "        running_loss += loss.item()\n",
    "\n",
    "        if i % args.print_step == (args.print_step - 1) or i == args.initial_train_steps -1 :\n",
    "            steps_in_epoch = args.print_step if i % args.print_step == (args.print_step -1) else (i % args.print_step) + 1\n",
    "            avg_loss = running_loss / steps_in_epoch if steps_in_epoch > 0 else running_loss\n",
    "            print(f'Model: {model_name}, Initial Train Step: {i + 1}/{args.initial_train_steps}, Loss: {avg_loss:.4f}')\n",
    "            if i % args.print_step == (args.print_step -1): running_loss = 0.0\n",
    "\n",
    "            avg_perf_on_train_tasks = 0.0\n",
    "            temp_perfs = []\n",
    "            num_eval_tasks = len(dataset_phase1.task_names_to_train) # Use actual number of valid tasks\n",
    "            trials_per_eval_task = max(1, args.eval_trials // num_eval_tasks if num_eval_tasks > 0 else args.eval_trials)\n",
    "\n",
    "            for train_task_name_eval in dataset_phase1.task_names_to_train: \n",
    "                task_global_idx = all_tasks_collection_names.index(train_task_name_eval)\n",
    "                context_val = np.zeros(num_total_tasks, dtype=np.float32)\n",
    "                context_val[task_global_idx] = 1.0\n",
    "                \n",
    "                try:\n",
    "                    eval_env_train = gym.make(train_task_name_eval, **kwargs_env)\n",
    "                    # get_performance will set model.eval()\n",
    "                    perf = get_performance(model, eval_env_train, num_trial=trials_per_eval_task, \n",
    "                                           device=device, context_vector_global=context_val, \n",
    "                                           raw_obs_size=raw_ob_size, seq_len_eval=args.seq_len)\n",
    "                    temp_perfs.append(perf)\n",
    "                    eval_env_train.close()\n",
    "                except Exception as e:\n",
    "                    print(f\"Error evaluating task {train_task_name_eval} during training: {e}\")\n",
    "            \n",
    "            if temp_perfs: avg_perf_on_train_tasks = np.mean(temp_perfs)\n",
    "            current_initial_train_perf_curve.append(avg_perf_on_train_tasks)\n",
    "            print(f'Model: {model_name}, Initial Train Step: {i + 1}, Avg Perf on Train Set: {avg_perf_on_train_tasks:.3f}')\n",
    "    \n",
    "    time_phase1_end = time.time()\n",
    "    print(f\"Finished Initial Training for {model_name} in {time_phase1_end - time_phase1_start:.2f} seconds.\")\n",
    "    initial_training_performance_curves[model_name] = current_initial_train_perf_curve\n",
    "    trained_model_state_dict = copy.deepcopy(model.state_dict()) \n",
    "    dataset_phase1.close_envs() \n",
    "\n",
    "\n",
    "    # === Phase 2: Adaptation ===\n",
    "    print(f\"\\n--- Phase 2: Adaptation for {model_name} on {len(adapt_tasks_names)} tasks ---\")\n",
    "    time_phase2_start = time.time()\n",
    "    \n",
    "    for adapt_task_name in adapt_tasks_names:\n",
    "        print(f\"\\n-- Adapting {model_name} to task: {adapt_task_name} --\")\n",
    "        model.load_state_dict(trained_model_state_dict) \n",
    "        \n",
    "        for param in model.parameters():\n",
    "            param.requires_grad = False \n",
    "\n",
    "        trainable_context = torch.randn(num_total_tasks, device=device, requires_grad=True) \n",
    "        optimizer_context = torch.optim.Adam([trainable_context], lr=args.lr_context)\n",
    "\n",
    "        try:\n",
    "            adapt_env_instance = gym.make(adapt_task_name, **kwargs_env)\n",
    "        except Exception as e:\n",
    "            print(f\"Could not make environment {adapt_task_name} for adaptation. Skipping. Error: {e}\")\n",
    "            adaptation_performance_curves[model_name][adapt_task_name] = [] \n",
    "            summary_table_data.append({\n",
    "                \"Model\": model_name, \"Adapted Task\": adapt_task_name,\n",
    "                \"Best Adapted Performance\": 0.0, \"Final Adapted Context Sum\": 0.0\n",
    "            })\n",
    "            continue \n",
    "\n",
    "        adapt_dataset = SingleTaskDataset(adapt_env_instance, args.batch_size, args.seq_len)\n",
    "        \n",
    "        current_adapt_perf_curve = []\n",
    "        best_adapted_perf = 0.0\n",
    "        running_adapt_loss = 0.0\n",
    "\n",
    "        for adapt_step in range(args.adapt_steps):\n",
    "            model.train() \n",
    "            \n",
    "            raw_observations, labels = adapt_dataset() \n",
    "            raw_observations = torch.from_numpy(raw_observations).type(torch.float).to(device) \n",
    "            labels_flat = torch.from_numpy(labels.flatten()).type(torch.long).to(device)\n",
    "\n",
    "            valid_indices_adapt = labels_flat != -1\n",
    "            if not valid_indices_adapt.any(): continue\n",
    "\n",
    "            context_expanded = trainable_context.unsqueeze(0).unsqueeze(0).repeat(raw_observations.shape[0], raw_observations.shape[1], 1)\n",
    "            model_inputs = torch.cat((raw_observations, context_expanded), dim=-1)\n",
    "\n",
    "            optimizer_context.zero_grad()\n",
    "            outputs, _ = model(model_inputs) \n",
    "            \n",
    "            outputs_view_adapt = outputs.view(-1, act_size)\n",
    "            outputs_for_loss_adapt = outputs_view_adapt[valid_indices_adapt]\n",
    "            labels_filtered_adapt = labels_flat[valid_indices_adapt]\n",
    "\n",
    "            if outputs_for_loss_adapt.shape[0] == 0: continue\n",
    "\n",
    "            loss = criterion(outputs_for_loss_adapt, labels_filtered_adapt)\n",
    "            loss.backward() # This will now work with CuDNN RNNs\n",
    "            optimizer_context.step()\n",
    "            running_adapt_loss += loss.item()\n",
    "\n",
    "            if adapt_step % (args.print_step // 2) == (args.print_step // 2 - 1) or adapt_step == args.adapt_steps -1:\n",
    "                steps_in_eval_period = (args.print_step // 2) if adapt_step % (args.print_step // 2) == (args.print_step // 2 - 1) else (adapt_step % (args.print_step//2)) +1\n",
    "                avg_adapt_loss = running_adapt_loss / steps_in_eval_period if steps_in_eval_period > 0 else running_adapt_loss\n",
    "                if adapt_step % (args.print_step // 2) == (args.print_step // 2 - 1): running_adapt_loss = 0.0\n",
    "                \n",
    "                current_ctx_np = trainable_context.detach().cpu().numpy()\n",
    "                \n",
    "                eval_adapt_env_temp = gym.make(adapt_task_name, **kwargs_env) \n",
    "                # get_performance will set model.eval() internally\n",
    "                perf = get_performance(model, eval_adapt_env_temp, num_trial=args.eval_trials, device=device,\n",
    "                                       context_vector_global=current_ctx_np, raw_obs_size=raw_ob_size, seq_len_eval=args.seq_len)\n",
    "                eval_adapt_env_temp.close()\n",
    "                \n",
    "                current_adapt_perf_curve.append(perf)\n",
    "                if perf > best_adapted_perf: best_adapted_perf = perf\n",
    "                \n",
    "                print(f'  Adapt Task: {adapt_task_name}, Step: {adapt_step + 1}/{args.adapt_steps}, Loss: {avg_adapt_loss:.4f}, Perf: {perf:.3f}')\n",
    "        \n",
    "        adaptation_performance_curves[model_name][adapt_task_name] = current_adapt_perf_curve\n",
    "        summary_table_data.append({\n",
    "            \"Model\": model_name,\n",
    "            \"Adapted Task\": adapt_task_name,\n",
    "            \"Best Adapted Performance\": best_adapted_perf,\n",
    "            \"Final Adapted Context Sum\": trainable_context.detach().sum().item() \n",
    "        })\n",
    "        adapt_env_instance.close()\n",
    "    \n",
    "    time_phase2_end = time.time()\n",
    "    print(f\"Finished Adaptation for {model_name} in {time_phase2_end - time_phase2_start:.2f} seconds.\")\n",
    "\n",
    "    del model \n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.empty_cache()\n",
    "\n",
    "\n",
    "# --- Reporting Results ---\n",
    "print(\"\\n\\n--- Overall Results ---\")\n",
    "\n",
    "print(\"\\nInitial Training Performance Curves (Avg Perf on 15 Train Tasks vs. Eval Points):\")\n",
    "for model_name, curve in initial_training_performance_curves.items():\n",
    "    print(f\"Model: {model_name}\")\n",
    "    curve_str = \", \".join([f\"{p:.3f}\" for p in curve])\n",
    "    print(f\"  Performance points: [{curve_str}]\")\n",
    "    np.save(path / f'{model_name}_initial_train_curve.npy', np.array(curve))\n",
    "\n",
    "print(\"\\nAdaptation Training Curves (Perf on Held-out Task vs. Eval Points):\")\n",
    "for model_name, task_curves in adaptation_performance_curves.items():\n",
    "    print(f\"Model: {model_name}\")\n",
    "    for task_name, curve in task_curves.items():\n",
    "        print(f\"  Adapted Task: {task_name}\")\n",
    "        if curve: \n",
    "            curve_str = \", \".join([f\"{p:.3f}\" for p in curve])\n",
    "            print(f\"    Performance points: [{curve_str}]\")\n",
    "            np.save(path / f'{model_name}_adapt_curve_{task_name.replace(\" \", \"_\").replace(\"-v0\", \"\")}.npy', np.array(curve))\n",
    "        else:\n",
    "            print(f\"    No performance data recorded (task might have failed to initialize).\")\n",
    "\n",
    "\n",
    "print(\"\\nSummary Table of Best Adapted Performance:\")\n",
    "if summary_table_data:\n",
    "    summary_df = pd.DataFrame(summary_table_data)\n",
    "    print(summary_df.to_string()) \n",
    "    summary_df.to_csv(path / 'adaptation_summary.csv', index=False)\n",
    "else:\n",
    "    print(\"No summary data to report.\")\n",
    "\n",
    "print(f\"\\nFinished Adaptability Analysis. Results saved in {path.resolve()}\")\n",
    "print(\"To visualize curves, load the .npy files or use the printed lists/DataFrame with a plotting library.\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "597e2872-8452-431f-9334-93fed495b31a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running with seed 0!\n",
      "Saving results to: /root/capsule/code/files_seed0_NNM4_tau0p05_hs256_lr0p001_ctx0p01_init10000_ad1000\n",
      "Training on 15 tasks (first 3): ['yang19.multidm-v0', 'yang19.dmc-v0', 'yang19.dms-v0']\n",
      "Adapting to 5 tasks: ['yang19.multidlydm-v0', 'yang19.ctxdm1-v0', 'yang19.rtgo-v0', 'yang19.ctxdlydm1-v0', 'yang19.dlydm2-v0']\n",
      "Device: cuda\n",
      "\n",
      "\n",
      "--- Training and Adapting Model: nmRNN_Spatial ---\n",
      "Total trainable parameters for nmRNN_Spatial: 359,952\n",
      "\n",
      "--- Phase 1: Initial Training on 15 tasks for nmRNN_Spatial ---\n",
      "Model: nmRNN_Spatial, Initial Train Step: 200/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 200, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_Spatial, Initial Train Step: 400/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 400, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_Spatial, Initial Train Step: 600/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 600, Avg Perf on Train Set: 0.919\n",
      "Model: nmRNN_Spatial, Initial Train Step: 800/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 800, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_Spatial, Initial Train Step: 1000/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 1000, Avg Perf on Train Set: 0.924\n",
      "Model: nmRNN_Spatial, Initial Train Step: 1200/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 1200, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_Spatial, Initial Train Step: 1400/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 1400, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_Spatial, Initial Train Step: 1600/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 1600, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_Spatial, Initial Train Step: 1800/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 1800, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_Spatial, Initial Train Step: 2000/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 2000, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_Spatial, Initial Train Step: 2200/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 2200, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_Spatial, Initial Train Step: 2400/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 2400, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_Spatial, Initial Train Step: 2600/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 2600, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_Spatial, Initial Train Step: 2800/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 2800, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_Spatial, Initial Train Step: 3000/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 3000, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_Spatial, Initial Train Step: 3200/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 3200, Avg Perf on Train Set: 0.923\n",
      "Model: nmRNN_Spatial, Initial Train Step: 3400/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 3400, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_Spatial, Initial Train Step: 3600/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 3600, Avg Perf on Train Set: 0.923\n",
      "Model: nmRNN_Spatial, Initial Train Step: 3800/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 3800, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_Spatial, Initial Train Step: 4000/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 4000, Avg Perf on Train Set: 0.919\n",
      "Model: nmRNN_Spatial, Initial Train Step: 4200/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 4200, Avg Perf on Train Set: 0.919\n",
      "Model: nmRNN_Spatial, Initial Train Step: 4400/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 4400, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_Spatial, Initial Train Step: 4600/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 4600, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_Spatial, Initial Train Step: 4800/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 4800, Avg Perf on Train Set: 0.923\n",
      "Model: nmRNN_Spatial, Initial Train Step: 5000/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 5000, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_Spatial, Initial Train Step: 5200/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 5200, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_Spatial, Initial Train Step: 5400/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 5400, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_Spatial, Initial Train Step: 5600/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 5600, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_Spatial, Initial Train Step: 5800/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 5800, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_Spatial, Initial Train Step: 6000/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 6000, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_Spatial, Initial Train Step: 6200/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 6200, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_Spatial, Initial Train Step: 6400/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 6400, Avg Perf on Train Set: 0.919\n",
      "Model: nmRNN_Spatial, Initial Train Step: 6600/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 6600, Avg Perf on Train Set: 0.923\n",
      "Model: nmRNN_Spatial, Initial Train Step: 6800/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 6800, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_Spatial, Initial Train Step: 7000/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 7000, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_Spatial, Initial Train Step: 7200/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 7200, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_Spatial, Initial Train Step: 7400/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 7400, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_Spatial, Initial Train Step: 7600/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 7600, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_Spatial, Initial Train Step: 7800/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 7800, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_Spatial, Initial Train Step: 8000/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 8000, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_Spatial, Initial Train Step: 8200/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 8200, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_Spatial, Initial Train Step: 8400/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 8400, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_Spatial, Initial Train Step: 8600/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 8600, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_Spatial, Initial Train Step: 8800/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 8800, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_Spatial, Initial Train Step: 9000/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 9000, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_Spatial, Initial Train Step: 9200/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 9200, Avg Perf on Train Set: 0.919\n",
      "Model: nmRNN_Spatial, Initial Train Step: 9400/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 9400, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_Spatial, Initial Train Step: 9600/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 9600, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_Spatial, Initial Train Step: 9800/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 9800, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_Spatial, Initial Train Step: 10000/10000, Loss: 2.8332\n",
      "Model: nmRNN_Spatial, Initial Train Step: 10000, Avg Perf on Train Set: 0.922\n",
      "Finished Initial Training for nmRNN_Spatial in 4738.57 seconds.\n",
      "\n",
      "--- Phase 2: Adaptation for nmRNN_Spatial on 5 tasks ---\n",
      "\n",
      "-- Adapting nmRNN_Spatial to task: yang19.multidlydm-v0 --\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 100/1000, Loss: 2.8332, Perf: 0.955\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 200/1000, Loss: 2.8332, Perf: 0.955\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 300/1000, Loss: 2.8332, Perf: 0.955\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 400/1000, Loss: 2.8332, Perf: 0.955\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 500/1000, Loss: 2.8332, Perf: 0.954\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 600/1000, Loss: 2.8332, Perf: 0.954\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 700/1000, Loss: 2.8332, Perf: 0.954\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 800/1000, Loss: 2.8332, Perf: 0.954\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 900/1000, Loss: 2.8332, Perf: 0.955\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 1000/1000, Loss: 2.8332, Perf: 0.955\n",
      "\n",
      "-- Adapting nmRNN_Spatial to task: yang19.ctxdm1-v0 --\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 100/1000, Loss: 2.8332, Perf: 0.881\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 200/1000, Loss: 2.8332, Perf: 0.882\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 300/1000, Loss: 2.8332, Perf: 0.882\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 400/1000, Loss: 2.8332, Perf: 0.881\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 500/1000, Loss: 2.8332, Perf: 0.882\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 600/1000, Loss: 2.8332, Perf: 0.882\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 700/1000, Loss: 2.8332, Perf: 0.882\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 800/1000, Loss: 2.8332, Perf: 0.882\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 900/1000, Loss: 2.8332, Perf: 0.880\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 1000/1000, Loss: 2.8332, Perf: 0.881\n",
      "\n",
      "-- Adapting nmRNN_Spatial to task: yang19.rtgo-v0 --\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 100/1000, Loss: 2.8332, Perf: 0.800\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 200/1000, Loss: 2.8332, Perf: 0.800\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 300/1000, Loss: 2.8332, Perf: 0.800\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 400/1000, Loss: 2.8332, Perf: 0.800\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 500/1000, Loss: 2.8332, Perf: 0.800\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 600/1000, Loss: 2.8332, Perf: 0.800\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 700/1000, Loss: 2.8332, Perf: 0.800\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 800/1000, Loss: 2.8332, Perf: 0.800\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 900/1000, Loss: 2.8332, Perf: 0.800\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 1000/1000, Loss: 2.8332, Perf: 0.800\n",
      "\n",
      "-- Adapting nmRNN_Spatial to task: yang19.ctxdlydm1-v0 --\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 100/1000, Loss: 2.8332, Perf: 0.954\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 200/1000, Loss: 2.8332, Perf: 0.954\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 300/1000, Loss: 2.8332, Perf: 0.955\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 400/1000, Loss: 2.8332, Perf: 0.955\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 500/1000, Loss: 2.8332, Perf: 0.955\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 600/1000, Loss: 2.8332, Perf: 0.954\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 700/1000, Loss: 2.8332, Perf: 0.954\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 800/1000, Loss: 2.8332, Perf: 0.956\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 900/1000, Loss: 2.8332, Perf: 0.956\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 1000/1000, Loss: 2.8332, Perf: 0.955\n",
      "\n",
      "-- Adapting nmRNN_Spatial to task: yang19.dlydm2-v0 --\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 100/1000, Loss: 2.8332, Perf: 0.953\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 200/1000, Loss: 2.8332, Perf: 0.956\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 300/1000, Loss: 2.8332, Perf: 0.956\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 400/1000, Loss: 2.8332, Perf: 0.954\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 500/1000, Loss: 2.8332, Perf: 0.954\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 600/1000, Loss: 2.8332, Perf: 0.955\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 700/1000, Loss: 2.8332, Perf: 0.955\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 800/1000, Loss: 2.8332, Perf: 0.954\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 900/1000, Loss: 2.8332, Perf: 0.954\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 1000/1000, Loss: 2.8332, Perf: 0.956\n",
      "Finished Adaptation for nmRNN_Spatial in 2493.21 seconds.\n",
      "\n",
      "\n",
      "--- Training and Adapting Model: nmRNN_NonSpatial ---\n",
      "Total trainable parameters for nmRNN_NonSpatial: 359,952\n",
      "\n",
      "--- Phase 1: Initial Training on 15 tasks for nmRNN_NonSpatial ---\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 200/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 200, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 400/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 400, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 600/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 600, Avg Perf on Train Set: 0.919\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 800/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 800, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 1000/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 1000, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 1200/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 1200, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 1400/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 1400, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 1600/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 1600, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 1800/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 1800, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 2000/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 2000, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 2200/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 2200, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 2400/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 2400, Avg Perf on Train Set: 0.919\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 2600/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 2600, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 2800/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 2800, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 3000/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 3000, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 3200/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 3200, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 3400/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 3400, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 3600/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 3600, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 3800/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 3800, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 4000/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 4000, Avg Perf on Train Set: 0.923\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 4200/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 4200, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 4400/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 4400, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 4600/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 4600, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 4800/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 4800, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 5000/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 5000, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 5200/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 5200, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 5400/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 5400, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 5600/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 5600, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 5800/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 5800, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 6000/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 6000, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 6200/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 6200, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 6400/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 6400, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 6600/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 6600, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 6800/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 6800, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 7000/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 7000, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 7200/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 7200, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 7400/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 7400, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 7600/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 7600, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 7800/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 7800, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 8000/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 8000, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 8200/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 8200, Avg Perf on Train Set: 0.920\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 8400/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 8400, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 8600/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 8600, Avg Perf on Train Set: 0.923\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 8800/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 8800, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 9000/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 9000, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 9200/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 9200, Avg Perf on Train Set: 0.919\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 9400/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 9400, Avg Perf on Train Set: 0.922\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 9600/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 9600, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 9800/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 9800, Avg Perf on Train Set: 0.921\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 10000/10000, Loss: nan\n",
      "Model: nmRNN_NonSpatial, Initial Train Step: 10000, Avg Perf on Train Set: 0.919\n",
      "Finished Initial Training for nmRNN_NonSpatial in 4425.05 seconds.\n",
      "\n",
      "--- Phase 2: Adaptation for nmRNN_NonSpatial on 5 tasks ---\n",
      "\n",
      "-- Adapting nmRNN_NonSpatial to task: yang19.multidlydm-v0 --\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 100/1000, Loss: nan, Perf: 0.954\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 200/1000, Loss: nan, Perf: 0.955\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 300/1000, Loss: nan, Perf: 0.956\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 400/1000, Loss: nan, Perf: 0.956\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 500/1000, Loss: nan, Perf: 0.954\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 600/1000, Loss: nan, Perf: 0.955\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 700/1000, Loss: nan, Perf: 0.955\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 800/1000, Loss: nan, Perf: 0.955\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 900/1000, Loss: nan, Perf: 0.954\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 1000/1000, Loss: nan, Perf: 0.955\n",
      "\n",
      "-- Adapting nmRNN_NonSpatial to task: yang19.ctxdm1-v0 --\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 100/1000, Loss: nan, Perf: 0.883\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 200/1000, Loss: nan, Perf: 0.882\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 300/1000, Loss: nan, Perf: 0.881\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 400/1000, Loss: nan, Perf: 0.881\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 500/1000, Loss: nan, Perf: 0.883\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 600/1000, Loss: nan, Perf: 0.881\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 700/1000, Loss: nan, Perf: 0.883\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 800/1000, Loss: nan, Perf: 0.883\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 900/1000, Loss: nan, Perf: 0.881\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 1000/1000, Loss: nan, Perf: 0.882\n",
      "\n",
      "-- Adapting nmRNN_NonSpatial to task: yang19.rtgo-v0 --\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 100/1000, Loss: nan, Perf: 0.800\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 200/1000, Loss: nan, Perf: 0.800\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 300/1000, Loss: nan, Perf: 0.800\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 400/1000, Loss: nan, Perf: 0.800\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 500/1000, Loss: nan, Perf: 0.800\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 600/1000, Loss: nan, Perf: 0.800\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 700/1000, Loss: nan, Perf: 0.800\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 800/1000, Loss: nan, Perf: 0.800\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 900/1000, Loss: nan, Perf: 0.800\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 1000/1000, Loss: nan, Perf: 0.800\n",
      "\n",
      "-- Adapting nmRNN_NonSpatial to task: yang19.ctxdlydm1-v0 --\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 100/1000, Loss: nan, Perf: 0.954\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 200/1000, Loss: nan, Perf: 0.954\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 300/1000, Loss: nan, Perf: 0.955\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 400/1000, Loss: nan, Perf: 0.955\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 500/1000, Loss: nan, Perf: 0.955\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 600/1000, Loss: nan, Perf: 0.954\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 700/1000, Loss: nan, Perf: 0.955\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 800/1000, Loss: nan, Perf: 0.954\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 900/1000, Loss: nan, Perf: 0.954\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 1000/1000, Loss: nan, Perf: 0.955\n",
      "\n",
      "-- Adapting nmRNN_NonSpatial to task: yang19.dlydm2-v0 --\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 100/1000, Loss: nan, Perf: 0.954\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 200/1000, Loss: nan, Perf: 0.955\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 300/1000, Loss: nan, Perf: 0.954\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 400/1000, Loss: nan, Perf: 0.955\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 500/1000, Loss: nan, Perf: 0.953\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 600/1000, Loss: nan, Perf: 0.955\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 700/1000, Loss: nan, Perf: 0.954\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 800/1000, Loss: nan, Perf: 0.955\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 900/1000, Loss: nan, Perf: 0.955\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 1000/1000, Loss: nan, Perf: 0.954\n",
      "Finished Adaptation for nmRNN_NonSpatial in 2293.17 seconds.\n",
      "\n",
      "\n",
      "--- Training and Adapting Model: VanillaRNN ---\n",
      "Total trainable parameters for VanillaRNN: 83,985\n",
      "\n",
      "--- Phase 1: Initial Training on 15 tasks for VanillaRNN ---\n",
      "Model: VanillaRNN, Initial Train Step: 200/10000, Loss: 0.4840\n",
      "Model: VanillaRNN, Initial Train Step: 200, Avg Perf on Train Set: 0.923\n",
      "Model: VanillaRNN, Initial Train Step: 400/10000, Loss: 0.2108\n",
      "Model: VanillaRNN, Initial Train Step: 400, Avg Perf on Train Set: 0.951\n",
      "Model: VanillaRNN, Initial Train Step: 600/10000, Loss: 0.1244\n",
      "Model: VanillaRNN, Initial Train Step: 600, Avg Perf on Train Set: 0.967\n",
      "Model: VanillaRNN, Initial Train Step: 800/10000, Loss: 0.0876\n",
      "Model: VanillaRNN, Initial Train Step: 800, Avg Perf on Train Set: 0.980\n",
      "Model: VanillaRNN, Initial Train Step: 1000/10000, Loss: 0.0630\n",
      "Model: VanillaRNN, Initial Train Step: 1000, Avg Perf on Train Set: 0.985\n",
      "Model: VanillaRNN, Initial Train Step: 1200/10000, Loss: 0.0451\n",
      "Model: VanillaRNN, Initial Train Step: 1200, Avg Perf on Train Set: 0.986\n",
      "Model: VanillaRNN, Initial Train Step: 1400/10000, Loss: 0.0370\n",
      "Model: VanillaRNN, Initial Train Step: 1400, Avg Perf on Train Set: 0.990\n",
      "Model: VanillaRNN, Initial Train Step: 1600/10000, Loss: 0.0303\n",
      "Model: VanillaRNN, Initial Train Step: 1600, Avg Perf on Train Set: 0.992\n",
      "Model: VanillaRNN, Initial Train Step: 1800/10000, Loss: 0.0288\n",
      "Model: VanillaRNN, Initial Train Step: 1800, Avg Perf on Train Set: 0.992\n",
      "Model: VanillaRNN, Initial Train Step: 2000/10000, Loss: 0.0243\n",
      "Model: VanillaRNN, Initial Train Step: 2000, Avg Perf on Train Set: 0.995\n",
      "Model: VanillaRNN, Initial Train Step: 2200/10000, Loss: 0.0234\n",
      "Model: VanillaRNN, Initial Train Step: 2200, Avg Perf on Train Set: 0.993\n",
      "Model: VanillaRNN, Initial Train Step: 2400/10000, Loss: 0.0225\n",
      "Model: VanillaRNN, Initial Train Step: 2400, Avg Perf on Train Set: 0.992\n",
      "Model: VanillaRNN, Initial Train Step: 2600/10000, Loss: 0.0206\n",
      "Model: VanillaRNN, Initial Train Step: 2600, Avg Perf on Train Set: 0.994\n",
      "Model: VanillaRNN, Initial Train Step: 2800/10000, Loss: 0.0201\n",
      "Model: VanillaRNN, Initial Train Step: 2800, Avg Perf on Train Set: 0.993\n",
      "Model: VanillaRNN, Initial Train Step: 3000/10000, Loss: 0.0201\n",
      "Model: VanillaRNN, Initial Train Step: 3000, Avg Perf on Train Set: 0.993\n",
      "Model: VanillaRNN, Initial Train Step: 3200/10000, Loss: 0.0173\n",
      "Model: VanillaRNN, Initial Train Step: 3200, Avg Perf on Train Set: 0.994\n",
      "Model: VanillaRNN, Initial Train Step: 3400/10000, Loss: 0.0162\n",
      "Model: VanillaRNN, Initial Train Step: 3400, Avg Perf on Train Set: 0.994\n",
      "Model: VanillaRNN, Initial Train Step: 3600/10000, Loss: 0.0175\n",
      "Model: VanillaRNN, Initial Train Step: 3600, Avg Perf on Train Set: 0.993\n",
      "Model: VanillaRNN, Initial Train Step: 3800/10000, Loss: 0.0158\n",
      "Model: VanillaRNN, Initial Train Step: 3800, Avg Perf on Train Set: 0.995\n",
      "Model: VanillaRNN, Initial Train Step: 4000/10000, Loss: 0.0144\n",
      "Model: VanillaRNN, Initial Train Step: 4000, Avg Perf on Train Set: 0.994\n",
      "Model: VanillaRNN, Initial Train Step: 4200/10000, Loss: 0.0153\n",
      "Model: VanillaRNN, Initial Train Step: 4200, Avg Perf on Train Set: 0.995\n",
      "Model: VanillaRNN, Initial Train Step: 4400/10000, Loss: 0.0157\n",
      "Model: VanillaRNN, Initial Train Step: 4400, Avg Perf on Train Set: 0.995\n",
      "Model: VanillaRNN, Initial Train Step: 4600/10000, Loss: 0.0140\n",
      "Model: VanillaRNN, Initial Train Step: 4600, Avg Perf on Train Set: 0.995\n",
      "Model: VanillaRNN, Initial Train Step: 4800/10000, Loss: 0.0143\n",
      "Model: VanillaRNN, Initial Train Step: 4800, Avg Perf on Train Set: 0.994\n",
      "Model: VanillaRNN, Initial Train Step: 5000/10000, Loss: 0.0129\n",
      "Model: VanillaRNN, Initial Train Step: 5000, Avg Perf on Train Set: 0.995\n",
      "Model: VanillaRNN, Initial Train Step: 5200/10000, Loss: 0.0145\n",
      "Model: VanillaRNN, Initial Train Step: 5200, Avg Perf on Train Set: 0.992\n",
      "Model: VanillaRNN, Initial Train Step: 5400/10000, Loss: 0.0144\n",
      "Model: VanillaRNN, Initial Train Step: 5400, Avg Perf on Train Set: 0.996\n",
      "Model: VanillaRNN, Initial Train Step: 5600/10000, Loss: 0.0133\n",
      "Model: VanillaRNN, Initial Train Step: 5600, Avg Perf on Train Set: 0.992\n",
      "Model: VanillaRNN, Initial Train Step: 5800/10000, Loss: 0.0171\n",
      "Model: VanillaRNN, Initial Train Step: 5800, Avg Perf on Train Set: 0.995\n",
      "Model: VanillaRNN, Initial Train Step: 6000/10000, Loss: 0.0119\n",
      "Model: VanillaRNN, Initial Train Step: 6000, Avg Perf on Train Set: 0.996\n",
      "Model: VanillaRNN, Initial Train Step: 6200/10000, Loss: 0.0112\n",
      "Model: VanillaRNN, Initial Train Step: 6200, Avg Perf on Train Set: 0.997\n",
      "Model: VanillaRNN, Initial Train Step: 6400/10000, Loss: 0.0133\n",
      "Model: VanillaRNN, Initial Train Step: 6400, Avg Perf on Train Set: 0.995\n",
      "Model: VanillaRNN, Initial Train Step: 6600/10000, Loss: 0.0108\n",
      "Model: VanillaRNN, Initial Train Step: 6600, Avg Perf on Train Set: 0.996\n",
      "Model: VanillaRNN, Initial Train Step: 6800/10000, Loss: 0.0129\n",
      "Model: VanillaRNN, Initial Train Step: 6800, Avg Perf on Train Set: 0.994\n",
      "Model: VanillaRNN, Initial Train Step: 7000/10000, Loss: 0.0108\n",
      "Model: VanillaRNN, Initial Train Step: 7000, Avg Perf on Train Set: 0.996\n",
      "Model: VanillaRNN, Initial Train Step: 7200/10000, Loss: 0.0140\n",
      "Model: VanillaRNN, Initial Train Step: 7200, Avg Perf on Train Set: 0.994\n",
      "Model: VanillaRNN, Initial Train Step: 7400/10000, Loss: 0.0144\n",
      "Model: VanillaRNN, Initial Train Step: 7400, Avg Perf on Train Set: 0.996\n",
      "Model: VanillaRNN, Initial Train Step: 7600/10000, Loss: 0.0121\n",
      "Model: VanillaRNN, Initial Train Step: 7600, Avg Perf on Train Set: 0.997\n",
      "Model: VanillaRNN, Initial Train Step: 7800/10000, Loss: 0.0113\n",
      "Model: VanillaRNN, Initial Train Step: 7800, Avg Perf on Train Set: 0.995\n",
      "Model: VanillaRNN, Initial Train Step: 8000/10000, Loss: 0.0123\n",
      "Model: VanillaRNN, Initial Train Step: 8000, Avg Perf on Train Set: 0.996\n",
      "Model: VanillaRNN, Initial Train Step: 8200/10000, Loss: 0.0102\n",
      "Model: VanillaRNN, Initial Train Step: 8200, Avg Perf on Train Set: 0.996\n",
      "Model: VanillaRNN, Initial Train Step: 8400/10000, Loss: 0.0113\n",
      "Model: VanillaRNN, Initial Train Step: 8400, Avg Perf on Train Set: 0.995\n",
      "Model: VanillaRNN, Initial Train Step: 8600/10000, Loss: 0.0146\n",
      "Model: VanillaRNN, Initial Train Step: 8600, Avg Perf on Train Set: 0.996\n",
      "Model: VanillaRNN, Initial Train Step: 8800/10000, Loss: 0.0107\n",
      "Model: VanillaRNN, Initial Train Step: 8800, Avg Perf on Train Set: 0.996\n",
      "Model: VanillaRNN, Initial Train Step: 9000/10000, Loss: 0.0095\n",
      "Model: VanillaRNN, Initial Train Step: 9000, Avg Perf on Train Set: 0.997\n",
      "Model: VanillaRNN, Initial Train Step: 9200/10000, Loss: 0.0093\n",
      "Model: VanillaRNN, Initial Train Step: 9200, Avg Perf on Train Set: 0.997\n",
      "Model: VanillaRNN, Initial Train Step: 9400/10000, Loss: 0.0098\n",
      "Model: VanillaRNN, Initial Train Step: 9400, Avg Perf on Train Set: 0.997\n",
      "Model: VanillaRNN, Initial Train Step: 9600/10000, Loss: 0.0123\n",
      "Model: VanillaRNN, Initial Train Step: 9600, Avg Perf on Train Set: 0.996\n",
      "Model: VanillaRNN, Initial Train Step: 9800/10000, Loss: 0.0104\n",
      "Model: VanillaRNN, Initial Train Step: 9800, Avg Perf on Train Set: 0.996\n",
      "Model: VanillaRNN, Initial Train Step: 10000/10000, Loss: 0.0122\n",
      "Model: VanillaRNN, Initial Train Step: 10000, Avg Perf on Train Set: 0.996\n",
      "Finished Initial Training for VanillaRNN in 2355.00 seconds.\n",
      "\n",
      "--- Phase 2: Adaptation for VanillaRNN on 5 tasks ---\n",
      "\n",
      "-- Adapting VanillaRNN to task: yang19.multidlydm-v0 --\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 100/1000, Loss: 0.1691, Perf: 0.969\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 200/1000, Loss: 0.1389, Perf: 0.971\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 300/1000, Loss: 0.1297, Perf: 0.973\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 400/1000, Loss: 0.1264, Perf: 0.972\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 500/1000, Loss: 0.1207, Perf: 0.970\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 600/1000, Loss: 0.1196, Perf: 0.969\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 700/1000, Loss: 0.1178, Perf: 0.971\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 800/1000, Loss: 0.1177, Perf: 0.971\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 900/1000, Loss: 0.1158, Perf: 0.971\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 1000/1000, Loss: 0.1152, Perf: 0.973\n",
      "\n",
      "-- Adapting VanillaRNN to task: yang19.ctxdm1-v0 --\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 100/1000, Loss: 0.1055, Perf: 0.964\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 200/1000, Loss: 0.0916, Perf: 0.965\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 300/1000, Loss: 0.0886, Perf: 0.964\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 400/1000, Loss: 0.0893, Perf: 0.967\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 500/1000, Loss: 0.0884, Perf: 0.967\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 600/1000, Loss: 0.0880, Perf: 0.968\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 700/1000, Loss: 0.0878, Perf: 0.969\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 800/1000, Loss: 0.0873, Perf: 0.968\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 900/1000, Loss: 0.0884, Perf: 0.969\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 1000/1000, Loss: 0.0888, Perf: 0.965\n",
      "\n",
      "-- Adapting VanillaRNN to task: yang19.rtgo-v0 --\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 100/1000, Loss: 4.3606, Perf: 0.809\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 200/1000, Loss: 1.7896, Perf: 0.818\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 300/1000, Loss: 1.5023, Perf: 0.819\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 400/1000, Loss: 1.4202, Perf: 0.819\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 500/1000, Loss: 1.3558, Perf: 0.820\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 600/1000, Loss: 1.3130, Perf: 0.820\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 700/1000, Loss: 1.2842, Perf: 0.813\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 800/1000, Loss: 1.2554, Perf: 0.815\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 900/1000, Loss: 1.2171, Perf: 0.818\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 1000/1000, Loss: 1.1593, Perf: 0.817\n",
      "\n",
      "-- Adapting VanillaRNN to task: yang19.ctxdlydm1-v0 --\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 100/1000, Loss: 0.1720, Perf: 0.972\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 200/1000, Loss: 0.1635, Perf: 0.975\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 300/1000, Loss: 0.1592, Perf: 0.972\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 400/1000, Loss: 0.1548, Perf: 0.974\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 500/1000, Loss: 0.1536, Perf: 0.974\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 600/1000, Loss: 0.1548, Perf: 0.974\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 700/1000, Loss: 0.1515, Perf: 0.974\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 800/1000, Loss: 0.1475, Perf: 0.974\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 900/1000, Loss: 0.1468, Perf: 0.974\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 1000/1000, Loss: 0.1463, Perf: 0.976\n",
      "\n",
      "-- Adapting VanillaRNN to task: yang19.dlydm2-v0 --\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 100/1000, Loss: 0.2866, Perf: 0.974\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 200/1000, Loss: 0.1767, Perf: 0.972\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 300/1000, Loss: 0.1553, Perf: 0.975\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 400/1000, Loss: 0.1480, Perf: 0.973\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 500/1000, Loss: 0.1446, Perf: 0.974\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 600/1000, Loss: 0.1422, Perf: 0.973\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 700/1000, Loss: 0.1392, Perf: 0.972\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 800/1000, Loss: 0.1351, Perf: 0.974\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 900/1000, Loss: 0.1342, Perf: 0.974\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 1000/1000, Loss: 0.1299, Perf: 0.975\n",
      "Finished Adaptation for VanillaRNN in 1275.68 seconds.\n",
      "\n",
      "\n",
      "--- Training and Adapting Model: LSTM ---\n",
      "Total trainable parameters for LSTM: 322,833\n",
      "\n",
      "--- Phase 1: Initial Training on 15 tasks for LSTM ---\n",
      "Model: LSTM, Initial Train Step: 200/10000, Loss: 0.5519\n",
      "Model: LSTM, Initial Train Step: 200, Avg Perf on Train Set: 0.922\n",
      "Model: LSTM, Initial Train Step: 400/10000, Loss: 0.2786\n",
      "Model: LSTM, Initial Train Step: 400, Avg Perf on Train Set: 0.930\n",
      "Model: LSTM, Initial Train Step: 600/10000, Loss: 0.1743\n",
      "Model: LSTM, Initial Train Step: 600, Avg Perf on Train Set: 0.954\n",
      "Model: LSTM, Initial Train Step: 800/10000, Loss: 0.1170\n",
      "Model: LSTM, Initial Train Step: 800, Avg Perf on Train Set: 0.967\n",
      "Model: LSTM, Initial Train Step: 1000/10000, Loss: 0.0763\n",
      "Model: LSTM, Initial Train Step: 1000, Avg Perf on Train Set: 0.984\n",
      "Model: LSTM, Initial Train Step: 1200/10000, Loss: 0.0498\n",
      "Model: LSTM, Initial Train Step: 1200, Avg Perf on Train Set: 0.990\n",
      "Model: LSTM, Initial Train Step: 1400/10000, Loss: 0.0369\n",
      "Model: LSTM, Initial Train Step: 1400, Avg Perf on Train Set: 0.992\n",
      "Model: LSTM, Initial Train Step: 1600/10000, Loss: 0.0309\n",
      "Model: LSTM, Initial Train Step: 1600, Avg Perf on Train Set: 0.991\n",
      "Model: LSTM, Initial Train Step: 1800/10000, Loss: 0.0267\n",
      "Model: LSTM, Initial Train Step: 1800, Avg Perf on Train Set: 0.993\n",
      "Model: LSTM, Initial Train Step: 2000/10000, Loss: 0.0242\n",
      "Model: LSTM, Initial Train Step: 2000, Avg Perf on Train Set: 0.993\n",
      "Model: LSTM, Initial Train Step: 2200/10000, Loss: 0.0224\n",
      "Model: LSTM, Initial Train Step: 2200, Avg Perf on Train Set: 0.994\n",
      "Model: LSTM, Initial Train Step: 2400/10000, Loss: 0.0205\n",
      "Model: LSTM, Initial Train Step: 2400, Avg Perf on Train Set: 0.995\n",
      "Model: LSTM, Initial Train Step: 2600/10000, Loss: 0.0192\n",
      "Model: LSTM, Initial Train Step: 2600, Avg Perf on Train Set: 0.995\n",
      "Model: LSTM, Initial Train Step: 2800/10000, Loss: 0.0184\n",
      "Model: LSTM, Initial Train Step: 2800, Avg Perf on Train Set: 0.995\n",
      "Model: LSTM, Initial Train Step: 3000/10000, Loss: 0.0163\n",
      "Model: LSTM, Initial Train Step: 3000, Avg Perf on Train Set: 0.994\n",
      "Model: LSTM, Initial Train Step: 3200/10000, Loss: 0.0159\n",
      "Model: LSTM, Initial Train Step: 3200, Avg Perf on Train Set: 0.995\n",
      "Model: LSTM, Initial Train Step: 3400/10000, Loss: 0.0146\n",
      "Model: LSTM, Initial Train Step: 3400, Avg Perf on Train Set: 0.995\n",
      "Model: LSTM, Initial Train Step: 3600/10000, Loss: 0.0142\n",
      "Model: LSTM, Initial Train Step: 3600, Avg Perf on Train Set: 0.994\n",
      "Model: LSTM, Initial Train Step: 3800/10000, Loss: 0.0138\n",
      "Model: LSTM, Initial Train Step: 3800, Avg Perf on Train Set: 0.996\n",
      "Model: LSTM, Initial Train Step: 4000/10000, Loss: 0.0130\n",
      "Model: LSTM, Initial Train Step: 4000, Avg Perf on Train Set: 0.996\n",
      "Model: LSTM, Initial Train Step: 4200/10000, Loss: 0.0121\n",
      "Model: LSTM, Initial Train Step: 4200, Avg Perf on Train Set: 0.993\n",
      "Model: LSTM, Initial Train Step: 4400/10000, Loss: 0.0121\n",
      "Model: LSTM, Initial Train Step: 4400, Avg Perf on Train Set: 0.995\n",
      "Model: LSTM, Initial Train Step: 4600/10000, Loss: 0.0113\n",
      "Model: LSTM, Initial Train Step: 4600, Avg Perf on Train Set: 0.997\n",
      "Model: LSTM, Initial Train Step: 4800/10000, Loss: 0.0117\n",
      "Model: LSTM, Initial Train Step: 4800, Avg Perf on Train Set: 0.997\n",
      "Model: LSTM, Initial Train Step: 5000/10000, Loss: 0.0111\n",
      "Model: LSTM, Initial Train Step: 5000, Avg Perf on Train Set: 0.996\n",
      "Model: LSTM, Initial Train Step: 5200/10000, Loss: 0.0109\n",
      "Model: LSTM, Initial Train Step: 5200, Avg Perf on Train Set: 0.995\n",
      "Model: LSTM, Initial Train Step: 5400/10000, Loss: 0.0104\n",
      "Model: LSTM, Initial Train Step: 5400, Avg Perf on Train Set: 0.996\n",
      "Model: LSTM, Initial Train Step: 5600/10000, Loss: 0.0099\n",
      "Model: LSTM, Initial Train Step: 5600, Avg Perf on Train Set: 0.996\n",
      "Model: LSTM, Initial Train Step: 5800/10000, Loss: 0.0095\n",
      "Model: LSTM, Initial Train Step: 5800, Avg Perf on Train Set: 0.997\n",
      "Model: LSTM, Initial Train Step: 6000/10000, Loss: 0.0091\n",
      "Model: LSTM, Initial Train Step: 6000, Avg Perf on Train Set: 0.996\n",
      "Model: LSTM, Initial Train Step: 6200/10000, Loss: 0.0093\n",
      "Model: LSTM, Initial Train Step: 6200, Avg Perf on Train Set: 0.996\n",
      "Model: LSTM, Initial Train Step: 6400/10000, Loss: 0.0091\n",
      "Model: LSTM, Initial Train Step: 6400, Avg Perf on Train Set: 0.998\n",
      "Model: LSTM, Initial Train Step: 6600/10000, Loss: 0.0100\n",
      "Model: LSTM, Initial Train Step: 6600, Avg Perf on Train Set: 0.997\n",
      "Model: LSTM, Initial Train Step: 6800/10000, Loss: 0.0091\n",
      "Model: LSTM, Initial Train Step: 6800, Avg Perf on Train Set: 0.997\n",
      "Model: LSTM, Initial Train Step: 7000/10000, Loss: 0.0088\n",
      "Model: LSTM, Initial Train Step: 7000, Avg Perf on Train Set: 0.996\n",
      "Model: LSTM, Initial Train Step: 7200/10000, Loss: 0.0082\n",
      "Model: LSTM, Initial Train Step: 7200, Avg Perf on Train Set: 0.995\n",
      "Model: LSTM, Initial Train Step: 7400/10000, Loss: 0.0082\n",
      "Model: LSTM, Initial Train Step: 7400, Avg Perf on Train Set: 0.997\n",
      "Model: LSTM, Initial Train Step: 7600/10000, Loss: 0.0080\n",
      "Model: LSTM, Initial Train Step: 7600, Avg Perf on Train Set: 0.998\n",
      "Model: LSTM, Initial Train Step: 7800/10000, Loss: 0.0079\n",
      "Model: LSTM, Initial Train Step: 7800, Avg Perf on Train Set: 0.997\n",
      "Model: LSTM, Initial Train Step: 8000/10000, Loss: 0.0078\n",
      "Model: LSTM, Initial Train Step: 8000, Avg Perf on Train Set: 0.997\n",
      "Model: LSTM, Initial Train Step: 8200/10000, Loss: 0.0080\n",
      "Model: LSTM, Initial Train Step: 8200, Avg Perf on Train Set: 0.998\n",
      "Model: LSTM, Initial Train Step: 8400/10000, Loss: 0.0075\n",
      "Model: LSTM, Initial Train Step: 8400, Avg Perf on Train Set: 0.997\n",
      "Model: LSTM, Initial Train Step: 8600/10000, Loss: 0.0075\n",
      "Model: LSTM, Initial Train Step: 8600, Avg Perf on Train Set: 0.997\n",
      "Model: LSTM, Initial Train Step: 8800/10000, Loss: 0.0073\n",
      "Model: LSTM, Initial Train Step: 8800, Avg Perf on Train Set: 0.997\n",
      "Model: LSTM, Initial Train Step: 9000/10000, Loss: 0.0074\n",
      "Model: LSTM, Initial Train Step: 9000, Avg Perf on Train Set: 0.997\n",
      "Model: LSTM, Initial Train Step: 9200/10000, Loss: 0.0071\n",
      "Model: LSTM, Initial Train Step: 9200, Avg Perf on Train Set: 0.997\n",
      "Model: LSTM, Initial Train Step: 9400/10000, Loss: 0.0073\n",
      "Model: LSTM, Initial Train Step: 9400, Avg Perf on Train Set: 0.997\n",
      "Model: LSTM, Initial Train Step: 9600/10000, Loss: 0.0070\n",
      "Model: LSTM, Initial Train Step: 9600, Avg Perf on Train Set: 0.997\n",
      "Model: LSTM, Initial Train Step: 9800/10000, Loss: 0.0069\n",
      "Model: LSTM, Initial Train Step: 9800, Avg Perf on Train Set: 0.997\n",
      "Model: LSTM, Initial Train Step: 10000/10000, Loss: 0.0066\n",
      "Model: LSTM, Initial Train Step: 10000, Avg Perf on Train Set: 0.997\n",
      "Finished Initial Training for LSTM in 2392.82 seconds.\n",
      "\n",
      "--- Phase 2: Adaptation for LSTM on 5 tasks ---\n",
      "\n",
      "-- Adapting LSTM to task: yang19.multidlydm-v0 --\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 100/1000, Loss: 0.4455, Perf: 0.955\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 200/1000, Loss: 0.2326, Perf: 0.969\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 300/1000, Loss: 0.2286, Perf: 0.966\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 400/1000, Loss: 0.2208, Perf: 0.967\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 500/1000, Loss: 0.2186, Perf: 0.968\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 600/1000, Loss: 0.2118, Perf: 0.968\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 700/1000, Loss: 0.2076, Perf: 0.967\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 800/1000, Loss: 0.2059, Perf: 0.968\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 900/1000, Loss: 0.1965, Perf: 0.969\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 1000/1000, Loss: 0.1954, Perf: 0.969\n",
      "\n",
      "-- Adapting LSTM to task: yang19.ctxdm1-v0 --\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 100/1000, Loss: 0.3214, Perf: 0.963\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 200/1000, Loss: 0.1154, Perf: 0.972\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 300/1000, Loss: 0.0897, Perf: 0.970\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 400/1000, Loss: 0.0812, Perf: 0.973\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 500/1000, Loss: 0.0791, Perf: 0.973\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 600/1000, Loss: 0.0783, Perf: 0.974\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 700/1000, Loss: 0.0770, Perf: 0.975\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 800/1000, Loss: 0.0765, Perf: 0.973\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 900/1000, Loss: 0.0765, Perf: 0.970\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 1000/1000, Loss: 0.0748, Perf: 0.974\n",
      "\n",
      "-- Adapting LSTM to task: yang19.rtgo-v0 --\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 100/1000, Loss: 1.0275, Perf: 0.824\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 200/1000, Loss: 0.6439, Perf: 0.835\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 300/1000, Loss: 0.6110, Perf: 0.837\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 400/1000, Loss: 0.5983, Perf: 0.842\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 500/1000, Loss: 0.5868, Perf: 0.842\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 600/1000, Loss: 0.5815, Perf: 0.847\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 700/1000, Loss: 0.5731, Perf: 0.841\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 800/1000, Loss: 0.5655, Perf: 0.848\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 900/1000, Loss: 0.5540, Perf: 0.848\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 1000/1000, Loss: 0.5404, Perf: 0.848\n",
      "\n",
      "-- Adapting LSTM to task: yang19.ctxdlydm1-v0 --\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 100/1000, Loss: 0.3206, Perf: 0.966\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 200/1000, Loss: 0.2224, Perf: 0.973\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 300/1000, Loss: 0.2049, Perf: 0.971\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 400/1000, Loss: 0.1850, Perf: 0.970\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 500/1000, Loss: 0.1525, Perf: 0.975\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 600/1000, Loss: 0.1349, Perf: 0.970\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 700/1000, Loss: 0.1300, Perf: 0.971\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 800/1000, Loss: 0.1257, Perf: 0.972\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 900/1000, Loss: 0.1120, Perf: 0.971\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 1000/1000, Loss: 0.1095, Perf: 0.970\n",
      "\n",
      "-- Adapting LSTM to task: yang19.dlydm2-v0 --\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 100/1000, Loss: 0.1569, Perf: 0.972\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 200/1000, Loss: 0.1221, Perf: 0.972\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 300/1000, Loss: 0.1148, Perf: 0.972\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 400/1000, Loss: 0.1111, Perf: 0.973\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 500/1000, Loss: 0.1084, Perf: 0.972\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 600/1000, Loss: 0.1058, Perf: 0.972\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 700/1000, Loss: 0.1033, Perf: 0.975\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 800/1000, Loss: 0.1032, Perf: 0.974\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 900/1000, Loss: 0.0995, Perf: 0.974\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 1000/1000, Loss: 0.0998, Perf: 0.974\n",
      "Finished Adaptation for LSTM in 1292.27 seconds.\n",
      "\n",
      "\n",
      "--- Training and Adapting Model: Transformer ---\n",
      "Total trainable parameters for Transformer: 1,072,401\n",
      "\n",
      "--- Phase 1: Initial Training on 15 tasks for Transformer ---\n",
      "Model: Transformer, Initial Train Step: 200/10000, Loss: 0.2253\n",
      "Model: Transformer, Initial Train Step: 200, Avg Perf on Train Set: 0.941\n",
      "Model: Transformer, Initial Train Step: 400/10000, Loss: 0.1700\n",
      "Model: Transformer, Initial Train Step: 400, Avg Perf on Train Set: 0.946\n",
      "Model: Transformer, Initial Train Step: 600/10000, Loss: 0.1602\n",
      "Model: Transformer, Initial Train Step: 600, Avg Perf on Train Set: 0.947\n",
      "Model: Transformer, Initial Train Step: 800/10000, Loss: 0.1553\n",
      "Model: Transformer, Initial Train Step: 800, Avg Perf on Train Set: 0.948\n",
      "Model: Transformer, Initial Train Step: 1000/10000, Loss: 0.1526\n",
      "Model: Transformer, Initial Train Step: 1000, Avg Perf on Train Set: 0.948\n",
      "Model: Transformer, Initial Train Step: 1200/10000, Loss: 0.1499\n",
      "Model: Transformer, Initial Train Step: 1200, Avg Perf on Train Set: 0.947\n",
      "Model: Transformer, Initial Train Step: 1400/10000, Loss: 0.1488\n",
      "Model: Transformer, Initial Train Step: 1400, Avg Perf on Train Set: 0.946\n",
      "Model: Transformer, Initial Train Step: 1600/10000, Loss: 0.1493\n",
      "Model: Transformer, Initial Train Step: 1600, Avg Perf on Train Set: 0.947\n",
      "Model: Transformer, Initial Train Step: 1800/10000, Loss: 0.1470\n",
      "Model: Transformer, Initial Train Step: 1800, Avg Perf on Train Set: 0.947\n",
      "Model: Transformer, Initial Train Step: 2000/10000, Loss: 0.1460\n",
      "Model: Transformer, Initial Train Step: 2000, Avg Perf on Train Set: 0.947\n",
      "Model: Transformer, Initial Train Step: 2200/10000, Loss: 0.1481\n",
      "Model: Transformer, Initial Train Step: 2200, Avg Perf on Train Set: 0.945\n",
      "Model: Transformer, Initial Train Step: 2400/10000, Loss: 0.1461\n",
      "Model: Transformer, Initial Train Step: 2400, Avg Perf on Train Set: 0.948\n",
      "Model: Transformer, Initial Train Step: 2600/10000, Loss: 0.1491\n",
      "Model: Transformer, Initial Train Step: 2600, Avg Perf on Train Set: 0.944\n",
      "Model: Transformer, Initial Train Step: 2800/10000, Loss: 0.1442\n",
      "Model: Transformer, Initial Train Step: 2800, Avg Perf on Train Set: 0.947\n",
      "Model: Transformer, Initial Train Step: 3000/10000, Loss: 0.1450\n",
      "Model: Transformer, Initial Train Step: 3000, Avg Perf on Train Set: 0.948\n",
      "Model: Transformer, Initial Train Step: 3200/10000, Loss: 0.1458\n",
      "Model: Transformer, Initial Train Step: 3200, Avg Perf on Train Set: 0.947\n",
      "Model: Transformer, Initial Train Step: 3400/10000, Loss: 0.1437\n",
      "Model: Transformer, Initial Train Step: 3400, Avg Perf on Train Set: 0.948\n",
      "Model: Transformer, Initial Train Step: 3600/10000, Loss: 0.1429\n",
      "Model: Transformer, Initial Train Step: 3600, Avg Perf on Train Set: 0.947\n",
      "Model: Transformer, Initial Train Step: 3800/10000, Loss: 0.1446\n",
      "Model: Transformer, Initial Train Step: 3800, Avg Perf on Train Set: 0.947\n",
      "Model: Transformer, Initial Train Step: 4000/10000, Loss: 0.1443\n",
      "Model: Transformer, Initial Train Step: 4000, Avg Perf on Train Set: 0.945\n",
      "Model: Transformer, Initial Train Step: 4200/10000, Loss: 0.1436\n",
      "Model: Transformer, Initial Train Step: 4200, Avg Perf on Train Set: 0.947\n",
      "Model: Transformer, Initial Train Step: 4400/10000, Loss: 0.1425\n",
      "Model: Transformer, Initial Train Step: 4400, Avg Perf on Train Set: 0.946\n",
      "Model: Transformer, Initial Train Step: 4600/10000, Loss: 0.1430\n",
      "Model: Transformer, Initial Train Step: 4600, Avg Perf on Train Set: 0.947\n",
      "Model: Transformer, Initial Train Step: 4800/10000, Loss: 0.1420\n",
      "Model: Transformer, Initial Train Step: 4800, Avg Perf on Train Set: 0.947\n",
      "Model: Transformer, Initial Train Step: 5000/10000, Loss: 0.1491\n",
      "Model: Transformer, Initial Train Step: 5000, Avg Perf on Train Set: 0.946\n",
      "Model: Transformer, Initial Train Step: 5200/10000, Loss: 0.1441\n",
      "Model: Transformer, Initial Train Step: 5200, Avg Perf on Train Set: 0.945\n",
      "Model: Transformer, Initial Train Step: 5400/10000, Loss: 0.1442\n",
      "Model: Transformer, Initial Train Step: 5400, Avg Perf on Train Set: 0.948\n",
      "Model: Transformer, Initial Train Step: 5600/10000, Loss: 0.1431\n",
      "Model: Transformer, Initial Train Step: 5600, Avg Perf on Train Set: 0.949\n",
      "Model: Transformer, Initial Train Step: 5800/10000, Loss: 0.1419\n",
      "Model: Transformer, Initial Train Step: 5800, Avg Perf on Train Set: 0.948\n",
      "Model: Transformer, Initial Train Step: 6000/10000, Loss: 0.1399\n",
      "Model: Transformer, Initial Train Step: 6000, Avg Perf on Train Set: 0.946\n",
      "Model: Transformer, Initial Train Step: 6200/10000, Loss: 0.1417\n",
      "Model: Transformer, Initial Train Step: 6200, Avg Perf on Train Set: 0.947\n",
      "Model: Transformer, Initial Train Step: 6400/10000, Loss: 0.1419\n",
      "Model: Transformer, Initial Train Step: 6400, Avg Perf on Train Set: 0.949\n",
      "Model: Transformer, Initial Train Step: 6600/10000, Loss: 0.1429\n",
      "Model: Transformer, Initial Train Step: 6600, Avg Perf on Train Set: 0.948\n",
      "Model: Transformer, Initial Train Step: 6800/10000, Loss: 0.1422\n",
      "Model: Transformer, Initial Train Step: 6800, Avg Perf on Train Set: 0.948\n",
      "Model: Transformer, Initial Train Step: 7000/10000, Loss: 0.1418\n",
      "Model: Transformer, Initial Train Step: 7000, Avg Perf on Train Set: 0.948\n",
      "Model: Transformer, Initial Train Step: 7200/10000, Loss: 0.1413\n",
      "Model: Transformer, Initial Train Step: 7200, Avg Perf on Train Set: 0.947\n",
      "Model: Transformer, Initial Train Step: 7400/10000, Loss: 0.1403\n",
      "Model: Transformer, Initial Train Step: 7400, Avg Perf on Train Set: 0.947\n",
      "Model: Transformer, Initial Train Step: 7600/10000, Loss: 0.1415\n",
      "Model: Transformer, Initial Train Step: 7600, Avg Perf on Train Set: 0.949\n",
      "Model: Transformer, Initial Train Step: 7800/10000, Loss: 0.1417\n",
      "Model: Transformer, Initial Train Step: 7800, Avg Perf on Train Set: 0.949\n",
      "Model: Transformer, Initial Train Step: 8000/10000, Loss: 0.1400\n",
      "Model: Transformer, Initial Train Step: 8000, Avg Perf on Train Set: 0.946\n",
      "Model: Transformer, Initial Train Step: 8200/10000, Loss: 0.1425\n",
      "Model: Transformer, Initial Train Step: 8200, Avg Perf on Train Set: 0.948\n",
      "Model: Transformer, Initial Train Step: 8400/10000, Loss: 0.1427\n",
      "Model: Transformer, Initial Train Step: 8400, Avg Perf on Train Set: 0.948\n",
      "Model: Transformer, Initial Train Step: 8600/10000, Loss: 0.1408\n",
      "Model: Transformer, Initial Train Step: 8600, Avg Perf on Train Set: 0.948\n",
      "Model: Transformer, Initial Train Step: 8800/10000, Loss: 0.1399\n",
      "Model: Transformer, Initial Train Step: 8800, Avg Perf on Train Set: 0.950\n",
      "Model: Transformer, Initial Train Step: 9000/10000, Loss: 0.1412\n",
      "Model: Transformer, Initial Train Step: 9000, Avg Perf on Train Set: 0.947\n",
      "Model: Transformer, Initial Train Step: 9200/10000, Loss: 0.1414\n",
      "Model: Transformer, Initial Train Step: 9200, Avg Perf on Train Set: 0.947\n",
      "Model: Transformer, Initial Train Step: 9400/10000, Loss: 0.1402\n",
      "Model: Transformer, Initial Train Step: 9400, Avg Perf on Train Set: 0.948\n",
      "Model: Transformer, Initial Train Step: 9600/10000, Loss: 0.1405\n",
      "Model: Transformer, Initial Train Step: 9600, Avg Perf on Train Set: 0.947\n",
      "Model: Transformer, Initial Train Step: 9800/10000, Loss: 0.1406\n",
      "Model: Transformer, Initial Train Step: 9800, Avg Perf on Train Set: 0.949\n",
      "Model: Transformer, Initial Train Step: 10000/10000, Loss: 0.1396\n",
      "Model: Transformer, Initial Train Step: 10000, Avg Perf on Train Set: 0.948\n",
      "Finished Initial Training for Transformer in 2621.69 seconds.\n",
      "\n",
      "--- Phase 2: Adaptation for Transformer on 5 tasks ---\n",
      "\n",
      "-- Adapting Transformer to task: yang19.multidlydm-v0 --\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 100/1000, Loss: 0.1000, Perf: 0.967\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 200/1000, Loss: 0.0808, Perf: 0.968\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 300/1000, Loss: 0.0795, Perf: 0.968\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 400/1000, Loss: 0.0789, Perf: 0.969\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 500/1000, Loss: 0.0781, Perf: 0.969\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 600/1000, Loss: 0.0777, Perf: 0.969\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 700/1000, Loss: 0.0772, Perf: 0.968\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 800/1000, Loss: 0.0772, Perf: 0.968\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 900/1000, Loss: 0.0770, Perf: 0.969\n",
      "  Adapt Task: yang19.multidlydm-v0, Step: 1000/1000, Loss: 0.0765, Perf: 0.967\n",
      "\n",
      "-- Adapting Transformer to task: yang19.ctxdm1-v0 --\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 100/1000, Loss: 0.5601, Perf: 0.900\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 200/1000, Loss: 0.3189, Perf: 0.902\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 300/1000, Loss: 0.3026, Perf: 0.900\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 400/1000, Loss: 0.2953, Perf: 0.903\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 500/1000, Loss: 0.2888, Perf: 0.902\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 600/1000, Loss: 0.2862, Perf: 0.901\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 700/1000, Loss: 0.2838, Perf: 0.904\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 800/1000, Loss: 0.2823, Perf: 0.902\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 900/1000, Loss: 0.2822, Perf: 0.905\n",
      "  Adapt Task: yang19.ctxdm1-v0, Step: 1000/1000, Loss: 0.2806, Perf: 0.904\n",
      "\n",
      "-- Adapting Transformer to task: yang19.rtgo-v0 --\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 100/1000, Loss: 1.2571, Perf: 0.803\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 200/1000, Loss: 0.8100, Perf: 0.812\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 300/1000, Loss: 0.6322, Perf: 0.810\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 400/1000, Loss: 0.6159, Perf: 0.809\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 500/1000, Loss: 0.6074, Perf: 0.808\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 600/1000, Loss: 0.6015, Perf: 0.808\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 700/1000, Loss: 0.5960, Perf: 0.811\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 800/1000, Loss: 0.5927, Perf: 0.808\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 900/1000, Loss: 0.5899, Perf: 0.811\n",
      "  Adapt Task: yang19.rtgo-v0, Step: 1000/1000, Loss: 0.5865, Perf: 0.811\n",
      "\n",
      "-- Adapting Transformer to task: yang19.ctxdlydm1-v0 --\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 100/1000, Loss: 0.2771, Perf: 0.965\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 200/1000, Loss: 0.0942, Perf: 0.966\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 300/1000, Loss: 0.0831, Perf: 0.967\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 400/1000, Loss: 0.0796, Perf: 0.969\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 500/1000, Loss: 0.0780, Perf: 0.970\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 600/1000, Loss: 0.0778, Perf: 0.969\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 700/1000, Loss: 0.0772, Perf: 0.968\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 800/1000, Loss: 0.0767, Perf: 0.968\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 900/1000, Loss: 0.0768, Perf: 0.967\n",
      "  Adapt Task: yang19.ctxdlydm1-v0, Step: 1000/1000, Loss: 0.0763, Perf: 0.968\n",
      "\n",
      "-- Adapting Transformer to task: yang19.dlydm2-v0 --\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 100/1000, Loss: 0.2283, Perf: 0.968\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 200/1000, Loss: 0.0811, Perf: 0.968\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 300/1000, Loss: 0.0760, Perf: 0.969\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 400/1000, Loss: 0.0741, Perf: 0.967\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 500/1000, Loss: 0.0730, Perf: 0.969\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 600/1000, Loss: 0.0726, Perf: 0.968\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 700/1000, Loss: 0.0721, Perf: 0.970\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 800/1000, Loss: 0.0720, Perf: 0.967\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 900/1000, Loss: 0.0719, Perf: 0.968\n",
      "  Adapt Task: yang19.dlydm2-v0, Step: 1000/1000, Loss: 0.0716, Perf: 0.969\n",
      "Finished Adaptation for Transformer in 1385.59 seconds.\n",
      "\n",
      "\n",
      "--- Overall Results ---\n",
      "\n",
      "Initial Training Performance Curves (Avg Perf on 15 Train Tasks vs. Eval Points):\n",
      "Model: nmRNN_Spatial\n",
      "  Performance points: [0.922, 0.920, 0.919, 0.922, 0.924, 0.922, 0.921, 0.920, 0.922, 0.922, 0.921, 0.921, 0.920, 0.922, 0.921, 0.923, 0.921, 0.923, 0.920, 0.919, 0.919, 0.921, 0.921, 0.923, 0.922, 0.921, 0.921, 0.920, 0.921, 0.920, 0.920, 0.919, 0.923, 0.920, 0.921, 0.921, 0.920, 0.922, 0.921, 0.921, 0.921, 0.920, 0.921, 0.920, 0.921, 0.919, 0.920, 0.922, 0.921, 0.922]\n",
      "Model: nmRNN_NonSpatial\n",
      "  Performance points: [0.922, 0.922, 0.919, 0.920, 0.921, 0.921, 0.920, 0.920, 0.922, 0.921, 0.922, 0.919, 0.920, 0.922, 0.922, 0.920, 0.922, 0.920, 0.922, 0.923, 0.921, 0.920, 0.920, 0.921, 0.921, 0.920, 0.920, 0.922, 0.920, 0.921, 0.921, 0.922, 0.922, 0.921, 0.922, 0.920, 0.921, 0.920, 0.922, 0.922, 0.920, 0.922, 0.923, 0.922, 0.921, 0.919, 0.922, 0.921, 0.921, 0.919]\n",
      "Model: VanillaRNN\n",
      "  Performance points: [0.923, 0.951, 0.967, 0.980, 0.985, 0.986, 0.990, 0.992, 0.992, 0.995, 0.993, 0.992, 0.994, 0.993, 0.993, 0.994, 0.994, 0.993, 0.995, 0.994, 0.995, 0.995, 0.995, 0.994, 0.995, 0.992, 0.996, 0.992, 0.995, 0.996, 0.997, 0.995, 0.996, 0.994, 0.996, 0.994, 0.996, 0.997, 0.995, 0.996, 0.996, 0.995, 0.996, 0.996, 0.997, 0.997, 0.997, 0.996, 0.996, 0.996]\n",
      "Model: LSTM\n",
      "  Performance points: [0.922, 0.930, 0.954, 0.967, 0.984, 0.990, 0.992, 0.991, 0.993, 0.993, 0.994, 0.995, 0.995, 0.995, 0.994, 0.995, 0.995, 0.994, 0.996, 0.996, 0.993, 0.995, 0.997, 0.997, 0.996, 0.995, 0.996, 0.996, 0.997, 0.996, 0.996, 0.998, 0.997, 0.997, 0.996, 0.995, 0.997, 0.998, 0.997, 0.997, 0.998, 0.997, 0.997, 0.997, 0.997, 0.997, 0.997, 0.997, 0.997, 0.997]\n",
      "Model: Transformer\n",
      "  Performance points: [0.941, 0.946, 0.947, 0.948, 0.948, 0.947, 0.946, 0.947, 0.947, 0.947, 0.945, 0.948, 0.944, 0.947, 0.948, 0.947, 0.948, 0.947, 0.947, 0.945, 0.947, 0.946, 0.947, 0.947, 0.946, 0.945, 0.948, 0.949, 0.948, 0.946, 0.947, 0.949, 0.948, 0.948, 0.948, 0.947, 0.947, 0.949, 0.949, 0.946, 0.948, 0.948, 0.948, 0.950, 0.947, 0.947, 0.948, 0.947, 0.949, 0.948]\n",
      "\n",
      "Adaptation Training Curves (Perf on Held-out Task vs. Eval Points):\n",
      "Model: nmRNN_Spatial\n",
      "  Adapted Task: yang19.multidlydm-v0\n",
      "    Performance points: [0.955, 0.955, 0.955, 0.955, 0.954, 0.954, 0.954, 0.954, 0.955, 0.955]\n",
      "  Adapted Task: yang19.ctxdm1-v0\n",
      "    Performance points: [0.881, 0.882, 0.882, 0.881, 0.882, 0.882, 0.882, 0.882, 0.880, 0.881]\n",
      "  Adapted Task: yang19.rtgo-v0\n",
      "    Performance points: [0.800, 0.800, 0.800, 0.800, 0.800, 0.800, 0.800, 0.800, 0.800, 0.800]\n",
      "  Adapted Task: yang19.ctxdlydm1-v0\n",
      "    Performance points: [0.954, 0.954, 0.955, 0.955, 0.955, 0.954, 0.954, 0.956, 0.956, 0.955]\n",
      "  Adapted Task: yang19.dlydm2-v0\n",
      "    Performance points: [0.953, 0.956, 0.956, 0.954, 0.954, 0.955, 0.955, 0.954, 0.954, 0.956]\n",
      "Model: nmRNN_NonSpatial\n",
      "  Adapted Task: yang19.multidlydm-v0\n",
      "    Performance points: [0.954, 0.955, 0.956, 0.956, 0.954, 0.955, 0.955, 0.955, 0.954, 0.955]\n",
      "  Adapted Task: yang19.ctxdm1-v0\n",
      "    Performance points: [0.883, 0.882, 0.881, 0.881, 0.883, 0.881, 0.883, 0.883, 0.881, 0.882]\n",
      "  Adapted Task: yang19.rtgo-v0\n",
      "    Performance points: [0.800, 0.800, 0.800, 0.800, 0.800, 0.800, 0.800, 0.800, 0.800, 0.800]\n",
      "  Adapted Task: yang19.ctxdlydm1-v0\n",
      "    Performance points: [0.954, 0.954, 0.955, 0.955, 0.955, 0.954, 0.955, 0.954, 0.954, 0.955]\n",
      "  Adapted Task: yang19.dlydm2-v0\n",
      "    Performance points: [0.954, 0.955, 0.954, 0.955, 0.953, 0.955, 0.954, 0.955, 0.955, 0.954]\n",
      "Model: VanillaRNN\n",
      "  Adapted Task: yang19.multidlydm-v0\n",
      "    Performance points: [0.969, 0.971, 0.973, 0.972, 0.970, 0.969, 0.971, 0.971, 0.971, 0.973]\n",
      "  Adapted Task: yang19.ctxdm1-v0\n",
      "    Performance points: [0.964, 0.965, 0.964, 0.967, 0.967, 0.968, 0.969, 0.968, 0.969, 0.965]\n",
      "  Adapted Task: yang19.rtgo-v0\n",
      "    Performance points: [0.809, 0.818, 0.819, 0.819, 0.820, 0.820, 0.813, 0.815, 0.818, 0.817]\n",
      "  Adapted Task: yang19.ctxdlydm1-v0\n",
      "    Performance points: [0.972, 0.975, 0.972, 0.974, 0.974, 0.974, 0.974, 0.974, 0.974, 0.976]\n",
      "  Adapted Task: yang19.dlydm2-v0\n",
      "    Performance points: [0.974, 0.972, 0.975, 0.973, 0.974, 0.973, 0.972, 0.974, 0.974, 0.975]\n",
      "Model: LSTM\n",
      "  Adapted Task: yang19.multidlydm-v0\n",
      "    Performance points: [0.955, 0.969, 0.966, 0.967, 0.968, 0.968, 0.967, 0.968, 0.969, 0.969]\n",
      "  Adapted Task: yang19.ctxdm1-v0\n",
      "    Performance points: [0.963, 0.972, 0.970, 0.973, 0.973, 0.974, 0.975, 0.973, 0.970, 0.974]\n",
      "  Adapted Task: yang19.rtgo-v0\n",
      "    Performance points: [0.824, 0.835, 0.837, 0.842, 0.842, 0.847, 0.841, 0.848, 0.848, 0.848]\n",
      "  Adapted Task: yang19.ctxdlydm1-v0\n",
      "    Performance points: [0.966, 0.973, 0.971, 0.970, 0.975, 0.970, 0.971, 0.972, 0.971, 0.970]\n",
      "  Adapted Task: yang19.dlydm2-v0\n",
      "    Performance points: [0.972, 0.972, 0.972, 0.973, 0.972, 0.972, 0.975, 0.974, 0.974, 0.974]\n",
      "Model: Transformer\n",
      "  Adapted Task: yang19.multidlydm-v0\n",
      "    Performance points: [0.967, 0.968, 0.968, 0.969, 0.969, 0.969, 0.968, 0.968, 0.969, 0.967]\n",
      "  Adapted Task: yang19.ctxdm1-v0\n",
      "    Performance points: [0.900, 0.902, 0.900, 0.903, 0.902, 0.901, 0.904, 0.902, 0.905, 0.904]\n",
      "  Adapted Task: yang19.rtgo-v0\n",
      "    Performance points: [0.803, 0.812, 0.810, 0.809, 0.808, 0.808, 0.811, 0.808, 0.811, 0.811]\n",
      "  Adapted Task: yang19.ctxdlydm1-v0\n",
      "    Performance points: [0.965, 0.966, 0.967, 0.969, 0.970, 0.969, 0.968, 0.968, 0.967, 0.968]\n",
      "  Adapted Task: yang19.dlydm2-v0\n",
      "    Performance points: [0.968, 0.968, 0.969, 0.967, 0.969, 0.968, 0.970, 0.967, 0.968, 0.969]\n",
      "\n",
      "Summary Table of Best Adapted Performance:\n",
      "               Model          Adapted Task  Best Adapted Performance  Final Adapted Context Sum\n",
      "0      nmRNN_Spatial  yang19.multidlydm-v0                    0.9551                  -0.162027\n",
      "1      nmRNN_Spatial      yang19.ctxdm1-v0                    0.8825                   3.955762\n",
      "2      nmRNN_Spatial        yang19.rtgo-v0                    0.8000                   2.977698\n",
      "3      nmRNN_Spatial   yang19.ctxdlydm1-v0                    0.9557                  -2.054563\n",
      "4      nmRNN_Spatial      yang19.dlydm2-v0                    0.9560                   5.906074\n",
      "5   nmRNN_NonSpatial  yang19.multidlydm-v0                    0.9558                        NaN\n",
      "6   nmRNN_NonSpatial      yang19.ctxdm1-v0                    0.8833                        NaN\n",
      "7   nmRNN_NonSpatial        yang19.rtgo-v0                    0.8000                        NaN\n",
      "8   nmRNN_NonSpatial   yang19.ctxdlydm1-v0                    0.9552                        NaN\n",
      "9   nmRNN_NonSpatial      yang19.dlydm2-v0                    0.9551                        NaN\n",
      "10        VanillaRNN  yang19.multidlydm-v0                    0.9729                  12.014842\n",
      "11        VanillaRNN      yang19.ctxdm1-v0                    0.9693                   1.323712\n",
      "12        VanillaRNN        yang19.rtgo-v0                    0.8199                  11.709137\n",
      "13        VanillaRNN   yang19.ctxdlydm1-v0                    0.9755                   2.129296\n",
      "14        VanillaRNN      yang19.dlydm2-v0                    0.9753                   9.336136\n",
      "15              LSTM  yang19.multidlydm-v0                    0.9688                  -4.475739\n",
      "16              LSTM      yang19.ctxdm1-v0                    0.9752                   0.722906\n",
      "17              LSTM        yang19.rtgo-v0                    0.8481                  -4.098817\n",
      "18              LSTM   yang19.ctxdlydm1-v0                    0.9745                   3.843504\n",
      "19              LSTM      yang19.dlydm2-v0                    0.9746                  -2.858192\n",
      "20       Transformer  yang19.multidlydm-v0                    0.9694                   1.252292\n",
      "21       Transformer      yang19.ctxdm1-v0                    0.9052                   0.888458\n",
      "22       Transformer        yang19.rtgo-v0                    0.8123                   6.347764\n",
      "23       Transformer   yang19.ctxdlydm1-v0                    0.9698                   0.865036\n",
      "24       Transformer      yang19.dlydm2-v0                    0.9703                   0.683540\n",
      "\n",
      "Finished Adaptability Analysis. Results saved in /root/capsule/code/files_seed0_NNM4_tau0p05_hs256_lr0p001_ctx0p01_init10000_ad1000\n",
      "To visualize curves, load the .npy files or use the printed lists/DataFrame with a plotting library.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "from pathlib import Path\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import gym\n",
    "import neurogym as ngym\n",
    "from neurogym.wrappers import ScheduleEnvs\n",
    "# from neurogym.utils.scheduler import RandomSchedule # Not directly used with custom datasets\n",
    "# from models import RNNNet, get_performance # Assuming models.py contains these\n",
    "\n",
    "# import argparse # Removed argparse\n",
    "import numpy as np\n",
    "import random\n",
    "import copy\n",
    "import pandas as pd\n",
    "import math # For nmRNN decay and other calculations\n",
    "import scipy.spatial as ss # For SpatialWeight in nmRNN\n",
    "\n",
    "# --- Notebook-friendly Argument Definition ---\n",
    "class Args:\n",
    "    def __init__(self):\n",
    "        self.seed = 0\n",
    "        self.batch_size = 64\n",
    "        self.seq_len = 100\n",
    "        self.hidden_size = 256\n",
    "        self.lr_model = 1e-3\n",
    "        self.lr_context = 1e-2\n",
    "        self.initial_train_steps = 10000 # Reduced for quicker runs\n",
    "        self.adapt_steps = 1000       # Reduced\n",
    "        self.print_step = 200\n",
    "        self.eval_trials = 100         # Reduced\n",
    "        self.N_NM = 4                 # Number of neuromodulators for nmRNN\n",
    "        self.tau_nmrnn = 0.05       # Time constant (ms) for nmRNN decay calculation\n",
    "        self.grad_clip_nmrnn = 0.1    # Gradient clipping value for nmRNN internal hook (0 or None to disable)\n",
    "\n",
    "args = Args() # Instantiate the new args class\n",
    "\n",
    "# --- Placeholder for models.py content if not in separate file ---\n",
    "# If RNNNet and get_performance are not in a separate models.py, define them here.\n",
    "# For this example, I'll include simplified versions or assume they exist.\n",
    "\n",
    "class RNNNet(nn.Module):\n",
    "    \"\"\"Recurrent network model taken from NeuroGym examples.\"\"\"\n",
    "    def __init__(self, input_size, hidden_size, output_size, dt):\n",
    "        super(RNNNet, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)\n",
    "        self.fc = nn.Linear(hidden_size, output_size)\n",
    "        self.dt = dt # Stored but not directly used in this simple RNN's dynamics\n",
    "\n",
    "    def forward(self, x, hidden=None):\n",
    "        # x shape: (batch_size, seq_len, input_size)\n",
    "        # rnn_out shape: (batch_size, seq_len, hidden_size)\n",
    "        # hidden shape: (1, batch_size, hidden_size)\n",
    "        rnn_out, hidden = self.rnn(x, hidden)\n",
    "        # fc_out shape: (batch_size, seq_len, output_size)\n",
    "        output = self.fc(rnn_out)\n",
    "        return output, hidden\n",
    "\n",
    "def get_performance(model, env, num_trial=100, device='cpu',\n",
    "                    context_vector_global=None, raw_obs_size=None, seq_len_eval=100):\n",
    "    \"\"\"\n",
    "    Evaluates model performance on a given environment.\n",
    "    Handles context vector injection if provided.\n",
    "    Args:\n",
    "        model: The neural network model.\n",
    "        env: The Gym environment instance.\n",
    "        num_trial (int): Number of trials to run for evaluation.\n",
    "        device (str): Device to run evaluation on ('cpu' or 'cuda').\n",
    "        context_vector_global (np.array, optional): A global context vector to be concatenated\n",
    "                                                     to raw observations. Shape (num_total_tasks,).\n",
    "        raw_obs_size (int, optional): The feature size of the raw observation from the env.\n",
    "                                      Required if context_vector_global is used.\n",
    "        seq_len_eval (int): The sequence length for evaluation trials.\n",
    "    Returns:\n",
    "        float: Average performance (accuracy) over trials.\n",
    "    \"\"\"\n",
    "    model.eval() # Set model to evaluation mode for performance assessment\n",
    "    total_correct = 0\n",
    "    total_steps = 0\n",
    "\n",
    "    # Determine sequence length for evaluation\n",
    "    # Priority: env.seq_len > env.spec.max_episode_steps > seq_len_eval (arg)\n",
    "    if hasattr(env, 'seq_len') and env.seq_len is not None:\n",
    "        current_seq_len = env.seq_len\n",
    "    elif hasattr(env, 'spec') and env.spec is not None and hasattr(env.spec, 'max_episode_steps') and env.spec.max_episode_steps is not None:\n",
    "        current_seq_len = env.spec.max_episode_steps\n",
    "    else: \n",
    "        current_seq_len = seq_len_eval\n",
    "\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for _ in range(num_trial):\n",
    "            env_obs = env.reset()\n",
    "            raw_obs = env_obs['observation'] if isinstance(env_obs, dict) and 'observation' in env_obs else env_obs\n",
    "            \n",
    "            hidden_state = None # For stateful models like RNN/LSTM if needed step-by-step\n",
    "\n",
    "            trial_raw_obs_list = []\n",
    "            trial_gt_list = []\n",
    "            \n",
    "            for t in range(current_seq_len): # Iterate for the determined sequence length\n",
    "                trial_raw_obs_list.append(raw_obs)\n",
    "                \n",
    "                action = env.action_space.sample() \n",
    "                env_obs, _, done, info = env.step(action)\n",
    "                raw_obs = env_obs['observation'] if isinstance(env_obs, dict) and 'observation' in env_obs else env_obs\n",
    "                if 'gt' not in info:\n",
    "                    trial_gt_list.append(-1) \n",
    "                else:\n",
    "                    trial_gt_list.append(info['gt'])\n",
    "\n",
    "                if done:\n",
    "                    break \n",
    "            \n",
    "            raw_obs_sequence_np = np.array(trial_raw_obs_list) \n",
    "            gt_sequence_np = np.array(trial_gt_list)       \n",
    "            trial_actual_len = raw_obs_sequence_np.shape[0]\n",
    "\n",
    "            if trial_actual_len == 0: continue \n",
    "\n",
    "            model_input_np = raw_obs_sequence_np\n",
    "            if context_vector_global is not None and raw_obs_size is not None:\n",
    "                if raw_obs_sequence_np.shape[-1] != raw_obs_size:\n",
    "                    # This condition should ideally not be met if raw_obs_size is correctly identified.\n",
    "                    # print(f\"Warning: raw_obs_sequence_np dim {raw_obs_sequence_np.shape[-1]} != raw_obs_size {raw_obs_size} during get_performance.\")\n",
    "                    pass # Assuming raw_obs_size is correct and raw_obs_sequence_np is truly raw.\n",
    "                context_expanded = np.tile(context_vector_global, (trial_actual_len, 1)) \n",
    "                model_input_np = np.concatenate((raw_obs_sequence_np.astype(np.float32), context_expanded.astype(np.float32)), axis=-1)\n",
    "\n",
    "\n",
    "            inputs = torch.from_numpy(model_input_np).type(torch.float).unsqueeze(0).to(device) \n",
    "            \n",
    "            # Pass hidden_state only if model is RNN or LSTM and expects it for sequential processing\n",
    "            # For Transformer or nmRNN_Adapter, hidden_state is often managed internally or not used in this way.\n",
    "            if isinstance(model, (RNNNet, LSTMNet)): \n",
    "                 outputs, hidden_state = model(inputs, hidden_state) # hidden_state would be from previous step if doing step-by-step eval\n",
    "                                                                    # For full sequence eval, hidden_state is initial (None)\n",
    "            else: \n",
    "                 outputs, _ = model(inputs) # Assumes model handles its state or is stateless per call\n",
    "\n",
    "            predictions = torch.argmax(outputs.squeeze(0), dim=-1) \n",
    "            \n",
    "            valid_gt_indices = gt_sequence_np != -1\n",
    "            if not valid_gt_indices.any(): continue\n",
    "\n",
    "            correct_in_trial = (predictions.cpu().numpy()[valid_gt_indices] == gt_sequence_np[valid_gt_indices]).sum()\n",
    "            total_correct += correct_in_trial\n",
    "            total_steps += valid_gt_indices.sum() # Count only valid steps\n",
    "            \n",
    "    return total_correct / total_steps if total_steps > 0 else 0.0\n",
    "\n",
    "\n",
    "# --- Model Definitions (Placeholders for LSTM and Transformer) ---\n",
    "class LSTMNet(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, output_size, dt):\n",
    "        super(LSTMNet, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) \n",
    "        self.fc = nn.Linear(hidden_size, output_size)\n",
    "        self.dt = dt \n",
    "\n",
    "    def forward(self, x, hidden=None):\n",
    "        lstm_out, hidden = self.lstm(x, hidden) \n",
    "        output = self.fc(lstm_out) \n",
    "        return output, hidden\n",
    "\n",
    "class TransformerNet(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, output_size, nhead=4, num_layers=2, dt=None):\n",
    "        super(TransformerNet, self).__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size \n",
    "        self.dt = dt \n",
    "\n",
    "        if input_size != hidden_size:\n",
    "            self.input_proj = nn.Linear(input_size, hidden_size)\n",
    "        else:\n",
    "            self.input_proj = nn.Identity()\n",
    "        \n",
    "        encoder_layers = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=nhead, dim_feedforward=hidden_size*2, batch_first=True, dropout=0.1) \n",
    "        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)\n",
    "        self.fc_out = nn.Linear(hidden_size, output_size)\n",
    "\n",
    "    def forward(self, src, src_mask=None): \n",
    "        src = self.input_proj(src) \n",
    "        output = self.transformer_encoder(src, mask=src_mask) \n",
    "        output = self.fc_out(output) \n",
    "        return output, None \n",
    "\n",
    "\n",
    "# --- nmRNN Model Classes (from user) ---\n",
    "class SpatialWeight(nn.Module):\n",
    "    \"\"\"\n",
    "    Module to compute spatially dependent weights based on neuron positions.\n",
    "    Includes distance decay and inhibitory neuron specification.\n",
    "    \"\"\"\n",
    "    def __init__(self, N_nm, observable_size=64, ell=0.1): # N_nm is the first positional argument\n",
    "        super(SpatialWeight, self).__init__()\n",
    "        \n",
    "        self.register_buffer('pos_const', torch.tensor(np.random.rand(observable_size, 2), dtype=torch.float32))\n",
    "\n",
    "        pos_np = self.pos_const.cpu().numpy() \n",
    "        delpoints_np = ss.distance.cdist(pos_np, pos_np)\n",
    "        \n",
    "        N_nm_effective = N_nm if N_nm > 0 else 1 # Ensure last dim is at least 1 for broadcasting\n",
    "        delpoints_expanded_np = delpoints_np[:, :, None] * np.ones([observable_size, observable_size, N_nm_effective])\n",
    "\n",
    "\n",
    "        self.ell = ell  \n",
    "        pinhib = 0.5  \n",
    "        self.scale = 1.0 \n",
    "\n",
    "        inhib_np = (np.random.choice([0, 1], size=(observable_size,1,1), p=[1 - pinhib, pinhib])) * np.ones_like(delpoints_expanded_np)\n",
    "        self.register_buffer('inhib_const', torch.tensor(inhib_np, dtype=torch.float32))\n",
    "        \n",
    "        self.register_buffer('Delta_const', torch.tensor(delpoints_expanded_np / self.ell, dtype=torch.float32))\n",
    "        \n",
    "        mask_np = np.logical_and(delpoints_expanded_np < 5 * self.ell, np.eye(observable_size)[:, :, None] * np.ones_like(delpoints_expanded_np) == 0)\n",
    "        self.register_buffer('mask_const', torch.tensor(mask_np, dtype=torch.float32))\n",
    "\n",
    "    def forward(self, W):\n",
    "        \"\"\" W is the base weight tensor before spatial modulation. \"\"\"\n",
    "        # W shape: (hidden_size, hidden_size, N_nm)\n",
    "        # Ensure W has the same last dimension as the spatial constants\n",
    "        if W.shape[-1] != self.mask_const.shape[-1]:\n",
    "             # This might happen if N_nm=0, and W is (H,H) but mask_const is (H,H,1)\n",
    "             # This case should be handled by not calling spatialNet if N_nm=0\n",
    "             # Or by ensuring W is appropriately shaped if spatialNet is called with N_nm=0 (N_nm_effective=1)\n",
    "            if W.dim() == 2 and self.mask_const.shape[-1] == 1: # W is (H,H), mask is (H,H,1)\n",
    "                W_expanded = W.unsqueeze(-1) # Expand W to (H,H,1)\n",
    "                return self.scale * ((-1.0)**self.inhib_const) * torch.exp(W_expanded - self.Delta_const) * self.mask_const\n",
    "            else:\n",
    "                raise ValueError(f\"Shape mismatch in SpatialWeight: W shape {W.shape}, mask_const shape {self.mask_const.shape}\")\n",
    "\n",
    "\n",
    "        return self.scale * ((-1.0)**self.inhib_const) * torch.exp(W - self.Delta_const) * self.mask_const\n",
    "\n",
    "\n",
    "class spatial_nmRNNCell_base(nn.Module):\n",
    "    def __init__(self, N_nm, input_size, hidden_size, nonlinearity, bias, keepW0=False, use_spatial_net=True):\n",
    "        super().__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.nonlinearity = nonlinearity \n",
    "        self.N_nm = N_nm \n",
    "        self.keepW0 = keepW0 \n",
    "        self.g = 10.0 \n",
    "        self.use_spatial_net = use_spatial_net\n",
    "\n",
    "        self.weight_ih = nn.Parameter(torch.Tensor(hidden_size, input_size))\n",
    "        \n",
    "        if self.use_spatial_net and self.N_nm > 0: \n",
    "            # FIX: Correct keyword argument from N_NM to N_nm\n",
    "            self.spatialNet_instance = SpatialWeight(N_nm=N_nm, observable_size=hidden_size) \n",
    "            self.base_weight_hh_modulated = nn.Parameter(torch.Tensor(hidden_size, hidden_size, N_nm))\n",
    "        elif self.N_nm > 0 : \n",
    "            self.spatialNet_instance = None\n",
    "            self.weight_hh_direct_modulated = nn.Parameter(torch.Tensor(hidden_size, hidden_size, N_nm)) \n",
    "        else: \n",
    "            self.spatialNet_instance = None\n",
    "            # If N_nm = 0, no modulated recurrent weights are created here. Logic in forward handles it.\n",
    "\n",
    "        self.weight_h2nm = nn.Parameter(torch.Tensor(N_nm, hidden_size)) if N_nm > 0 else None\n",
    "        self.weight_nm2nm = nn.Parameter(torch.Tensor(N_nm, N_nm)) if N_nm > 0 else None\n",
    "\n",
    "        if keepW0:\n",
    "            self.weight0_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size))\n",
    "        else:\n",
    "            self.register_buffer('weight0_hh_const', torch.zeros(hidden_size, hidden_size), persistent=False) \n",
    "\n",
    "        if bias:\n",
    "            self.bias = nn.Parameter(torch.Tensor(hidden_size))\n",
    "        else:\n",
    "            self.register_parameter('bias', None) \n",
    "\n",
    "        self.reset_parameters() \n",
    "\n",
    "    def get_weight0_hh(self): \n",
    "        if self.keepW0:\n",
    "            return self.weight0_hh\n",
    "        else:\n",
    "            return self.weight0_hh_const\n",
    "\n",
    "    def get_modulated_hh_weights(self):\n",
    "        if self.use_spatial_net and self.N_nm > 0 and hasattr(self, 'spatialNet_instance') and self.spatialNet_instance is not None:\n",
    "            return self.spatialNet_instance(self.base_weight_hh_modulated)\n",
    "        elif self.N_nm > 0 and hasattr(self, 'weight_hh_direct_modulated'):\n",
    "             return self.weight_hh_direct_modulated\n",
    "        return None \n",
    "\n",
    "    def reset_parameters(self):\n",
    "        nn.init.kaiming_uniform_(self.weight_ih, a=math.sqrt(5))\n",
    "        \n",
    "        if self.use_spatial_net and self.N_nm > 0 and hasattr(self, 'base_weight_hh_modulated'):\n",
    "            if self.hidden_size > 0:\n",
    "                 nn.init.kaiming_uniform_(self.base_weight_hh_modulated, a=self.g / math.sqrt(self.hidden_size))\n",
    "            else:\n",
    "                 nn.init.zeros_(self.base_weight_hh_modulated)\n",
    "        elif self.N_nm > 0 and hasattr(self, 'weight_hh_direct_modulated'):\n",
    "             if self.hidden_size > 0:\n",
    "                 nn.init.kaiming_uniform_(self.weight_hh_direct_modulated, a=self.g / math.sqrt(self.hidden_size))\n",
    "             else:\n",
    "                 nn.init.zeros_(self.weight_hh_direct_modulated)\n",
    "\n",
    "        if self.weight_h2nm is not None: nn.init.sparse_(self.weight_h2nm, sparsity=0.1) \n",
    "        if self.weight_nm2nm is not None: nn.init.zeros_(self.weight_nm2nm) \n",
    "\n",
    "        if self.keepW0 and isinstance(self.get_weight0_hh(), nn.Parameter): \n",
    "            nn.init.kaiming_uniform_(self.get_weight0_hh(), a=math.sqrt(5))\n",
    "\n",
    "        if self.bias is not None:\n",
    "            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_ih)\n",
    "            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0\n",
    "            nn.init.uniform_(self.bias, -bound, bound)\n",
    "\n",
    "\n",
    "class s_nmRNNCell(spatial_nmRNNCell_base):\n",
    "    def __init__(self, N_nm, input_size, hidden_size, out_size, nonlinearity=None, decay=0.0, bias=True, keepW0=True, use_spatial_net=True):\n",
    "        super().__init__(N_nm, input_size, hidden_size, nonlinearity, bias, keepW0=keepW0, use_spatial_net=use_spatial_net)\n",
    "        self.decay = decay \n",
    "        self.out_size = out_size \n",
    "\n",
    "    def forward(self, input_val, hiddenCombined): \n",
    "        batch_size = input_val.shape[0]\n",
    "        \n",
    "        if hiddenCombined.dim() == 3 and hiddenCombined.shape[0] == 1:\n",
    "            hiddenCombined = hiddenCombined.squeeze(0) \n",
    "\n",
    "        if self.N_nm > 0:\n",
    "            hidden = hiddenCombined[:, :self.hidden_size] \n",
    "            nm = hiddenCombined[:, self.hidden_size:]    \n",
    "            if nm.shape[1] != self.N_nm: # Check consistency\n",
    "                if self.N_nm == 0 and nm.shape[1] == 0: # This is fine if N_nm is 0\n",
    "                    pass\n",
    "                else:\n",
    "                    raise ValueError(f\"NM state slice incorrect in s_nmRNNCell. Expected {self.N_nm} features, got {nm.shape[1]}. HiddenCombined shape: {hiddenCombined.shape}\")\n",
    "        else: # N_nm is 0\n",
    "            hidden = hiddenCombined\n",
    "            nm = None \n",
    "\n",
    "        pre_activity = torch.matmul(input_val, self.weight_ih.t()) \n",
    "        \n",
    "        current_weight0_hh = self.get_weight0_hh() \n",
    "        pre_activity += torch.matmul(hidden, current_weight0_hh.t())\n",
    "\n",
    "        current_modulated_hh = self.get_modulated_hh_weights()\n",
    "        if nm is not None and self.N_nm > 0 and current_modulated_hh is not None:\n",
    "            modulated_rec = torch.einsum('bj,ijk,bk->bi', hidden, current_modulated_hh, nm)\n",
    "            pre_activity += modulated_rec \n",
    "\n",
    "        if self.bias is not None:\n",
    "            pre_activity += self.bias \n",
    "\n",
    "        activity = self.nonlinearity(pre_activity)\n",
    "        hidden_new = self.decay * hidden + (1 - self.decay) * activity\n",
    "\n",
    "        if nm is not None and self.N_nm > 0 and self.weight_h2nm is not None and self.weight_nm2nm is not None:\n",
    "            pre_activity_nm = torch.matmul(hidden, self.weight_h2nm.t()) \n",
    "            pre_activity_nm += torch.matmul(nm, self.weight_nm2nm.t()) \n",
    "            activity_nm = self.nonlinearity(pre_activity_nm) \n",
    "            nm_new = self.decay * nm + (1 - self.decay) * activity_nm\n",
    "            hiddenCombined_new = torch.cat([hidden_new, nm_new], dim=1) \n",
    "        else:\n",
    "            hiddenCombined_new = hidden_new\n",
    "\n",
    "        return hiddenCombined_new.unsqueeze(0)\n",
    "\n",
    "\n",
    "class s_nmRNNLayer(nn.Module):\n",
    "    def __init__(self, N_nm, input_size, hidden_size, out_size, nonlinearity, decay=0.9, bias=False, keepW0=False, use_spatial_net=True):\n",
    "        super().__init__()\n",
    "        self.rnncell = s_nmRNNCell(N_nm, input_size, hidden_size, out_size, nonlinearity=nonlinearity, decay=decay, bias=bias, keepW0=keepW0, use_spatial_net=use_spatial_net)\n",
    "        self.N_nm = N_nm\n",
    "        self.hidden_size = hidden_size \n",
    "        self.out_size = out_size \n",
    "        \n",
    "        if self.N_nm > 0:\n",
    "            self.weight_readout = nn.Parameter(torch.Tensor(self.out_size, self.hidden_size, self.N_nm))\n",
    "            if self.hidden_size > 0 : \n",
    "                 nn.init.kaiming_uniform_(self.weight_readout, a=1/(math.sqrt(self.hidden_size) if self.hidden_size > 0 else 1.0))\n",
    "            else:\n",
    "                 nn.init.zeros_(self.weight_readout)\n",
    "        else: \n",
    "            self.fc_readout_no_nm = nn.Linear(self.hidden_size, self.out_size)\n",
    "\n",
    "\n",
    "    def forward(self, input_val, initH): \n",
    "        expected_initH_features = self.hidden_size + self.N_nm\n",
    "        if initH.shape[-1] != expected_initH_features:\n",
    "             raise ValueError(f\"Initial hidden state feature dimension mismatch. Expected {expected_initH_features}, got {initH.shape[-1]}. initH shape: {initH.shape}\")\n",
    "\n",
    "        inputs_unbound = input_val.unbind(0) \n",
    "        current_hidden_combined = initH \n",
    "        hidden_all_history = []\n",
    "        outputs_readout_list = []\n",
    "\n",
    "        for i in range(len(inputs_unbound)):\n",
    "            current_hidden_combined = self.rnncell(inputs_unbound[i], current_hidden_combined) \n",
    "            current_combined_state_squeezed = current_hidden_combined.squeeze(0)\n",
    "            \n",
    "            if self.N_nm > 0:\n",
    "                rates = current_combined_state_squeezed[:, :self.hidden_size] \n",
    "                nm_state = current_combined_state_squeezed[:, self.hidden_size:]    \n",
    "                if nm_state.shape[1] != self.N_nm: \n",
    "                    raise ValueError(f\"NM state slice incorrect in s_nmRNNLayer. Expected {self.N_nm} features, got {nm_state.shape[1]}\")\n",
    "                outputs_readout_list.append(torch.einsum('bj,ijk,bk->bi', rates, self.weight_readout, nm_state))\n",
    "            else: \n",
    "                rates = current_combined_state_squeezed \n",
    "                outputs_readout_list.append(self.fc_readout_no_nm(rates))\n",
    "\n",
    "            hidden_all_history.append(current_combined_state_squeezed) \n",
    "\n",
    "        outputs_stacked = torch.stack(outputs_readout_list, dim=0) \n",
    "        history_stacked = torch.stack(hidden_all_history, dim=0) \n",
    "        return outputs_stacked, history_stacked, current_hidden_combined\n",
    "\n",
    "\n",
    "class Model_nm(nn.Module):\n",
    "    def __init__(self, hp, RNNLayer_class): \n",
    "        super().__init__()\n",
    "        n_input = hp['n_input']\n",
    "        n_rnn = hp['n_rnn'] \n",
    "        n_output = hp['n_output'] \n",
    "        decay = hp['decay']\n",
    "        N_NM = hp['N_NM']\n",
    "        bias = hp.get('bias', True) \n",
    "        keepW0 = hp.get('keepW0', True) \n",
    "        clip_value = hp.get('grad_clip', None) \n",
    "        use_spatial_net = hp.get('use_spatial_net', True) \n",
    "        activation_str = hp.get('activation', 'relu') \n",
    "\n",
    "        if activation_str == 'relu':\n",
    "            nonlinearity = nn.ReLU()\n",
    "        elif activation_str == 'tanh':\n",
    "            nonlinearity = nn.Tanh()\n",
    "        else:\n",
    "            raise ValueError(f\"Unsupported activation function: {activation_str}\")\n",
    "\n",
    "        self.n_rnn = n_rnn\n",
    "        self.N_NM = N_NM\n",
    "\n",
    "        self.rnn = RNNLayer_class(N_NM, n_input, n_rnn, n_output, nonlinearity, decay, bias=bias, keepW0=keepW0, use_spatial_net=use_spatial_net)\n",
    "\n",
    "        if clip_value is not None and clip_value > 0:\n",
    "            for p in self.parameters():\n",
    "                if p.requires_grad:\n",
    "                    p.register_hook(lambda grad: torch.clamp(grad, -clip_value, clip_value) if grad is not None else None)\n",
    "\n",
    "\n",
    "    def forward(self, x, device='cpu'): # x: (Time, batch, input_size)\n",
    "        batch_size = x.shape[1]\n",
    "        hidden0_rnn = torch.zeros(1, batch_size, self.n_rnn, device=device)\n",
    "        \n",
    "        if self.N_NM > 0:\n",
    "            nm0 = torch.zeros(1, batch_size, self.N_NM, device=device)\n",
    "            hiddenCombined0_for_layer = torch.cat([hidden0_rnn, nm0], dim=2)\n",
    "        else: \n",
    "            hiddenCombined0_for_layer = hidden0_rnn \n",
    "        \n",
    "        output_readout, hiddenCombined_seq, final_hc_state = self.rnn(x, hiddenCombined0_for_layer)\n",
    "        \n",
    "        hidden_rnn_seq = hiddenCombined_seq[:, :, :self.n_rnn] \n",
    "        if self.N_NM > 0:\n",
    "            nm_seq = hiddenCombined_seq[:, :, self.n_rnn : self.n_rnn + self.N_NM]\n",
    "        else: \n",
    "            nm_seq = torch.empty(hiddenCombined_seq.shape[0], hiddenCombined_seq.shape[1], 0, device=device)\n",
    "\n",
    "        return output_readout, hidden_rnn_seq, nm_seq\n",
    "\n",
    "\n",
    "class nmRNN_Adapter(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, output_size, dt, N_NM, tau_nmrnn, use_spatial_net, activation='relu', bias=True, keepW0=True, grad_clip_nmrnn=None): \n",
    "        super().__init__()\n",
    "        self.dt = float(dt) \n",
    "        self.tau_nmrnn = float(tau_nmrnn) \n",
    "        if self.dt <= 0: raise ValueError(\"dt must be positive.\")\n",
    "        if self.tau_nmrnn <= 0: raise ValueError(\"tau_nmrnn must be positive for decay calculation.\")\n",
    "        \n",
    "        decay_val = self.tau_nmrnn\n",
    "\n",
    "        hp = {\n",
    "            'n_input': input_size,\n",
    "            'n_rnn': hidden_size, \n",
    "            'n_output': output_size, \n",
    "            'decay': decay_val,\n",
    "            'N_NM': N_NM,\n",
    "            'activation': activation,\n",
    "            'bias': bias,\n",
    "            'keepW0': keepW0,\n",
    "            'grad_clip': grad_clip_nmrnn, \n",
    "            'use_spatial_net': use_spatial_net\n",
    "        }\n",
    "        self.model_nm = Model_nm(hp, s_nmRNNLayer) \n",
    "\n",
    "    def forward(self, x, hidden_state_ignored=None): # x: (batch, seq, features)\n",
    "        current_device = x.device\n",
    "        x_transposed = x.transpose(0, 1) # (seq, batch, features)\n",
    "        \n",
    "        output_readout_seq_timefirst, hidden_rnn_seq_timefirst, _ = self.model_nm(x_transposed, device=current_device)\n",
    "        \n",
    "        output_final_batchfirst = output_readout_seq_timefirst.transpose(0, 1) \n",
    "        hidden_to_return_batchfirst = hidden_rnn_seq_timefirst.transpose(0,1) \n",
    "        \n",
    "        return output_final_batchfirst, hidden_to_return_batchfirst\n",
    "\n",
    "\n",
    "# --- Main script execution starts here ---\n",
    "def set_seed(seed):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "    print(f\"Running with seed {seed}!\")\n",
    "\n",
    "set_seed(args.seed)\n",
    "\n",
    "# --- Create Save Directory ---\n",
    "path_str = f'files_seed{args.seed}_NNM{args.N_NM}_tau{args.tau_nmrnn}_hs{args.hidden_size}_lr{args.lr_model}_ctx{args.lr_context}_init{args.initial_train_steps}_ad{args.adapt_steps}'\n",
    "path = Path('.') / path_str.replace('.', 'p') \n",
    "os.makedirs(path, exist_ok=True)\n",
    "print(f\"Saving results to: {path.resolve()}\")\n",
    "\n",
    "# --- Environment Setup ---\n",
    "kwargs_env = {'dt': 100.0} \n",
    "all_tasks_collection_names = ngym.get_collection('yang19') \n",
    "\n",
    "if not all_tasks_collection_names or len(all_tasks_collection_names) < 20:\n",
    "    print(\"Default Yang19 task collection issue. Using a manual list.\")\n",
    "    all_tasks_collection_names = [ \n",
    "        'ContextDecisionMaking-v0', 'DecisionMaking-v0', 'DelayComparison-v0',\n",
    "        'DelayMatchCategory-v0', 'DelayMatchSample-v0', 'Detection-v0',\n",
    "        'DualDecisionMaking-v0', 'GoNogo-v0', 'IntervalDiscrimination-v0',\n",
    "        'MatchCategory-v0', 'MatchSample-v0', 'MultisensoryIntegration-v0',\n",
    "        'PerceptualDecisionMaking-v0', 'ReachingDelayResponse-v0', 'SpatialSuppressMotion-v0',\n",
    "        'AntiReach-v0', 'Countermanding-v0', 'EconomicDecisionMaking-v0', \n",
    "        'MotorTiming-v0', 'ProbabilisticReasoning-v0'\n",
    "    ]\n",
    "    if len(all_tasks_collection_names) < 20:\n",
    "         raise ValueError(\"Fallback Yang19 task list also has less than 20 tasks.\")\n",
    "    try:\n",
    "        gym.make(all_tasks_collection_names[0], **kwargs_env).close()\n",
    "    except Exception as e:\n",
    "        print(f\"Error making fallback task: {e}. Ensure neurogym tasks are registered or list is correct.\")\n",
    "        raise\n",
    "\n",
    "task_indices = list(range(len(all_tasks_collection_names)))\n",
    "random.shuffle(task_indices) \n",
    "\n",
    "train_task_indices = task_indices[:15]\n",
    "adapt_task_indices = task_indices[15:20]\n",
    "\n",
    "train_tasks_names = [all_tasks_collection_names[i] for i in train_task_indices]\n",
    "adapt_tasks_names = [all_tasks_collection_names[i] for i in adapt_task_indices]\n",
    "\n",
    "print(f\"Training on {len(train_tasks_names)} tasks (first 3): {train_tasks_names[:3]}\")\n",
    "print(f\"Adapting to {len(adapt_tasks_names)} tasks: {adapt_tasks_names}\")\n",
    "\n",
    "_temp_env = gym.make(all_tasks_collection_names[0], **kwargs_env)\n",
    "if isinstance(_temp_env.observation_space, gym.spaces.Dict):\n",
    "    raw_ob_size = _temp_env.observation_space['observation'].shape[0]\n",
    "else:\n",
    "    raw_ob_size = _temp_env.observation_space.shape[0]\n",
    "act_size = _temp_env.action_space.n\n",
    "_temp_env.close()\n",
    "\n",
    "num_total_tasks = len(all_tasks_collection_names) \n",
    "model_input_size = raw_ob_size + num_total_tasks\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "print(f\"Device: {device}\")\n",
    "\n",
    "# --- Model Dictionary ---\n",
    "models_to_test = {\n",
    "    \"nmRNN_Spatial\": lambda: nmRNN_Adapter(input_size=model_input_size, hidden_size=args.hidden_size, output_size=act_size, dt=kwargs_env['dt'],\n",
    "                                           N_NM=args.N_NM, tau_nmrnn=args.tau_nmrnn, use_spatial_net=True, grad_clip_nmrnn=args.grad_clip_nmrnn).to(device),\n",
    "    \"nmRNN_NonSpatial\": lambda: nmRNN_Adapter(input_size=model_input_size, hidden_size=args.hidden_size, output_size=act_size, dt=kwargs_env['dt'],\n",
    "                                              N_NM=args.N_NM, tau_nmrnn=args.tau_nmrnn, use_spatial_net=False, grad_clip_nmrnn=args.grad_clip_nmrnn).to(device),\n",
    "    \"VanillaRNN\": lambda: RNNNet(input_size=model_input_size, hidden_size=args.hidden_size, output_size=act_size, dt=kwargs_env['dt']).to(device),\n",
    "    \"LSTM\": lambda: LSTMNet(input_size=model_input_size, hidden_size=args.hidden_size, output_size=act_size, dt=kwargs_env['dt']).to(device),\n",
    "    \"Transformer\": lambda: TransformerNet(input_size=model_input_size, hidden_size=args.hidden_size, output_size=act_size, dt=kwargs_env['dt'], nhead=4, num_layers=2).to(device),\n",
    "}\n",
    "if args.N_NM == 0: \n",
    "    print(\"N_NM is 0, running only nmRNN_NonSpatial (effectively a standard RNN with nmRNN structure but no NMs).\")\n",
    "    models_to_test.pop(\"nmRNN_Spatial\", None) \n",
    "    if \"nmRNN_NonSpatial\" in models_to_test: # Rename for clarity\n",
    "        models_to_test[\"nmRNN_N0\"] = models_to_test.pop(\"nmRNN_NonSpatial\")\n",
    "\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "# --- Results Storage ---\n",
    "initial_training_performance_curves = {model_name: [] for model_name in models_to_test}\n",
    "adaptation_performance_curves = {model_name: {task_name: [] for task_name in adapt_tasks_names} for model_name in models_to_test}\n",
    "summary_table_data = []\n",
    "\n",
    "\n",
    "# --- Custom Dataset for Training (Phase 1) ---\n",
    "class Yang19TrainDataset:\n",
    "    def __init__(self, task_names_to_train, all_task_names_ordered_list, batch_size, seq_len, env_kwargs_dict):\n",
    "        self.batch_size = batch_size\n",
    "        self.seq_len = seq_len\n",
    "        self.task_names_to_train = task_names_to_train\n",
    "        self.all_task_names_ordered_list = all_task_names_ordered_list\n",
    "        self.num_total_tasks = len(self.all_task_names_ordered_list)\n",
    "        self.env_kwargs_dict = env_kwargs_dict\n",
    "        \n",
    "        self.envs_dict = {}\n",
    "        valid_task_names_for_training = []\n",
    "        for name in self.task_names_to_train:\n",
    "            try:\n",
    "                self.envs_dict[name] = gym.make(name, **self.env_kwargs_dict)\n",
    "                valid_task_names_for_training.append(name)\n",
    "            except Exception as e:\n",
    "                print(f\"Could not make environment {name} for training dataset: {e}. It will be excluded.\")\n",
    "        \n",
    "        self.task_names_to_train = valid_task_names_for_training # Update to only include valid tasks\n",
    "        self.task_global_indices = {name: self.all_task_names_ordered_list.index(name) for name in self.task_names_to_train}\n",
    "\n",
    "        if not self.task_names_to_train or not self.envs_dict:\n",
    "            raise ValueError(\"No valid environments could be created for Yang19TrainDataset.\")\n",
    "        self.current_task_name = random.choice(list(self.envs_dict.keys()))\n",
    "\n",
    "\n",
    "    def __call__(self):\n",
    "        inputs_batch = []\n",
    "        labels_batch = []\n",
    "        for _ in range(self.batch_size):\n",
    "            chosen_task_name = random.choice(list(self.envs_dict.keys())) \n",
    "            current_env = self.envs_dict[chosen_task_name]\n",
    "            task_global_idx = self.task_global_indices[chosen_task_name]\n",
    "\n",
    "            raw_obs_list_item = []\n",
    "            gt_actions_list_item = []\n",
    "            \n",
    "            env_obs = current_env.reset()\n",
    "            current_raw_obs = env_obs['observation'] if isinstance(env_obs, dict) and 'observation' in env_obs else env_obs\n",
    "\n",
    "            for step_num in range(self.seq_len):\n",
    "                raw_obs_list_item.append(current_raw_obs)\n",
    "                action = current_env.action_space.sample() \n",
    "                env_obs, _, done, info = current_env.step(action)\n",
    "                current_raw_obs = env_obs['observation'] if isinstance(env_obs, dict) and 'observation' in env_obs else env_obs\n",
    "                \n",
    "                if 'gt' not in info: \n",
    "                    gt_actions_list_item.append(-1) \n",
    "                else:\n",
    "                    gt_actions_list_item.append(info['gt'])\n",
    "                \n",
    "                if done: \n",
    "                    padding_needed = self.seq_len - (step_num + 1)\n",
    "                    if padding_needed > 0:\n",
    "                        raw_obs_list_item.extend([current_raw_obs] * padding_needed)\n",
    "                        last_gt = gt_actions_list_item[-1] if gt_actions_list_item else -1\n",
    "                        gt_actions_list_item.extend([last_gt] * padding_needed)\n",
    "                    break\n",
    "            \n",
    "            raw_obs_seq_np = np.array(raw_obs_list_item) \n",
    "            context_vec = np.zeros((self.seq_len, self.num_total_tasks), dtype=np.float32) \n",
    "            context_vec[:, task_global_idx] = 1.0\n",
    "            model_input_seq = np.concatenate((raw_obs_seq_np.astype(np.float32), context_vec), axis=-1)\n",
    "            \n",
    "            inputs_batch.append(model_input_seq)\n",
    "            labels_batch.append(np.array(gt_actions_list_item))\n",
    "\n",
    "        return np.array(inputs_batch), np.array(labels_batch)\n",
    "\n",
    "    def close_envs(self):\n",
    "        for env in self.envs_dict.values():\n",
    "            env.close()\n",
    "\n",
    "# --- Custom Dataset for Adaptation (Phase 2) ---\n",
    "class SingleTaskDataset:\n",
    "    def __init__(self, env_instance, batch_size, seq_len): \n",
    "        self.env = env_instance\n",
    "        self.batch_size = batch_size\n",
    "        self.seq_len = seq_len\n",
    "\n",
    "    def __call__(self):\n",
    "        obs_batch = []\n",
    "        labels_batch = []\n",
    "        for _ in range(self.batch_size):\n",
    "            obs_list_item = []\n",
    "            gt_list_item = []\n",
    "            env_obs = self.env.reset()\n",
    "            current_raw_obs = env_obs['observation'] if isinstance(env_obs, dict) and 'observation' in env_obs else env_obs\n",
    "\n",
    "            for step_num in range(self.seq_len):\n",
    "                obs_list_item.append(current_raw_obs)\n",
    "                action = self.env.action_space.sample() \n",
    "                env_obs, _, done, info = self.env.step(action)\n",
    "                current_raw_obs = env_obs['observation'] if isinstance(env_obs, dict) and 'observation' in env_obs else env_obs\n",
    "                \n",
    "                if 'gt' not in info:\n",
    "                    gt_list_item.append(-1)\n",
    "                else:\n",
    "                    gt_list_item.append(info['gt'])\n",
    "\n",
    "                if done:\n",
    "                    padding_needed = self.seq_len - (step_num + 1)\n",
    "                    if padding_needed > 0:\n",
    "                        obs_list_item.extend([current_raw_obs] * padding_needed)\n",
    "                        last_gt = gt_list_item[-1] if gt_list_item else -1\n",
    "                        gt_list_item.extend([last_gt] * padding_needed)\n",
    "                    break\n",
    "            obs_batch.append(np.array(obs_list_item))\n",
    "            labels_batch.append(np.array(gt_list_item))\n",
    "        return np.array(obs_batch, dtype=np.float32), np.array(labels_batch) \n",
    "\n",
    "\n",
    "# --- Main Loop for Each Model Type ---\n",
    "for model_name, model_constructor in models_to_test.items():\n",
    "    print(f\"\\n\\n--- Training and Adapting Model: {model_name} ---\")\n",
    "    model = model_constructor() \n",
    "    # print(model) # Optional: print model structure\n",
    "    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "    print(f\"Total trainable parameters for {model_name}: {total_params:,}\")\n",
    "\n",
    "    # === Phase 1: Initial Training ===\n",
    "    print(f\"\\n--- Phase 1: Initial Training on {len(train_tasks_names)} tasks for {model_name} ---\")\n",
    "    try:\n",
    "        dataset_phase1 = Yang19TrainDataset(train_tasks_names, all_tasks_collection_names, args.batch_size, args.seq_len, env_kwargs_dict=kwargs_env)\n",
    "    except ValueError as e:\n",
    "        print(f\"Error initializing dataset for Phase 1 for model {model_name}: {e}. Skipping this model.\")\n",
    "        continue # Skip to the next model if dataset fails\n",
    "\n",
    "    optimizer_model = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr_model)\n",
    "    \n",
    "    running_loss = 0.0\n",
    "    current_initial_train_perf_curve = []\n",
    "    time_phase1_start = time.time()\n",
    "\n",
    "    for i in range(args.initial_train_steps):\n",
    "        model.train() # Ensure model is in training mode for this phase\n",
    "        inputs, labels = dataset_phase1() \n",
    "        inputs = torch.from_numpy(inputs).type(torch.float).to(device) \n",
    "        labels_flat = torch.from_numpy(labels.flatten()).type(torch.long).to(device)\n",
    "\n",
    "        valid_indices = labels_flat != -1\n",
    "        if not valid_indices.any(): \n",
    "            continue\n",
    "        \n",
    "        optimizer_model.zero_grad()\n",
    "        outputs, _ = model(inputs) \n",
    "        \n",
    "        outputs_reshaped = outputs.reshape(-1, act_size) \n",
    "        outputs_for_loss = outputs_reshaped[valid_indices]\n",
    "        labels_filtered = labels_flat[valid_indices]\n",
    "        \n",
    "        if outputs_for_loss.shape[0] == 0: \n",
    "            continue\n",
    "\n",
    "        loss = criterion(outputs_for_loss, labels_filtered)\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, model.parameters()), 1.0) \n",
    "        optimizer_model.step()\n",
    "        running_loss += loss.item()\n",
    "\n",
    "        if i % args.print_step == (args.print_step - 1) or i == args.initial_train_steps -1 :\n",
    "            steps_in_epoch = args.print_step if i % args.print_step == (args.print_step -1) else (i % args.print_step) + 1\n",
    "            avg_loss = running_loss / steps_in_epoch if steps_in_epoch > 0 else running_loss\n",
    "            print(f'Model: {model_name}, Initial Train Step: {i + 1}/{args.initial_train_steps}, Loss: {avg_loss:.4f}')\n",
    "            if i % args.print_step == (args.print_step -1): running_loss = 0.0\n",
    "\n",
    "            avg_perf_on_train_tasks = 0.0\n",
    "            temp_perfs = []\n",
    "            num_eval_tasks = len(dataset_phase1.task_names_to_train) # Use actual number of valid tasks\n",
    "            trials_per_eval_task = max(1, args.eval_trials // num_eval_tasks if num_eval_tasks > 0 else args.eval_trials)\n",
    "\n",
    "            for train_task_name_eval in dataset_phase1.task_names_to_train: \n",
    "                task_global_idx = all_tasks_collection_names.index(train_task_name_eval)\n",
    "                context_val = np.zeros(num_total_tasks, dtype=np.float32)\n",
    "                context_val[task_global_idx] = 1.0\n",
    "                \n",
    "                try:\n",
    "                    eval_env_train = gym.make(train_task_name_eval, **kwargs_env)\n",
    "                    perf = get_performance(model, eval_env_train, num_trial=trials_per_eval_task, \n",
    "                                           device=device, context_vector_global=context_val, \n",
    "                                           raw_obs_size=raw_ob_size, seq_len_eval=args.seq_len)\n",
    "                    temp_perfs.append(perf)\n",
    "                    eval_env_train.close()\n",
    "                except Exception as e:\n",
    "                    print(f\"Error evaluating task {train_task_name_eval} during training: {e}\")\n",
    "            \n",
    "            if temp_perfs: avg_perf_on_train_tasks = np.mean(temp_perfs)\n",
    "            current_initial_train_perf_curve.append(avg_perf_on_train_tasks)\n",
    "            print(f'Model: {model_name}, Initial Train Step: {i + 1}, Avg Perf on Train Set: {avg_perf_on_train_tasks:.3f}')\n",
    "    \n",
    "    time_phase1_end = time.time()\n",
    "    print(f\"Finished Initial Training for {model_name} in {time_phase1_end - time_phase1_start:.2f} seconds.\")\n",
    "    initial_training_performance_curves[model_name] = current_initial_train_perf_curve\n",
    "    trained_model_state_dict = copy.deepcopy(model.state_dict()) \n",
    "    dataset_phase1.close_envs() \n",
    "\n",
    "\n",
    "    # === Phase 2: Adaptation ===\n",
    "    print(f\"\\n--- Phase 2: Adaptation for {model_name} on {len(adapt_tasks_names)} tasks ---\")\n",
    "    time_phase2_start = time.time()\n",
    "    \n",
    "    for adapt_task_name in adapt_tasks_names:\n",
    "        print(f\"\\n-- Adapting {model_name} to task: {adapt_task_name} --\")\n",
    "        model.load_state_dict(trained_model_state_dict) \n",
    "        \n",
    "        for param in model.parameters():\n",
    "            param.requires_grad = False \n",
    "\n",
    "        trainable_context = torch.randn(num_total_tasks, device=device, requires_grad=True) \n",
    "        optimizer_context = torch.optim.Adam([trainable_context], lr=args.lr_context)\n",
    "\n",
    "        try:\n",
    "            adapt_env_instance = gym.make(adapt_task_name, **kwargs_env)\n",
    "        except Exception as e:\n",
    "            print(f\"Could not make environment {adapt_task_name} for adaptation. Skipping. Error: {e}\")\n",
    "            adaptation_performance_curves[model_name][adapt_task_name] = [] \n",
    "            summary_table_data.append({\n",
    "                \"Model\": model_name, \"Adapted Task\": adapt_task_name,\n",
    "                \"Best Adapted Performance\": 0.0, \"Final Adapted Context Sum\": 0.0\n",
    "            })\n",
    "            continue \n",
    "\n",
    "        adapt_dataset = SingleTaskDataset(adapt_env_instance, args.batch_size, args.seq_len)\n",
    "        \n",
    "        current_adapt_perf_curve = []\n",
    "        best_adapted_perf = 0.0\n",
    "        running_adapt_loss = 0.0\n",
    "\n",
    "        for adapt_step in range(args.adapt_steps):\n",
    "            model.train() \n",
    "            \n",
    "            raw_observations, labels = adapt_dataset() \n",
    "            raw_observations = torch.from_numpy(raw_observations).type(torch.float).to(device) \n",
    "            labels_flat = torch.from_numpy(labels.flatten()).type(torch.long).to(device)\n",
    "\n",
    "            valid_indices_adapt = labels_flat != -1\n",
    "            if not valid_indices_adapt.any(): continue\n",
    "\n",
    "            context_expanded = trainable_context.unsqueeze(0).unsqueeze(0).repeat(raw_observations.shape[0], raw_observations.shape[1], 1)\n",
    "            model_inputs = torch.cat((raw_observations, context_expanded), dim=-1)\n",
    "\n",
    "            optimizer_context.zero_grad()\n",
    "            outputs, _ = model(model_inputs) \n",
    "            \n",
    "            outputs_reshaped_adapt = outputs.reshape(-1, act_size)\n",
    "            outputs_for_loss_adapt = outputs_reshaped_adapt[valid_indices_adapt]\n",
    "            labels_filtered_adapt = labels_flat[valid_indices_adapt]\n",
    "\n",
    "            if outputs_for_loss_adapt.shape[0] == 0: continue\n",
    "\n",
    "            loss = criterion(outputs_for_loss_adapt, labels_filtered_adapt)\n",
    "            loss.backward() \n",
    "            optimizer_context.step()\n",
    "            running_adapt_loss += loss.item()\n",
    "\n",
    "            if adapt_step % (args.print_step // 2) == (args.print_step // 2 - 1) or adapt_step == args.adapt_steps -1:\n",
    "                steps_in_eval_period = (args.print_step // 2) if adapt_step % (args.print_step // 2) == (args.print_step // 2 - 1) else (adapt_step % (args.print_step//2)) +1\n",
    "                avg_adapt_loss = running_adapt_loss / steps_in_eval_period if steps_in_eval_period > 0 else running_adapt_loss\n",
    "                if adapt_step % (args.print_step // 2) == (args.print_step // 2 - 1): running_adapt_loss = 0.0\n",
    "                \n",
    "                current_ctx_np = trainable_context.detach().cpu().numpy()\n",
    "                \n",
    "                eval_adapt_env_temp = gym.make(adapt_task_name, **kwargs_env) \n",
    "                perf = get_performance(model, eval_adapt_env_temp, num_trial=args.eval_trials, device=device,\n",
    "                                       context_vector_global=current_ctx_np, raw_obs_size=raw_ob_size, seq_len_eval=args.seq_len)\n",
    "                eval_adapt_env_temp.close()\n",
    "                \n",
    "                current_adapt_perf_curve.append(perf)\n",
    "                if perf > best_adapted_perf: best_adapted_perf = perf\n",
    "                \n",
    "                print(f'  Adapt Task: {adapt_task_name}, Step: {adapt_step + 1}/{args.adapt_steps}, Loss: {avg_adapt_loss:.4f}, Perf: {perf:.3f}')\n",
    "        \n",
    "        adaptation_performance_curves[model_name][adapt_task_name] = current_adapt_perf_curve\n",
    "        summary_table_data.append({\n",
    "            \"Model\": model_name,\n",
    "            \"Adapted Task\": adapt_task_name,\n",
    "            \"Best Adapted Performance\": best_adapted_perf,\n",
    "            \"Final Adapted Context Sum\": trainable_context.detach().sum().item() \n",
    "        })\n",
    "        adapt_env_instance.close()\n",
    "    \n",
    "    time_phase2_end = time.time()\n",
    "    print(f\"Finished Adaptation for {model_name} in {time_phase2_end - time_phase2_start:.2f} seconds.\")\n",
    "\n",
    "    del model \n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.empty_cache()\n",
    "\n",
    "\n",
    "# --- Reporting Results ---\n",
    "print(\"\\n\\n--- Overall Results ---\")\n",
    "\n",
    "print(\"\\nInitial Training Performance Curves (Avg Perf on 15 Train Tasks vs. Eval Points):\")\n",
    "for model_name, curve in initial_training_performance_curves.items():\n",
    "    print(f\"Model: {model_name}\")\n",
    "    curve_str = \", \".join([f\"{p:.3f}\" for p in curve])\n",
    "    print(f\"  Performance points: [{curve_str}]\")\n",
    "    np.save(path / f'{model_name}_initial_train_curve.npy', np.array(curve))\n",
    "\n",
    "print(\"\\nAdaptation Training Curves (Perf on Held-out Task vs. Eval Points):\")\n",
    "for model_name, task_curves in adaptation_performance_curves.items():\n",
    "    print(f\"Model: {model_name}\")\n",
    "    for task_name, curve in task_curves.items():\n",
    "        print(f\"  Adapted Task: {task_name}\")\n",
    "        if curve: \n",
    "            curve_str = \", \".join([f\"{p:.3f}\" for p in curve])\n",
    "            print(f\"    Performance points: [{curve_str}]\")\n",
    "            np.save(path / f'{model_name}_adapt_curve_{task_name.replace(\" \", \"_\").replace(\"-v0\", \"\")}.npy', np.array(curve))\n",
    "        else:\n",
    "            print(f\"    No performance data recorded (task might have failed to initialize).\")\n",
    "\n",
    "\n",
    "print(\"\\nSummary Table of Best Adapted Performance:\")\n",
    "if summary_table_data:\n",
    "    summary_df = pd.DataFrame(summary_table_data)\n",
    "    print(summary_df.to_string()) \n",
    "    summary_df.to_csv(path / 'adaptation_summary.csv', index=False)\n",
    "else:\n",
    "    print(\"No summary data to report.\")\n",
    "\n",
    "print(f\"\\nFinished Adaptability Analysis. Results saved in {path.resolve()}\")\n",
    "print(\"To visualize curves, load the .npy files or use the printed lists/DataFrame with a plotting library.\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5970e92b-27b5-469e-bb02-f7bc0a09df9e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
