{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f8c0828e",
   "metadata": {},
   "source": [
    "## This notebook is just a check. I want to understand if when freeze_encoder = False the model actually changes it's weights + I want to see predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bd650fb8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package cmudict to /home/XXXXXX/nltk_data...\n",
      "[nltk_data]   Package cmudict is already up-to-date!\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "import pickle\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "from torch.utils.data import DataLoader\n",
    "import torch\n",
    "from dataset import SpeechSentenceDataset, idsToPhonemes, getDatasetLoaders,getDatasetLoaders_V3, PHONE_DEF, PHONE_DEF_SIL\n",
    "import re \n",
    "from g2p_en import G2p\n",
    "import numpy as np\n",
    "from model.ctc_modelling import LightningGRUDecoder, LightningGRUDecoder_MFCC_v3\n",
    "from model.hybrid_modelling import HybridCausalLMOutput, HybridGRUDecoder\n",
    "import time\n",
    "import numpy as np\n",
    "from edit_distance import SequenceMatcher\n",
    "import tqdm\n",
    "import pytorch_lightning as pl\n",
    "import jiwer\n",
    "import nltk\n",
    "from nltk.corpus import cmudict\n",
    "from pytorch_lightning.loggers import WandbLogger\n",
    "import wandb\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
    "import copy\n",
    "from difflib import get_close_matches\n",
    "from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer\n",
    "import pandas as pd\n",
    "from torchaudio.models.decoder import ctc_decoder\n",
    "import string\n",
    "from config import DATASET_SM_ROBUST, DATASET_SM_ZSCORE, DATASET_FULL_TRIALS_ZSCORE\n",
    "# from model.ctc_modelling import Light\n",
    "import os\n",
    "# Download CMU Pronouncing Dictionary (First-time use)\n",
    "nltk.download(\"cmudict\")\n",
    "\n",
    "# Load CMUdict\n",
    "cmu_dict = cmudict.dict()\n",
    "\n",
    "from copy import deepcopy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7cd46930",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of trials:  10020\n",
      "Number of days:  24\n",
      "Number of trials after filtering by indices:  8800\n",
      "Number of trials:  880\n",
      "Number of days:  24\n",
      "Number of trials after filtering by indices:  880\n"
     ]
    }
   ],
   "source": [
    "train_loader, test_loader,_, loadedData = getDatasetLoaders_V3(DATASET_FULL_TRIALS_ZSCORE, 64, include_prego=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "698a4ba3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Resetting neural_dim based on channels\n",
      "neural_dim 256 256\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/torch/functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3595.)\n",
      "  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3443712/1459677493.py:36: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  neural_encoder.load_state_dict(torch.load(neural_encoder_model_weights_path)[\"state_dict\"])\n"
     ]
    }
   ],
   "source": [
    "\n",
    "nInputFeatures = 256 #channels \n",
    "nClasses = 40 \n",
    "dropout = 0.4 \n",
    "hidden_dim = 1024\n",
    "nlayers = 5\n",
    "stride_len = 4\n",
    "kernel_len =32\n",
    "gaussian_smooth_width = 2\n",
    "bidirectional = True\n",
    "\n",
    "white_noise_SD = 0.8\n",
    "constant_offset_SD = 0.2\n",
    "seq_len = 150\n",
    "max_time_series_len = 12000\n",
    "\n",
    "lr_start = 3e-4\n",
    "lr_end = 0.02\n",
    "l2_decay = 1e-5\n",
    "\n",
    "neural_encoder_model_weights_path = \"../.checkpoints/mfcc_sm_gru_ctc_LONGRUN/best_model.ckpt\"\n",
    "neural_encoder = LightningGRUDecoder_MFCC_v3(\n",
    "            neural_dim=nInputFeatures,\n",
    "            n_classes=nClasses,\n",
    "            hidden_dim=hidden_dim,\n",
    "            layer_dim=nlayers,\n",
    "            strideLen=stride_len,\n",
    "            kernelLen=kernel_len,\n",
    "            gaussianSmoothWidth=gaussian_smooth_width,\n",
    "            bidirectional=bidirectional,\n",
    "            dropout=dropout,\n",
    "            white_noise_SD=white_noise_SD,\n",
    "            constant_offset_SD=constant_offset_SD,\n",
    "            weight_decay=l2_decay,\n",
    "            learning_rate=lr_start)\n",
    "\n",
    "neural_encoder.load_state_dict(torch.load(neural_encoder_model_weights_path)[\"state_dict\"])\n",
    "\n",
    "neural_encoder_back = copy.deepcopy(neural_encoder)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0149bc53",
   "metadata": {},
   "outputs": [],
   "source": [
    "weights_path = \"/data/XXXXXX/speech_decoding_BCI/optimization/.checkpoints/gru_ctc_mfcc_bart/best_model_wer-v8.ckpt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bad0f1a5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using LoRA with r=128, alpha=64, dropout=0.2 on target modules ['q_proj', 'k_proj', 'v_proj', 'out_proj'].\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3443712/977749837.py:13: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  model.load_state_dict(torch.load(weights_path)[\"state_dict\"])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "model = HybridGRUDecoder(\n",
    "    neural_encoder=neural_encoder,\n",
    "    learning_rate=1e-4,\n",
    "    weight_decay=0,\n",
    "    lm_model_dim = 768, \n",
    "    freeze_lm=False,\n",
    "    freeze_encoder=False,\n",
    "    use_lora=True,\n",
    "    lora_r=128,\n",
    "    lora_alpha=64,\n",
    ")\n",
    "\n",
    "model.load_state_dict(torch.load(weights_path)[\"state_dict\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1847f33e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['neural_feats', 'phone_seq', 'neural_time_bins', 'phone_seq_len', 'day', 'sentence', 'audio_file', 'mfcc', 'go_onset', 'speech_label'])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch = next(iter(test_loader))\n",
    "\n",
    "batch.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6d1031a4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "HybridGRUDecoder(\n",
       "  (encoder): LightningGRUDecoder_MFCC_v3(\n",
       "    (inputLayerNonlinearity): Softsign()\n",
       "    (unfolder): Unfold(kernel_size=(32, 1), dilation=1, padding=0, stride=4)\n",
       "    (mfcc_unfolder): Unfold(kernel_size=(4, 1), dilation=1, padding=0, stride=4)\n",
       "    (gaussianSmoother): GaussianSmoothing()\n",
       "    (gru_decoder): GRU(8192, 1024, num_layers=5, batch_first=True, dropout=0.4, bidirectional=True)\n",
       "    (fc_decoder_out): Linear(in_features=2048, out_features=41, bias=True)\n",
       "    (mfcc_decoder): Linear(in_features=2048, out_features=56, bias=True)\n",
       "    (ctc_loss): CTCLoss()\n",
       "    (l1oss): L1Loss()\n",
       "  )\n",
       "  (project): Sequential(\n",
       "    (0): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
       "    (1): Linear(in_features=2048, out_features=768, bias=True)\n",
       "    (2): ReLU()\n",
       "    (3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (language_model): PeftModel(\n",
       "    (base_model): LoraModel(\n",
       "      (model): BartForConditionalGeneration(\n",
       "        (model): BartModel(\n",
       "          (shared): BartScaledWordEmbedding(50265, 768, padding_idx=1)\n",
       "          (encoder): BartEncoder(\n",
       "            (embed_tokens): BartScaledWordEmbedding(50265, 768, padding_idx=1)\n",
       "            (embed_positions): BartLearnedPositionalEmbedding(1026, 768)\n",
       "            (layers): ModuleList(\n",
       "              (0-5): 6 x BartEncoderLayer(\n",
       "                (self_attn): BartSdpaAttention(\n",
       "                  (k_proj): lora.Linear(\n",
       "                    (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (lora_dropout): ModuleDict(\n",
       "                      (default): Dropout(p=0.2, inplace=False)\n",
       "                    )\n",
       "                    (lora_A): ModuleDict(\n",
       "                      (default): Linear(in_features=768, out_features=128, bias=False)\n",
       "                    )\n",
       "                    (lora_B): ModuleDict(\n",
       "                      (default): Linear(in_features=128, out_features=768, bias=False)\n",
       "                    )\n",
       "                    (lora_embedding_A): ParameterDict()\n",
       "                    (lora_embedding_B): ParameterDict()\n",
       "                    (lora_magnitude_vector): ModuleDict()\n",
       "                  )\n",
       "                  (v_proj): lora.Linear(\n",
       "                    (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (lora_dropout): ModuleDict(\n",
       "                      (default): Dropout(p=0.2, inplace=False)\n",
       "                    )\n",
       "                    (lora_A): ModuleDict(\n",
       "                      (default): Linear(in_features=768, out_features=128, bias=False)\n",
       "                    )\n",
       "                    (lora_B): ModuleDict(\n",
       "                      (default): Linear(in_features=128, out_features=768, bias=False)\n",
       "                    )\n",
       "                    (lora_embedding_A): ParameterDict()\n",
       "                    (lora_embedding_B): ParameterDict()\n",
       "                    (lora_magnitude_vector): ModuleDict()\n",
       "                  )\n",
       "                  (q_proj): lora.Linear(\n",
       "                    (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (lora_dropout): ModuleDict(\n",
       "                      (default): Dropout(p=0.2, inplace=False)\n",
       "                    )\n",
       "                    (lora_A): ModuleDict(\n",
       "                      (default): Linear(in_features=768, out_features=128, bias=False)\n",
       "                    )\n",
       "                    (lora_B): ModuleDict(\n",
       "                      (default): Linear(in_features=128, out_features=768, bias=False)\n",
       "                    )\n",
       "                    (lora_embedding_A): ParameterDict()\n",
       "                    (lora_embedding_B): ParameterDict()\n",
       "                    (lora_magnitude_vector): ModuleDict()\n",
       "                  )\n",
       "                  (out_proj): lora.Linear(\n",
       "                    (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (lora_dropout): ModuleDict(\n",
       "                      (default): Dropout(p=0.2, inplace=False)\n",
       "                    )\n",
       "                    (lora_A): ModuleDict(\n",
       "                      (default): Linear(in_features=768, out_features=128, bias=False)\n",
       "                    )\n",
       "                    (lora_B): ModuleDict(\n",
       "                      (default): Linear(in_features=128, out_features=768, bias=False)\n",
       "                    )\n",
       "                    (lora_embedding_A): ParameterDict()\n",
       "                    (lora_embedding_B): ParameterDict()\n",
       "                    (lora_magnitude_vector): ModuleDict()\n",
       "                  )\n",
       "                )\n",
       "                (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "                (activation_fn): GELUActivation()\n",
       "                (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "                (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "                (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "              )\n",
       "            )\n",
       "            (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          )\n",
       "          (decoder): BartDecoder(\n",
       "            (embed_tokens): BartScaledWordEmbedding(50265, 768, padding_idx=1)\n",
       "            (embed_positions): BartLearnedPositionalEmbedding(1026, 768)\n",
       "            (layers): ModuleList(\n",
       "              (0-5): 6 x BartDecoderLayer(\n",
       "                (self_attn): BartSdpaAttention(\n",
       "                  (k_proj): lora.Linear(\n",
       "                    (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (lora_dropout): ModuleDict(\n",
       "                      (default): Dropout(p=0.2, inplace=False)\n",
       "                    )\n",
       "                    (lora_A): ModuleDict(\n",
       "                      (default): Linear(in_features=768, out_features=128, bias=False)\n",
       "                    )\n",
       "                    (lora_B): ModuleDict(\n",
       "                      (default): Linear(in_features=128, out_features=768, bias=False)\n",
       "                    )\n",
       "                    (lora_embedding_A): ParameterDict()\n",
       "                    (lora_embedding_B): ParameterDict()\n",
       "                    (lora_magnitude_vector): ModuleDict()\n",
       "                  )\n",
       "                  (v_proj): lora.Linear(\n",
       "                    (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (lora_dropout): ModuleDict(\n",
       "                      (default): Dropout(p=0.2, inplace=False)\n",
       "                    )\n",
       "                    (lora_A): ModuleDict(\n",
       "                      (default): Linear(in_features=768, out_features=128, bias=False)\n",
       "                    )\n",
       "                    (lora_B): ModuleDict(\n",
       "                      (default): Linear(in_features=128, out_features=768, bias=False)\n",
       "                    )\n",
       "                    (lora_embedding_A): ParameterDict()\n",
       "                    (lora_embedding_B): ParameterDict()\n",
       "                    (lora_magnitude_vector): ModuleDict()\n",
       "                  )\n",
       "                  (q_proj): lora.Linear(\n",
       "                    (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (lora_dropout): ModuleDict(\n",
       "                      (default): Dropout(p=0.2, inplace=False)\n",
       "                    )\n",
       "                    (lora_A): ModuleDict(\n",
       "                      (default): Linear(in_features=768, out_features=128, bias=False)\n",
       "                    )\n",
       "                    (lora_B): ModuleDict(\n",
       "                      (default): Linear(in_features=128, out_features=768, bias=False)\n",
       "                    )\n",
       "                    (lora_embedding_A): ParameterDict()\n",
       "                    (lora_embedding_B): ParameterDict()\n",
       "                    (lora_magnitude_vector): ModuleDict()\n",
       "                  )\n",
       "                  (out_proj): lora.Linear(\n",
       "                    (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (lora_dropout): ModuleDict(\n",
       "                      (default): Dropout(p=0.2, inplace=False)\n",
       "                    )\n",
       "                    (lora_A): ModuleDict(\n",
       "                      (default): Linear(in_features=768, out_features=128, bias=False)\n",
       "                    )\n",
       "                    (lora_B): ModuleDict(\n",
       "                      (default): Linear(in_features=128, out_features=768, bias=False)\n",
       "                    )\n",
       "                    (lora_embedding_A): ParameterDict()\n",
       "                    (lora_embedding_B): ParameterDict()\n",
       "                    (lora_magnitude_vector): ModuleDict()\n",
       "                  )\n",
       "                )\n",
       "                (activation_fn): GELUActivation()\n",
       "                (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "                (encoder_attn): BartSdpaAttention(\n",
       "                  (k_proj): lora.Linear(\n",
       "                    (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (lora_dropout): ModuleDict(\n",
       "                      (default): Dropout(p=0.2, inplace=False)\n",
       "                    )\n",
       "                    (lora_A): ModuleDict(\n",
       "                      (default): Linear(in_features=768, out_features=128, bias=False)\n",
       "                    )\n",
       "                    (lora_B): ModuleDict(\n",
       "                      (default): Linear(in_features=128, out_features=768, bias=False)\n",
       "                    )\n",
       "                    (lora_embedding_A): ParameterDict()\n",
       "                    (lora_embedding_B): ParameterDict()\n",
       "                    (lora_magnitude_vector): ModuleDict()\n",
       "                  )\n",
       "                  (v_proj): lora.Linear(\n",
       "                    (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (lora_dropout): ModuleDict(\n",
       "                      (default): Dropout(p=0.2, inplace=False)\n",
       "                    )\n",
       "                    (lora_A): ModuleDict(\n",
       "                      (default): Linear(in_features=768, out_features=128, bias=False)\n",
       "                    )\n",
       "                    (lora_B): ModuleDict(\n",
       "                      (default): Linear(in_features=128, out_features=768, bias=False)\n",
       "                    )\n",
       "                    (lora_embedding_A): ParameterDict()\n",
       "                    (lora_embedding_B): ParameterDict()\n",
       "                    (lora_magnitude_vector): ModuleDict()\n",
       "                  )\n",
       "                  (q_proj): lora.Linear(\n",
       "                    (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (lora_dropout): ModuleDict(\n",
       "                      (default): Dropout(p=0.2, inplace=False)\n",
       "                    )\n",
       "                    (lora_A): ModuleDict(\n",
       "                      (default): Linear(in_features=768, out_features=128, bias=False)\n",
       "                    )\n",
       "                    (lora_B): ModuleDict(\n",
       "                      (default): Linear(in_features=128, out_features=768, bias=False)\n",
       "                    )\n",
       "                    (lora_embedding_A): ParameterDict()\n",
       "                    (lora_embedding_B): ParameterDict()\n",
       "                    (lora_magnitude_vector): ModuleDict()\n",
       "                  )\n",
       "                  (out_proj): lora.Linear(\n",
       "                    (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (lora_dropout): ModuleDict(\n",
       "                      (default): Dropout(p=0.2, inplace=False)\n",
       "                    )\n",
       "                    (lora_A): ModuleDict(\n",
       "                      (default): Linear(in_features=768, out_features=128, bias=False)\n",
       "                    )\n",
       "                    (lora_B): ModuleDict(\n",
       "                      (default): Linear(in_features=128, out_features=768, bias=False)\n",
       "                    )\n",
       "                    (lora_embedding_A): ParameterDict()\n",
       "                    (lora_embedding_B): ParameterDict()\n",
       "                    (lora_magnitude_vector): ModuleDict()\n",
       "                  )\n",
       "                )\n",
       "                (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "                (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "                (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "                (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "              )\n",
       "            )\n",
       "            (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          )\n",
       "        )\n",
       "        (lm_head): Linear(in_features=768, out_features=50265, bias=False)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.to(\"cuda:0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "23585fa8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/XXXXXX/speech_decoding_BCI/optimization/../augmentations.py:91: UserWarning: Using padding='same' with even kernel lengths and odd dilation may require a zero-padded copy of the input be created (Triggered internally at ../aten/src/ATen/native/Convolution.cpp:1036.)\n",
      "  return self.conv(input, weight=self.weight, groups=self.groups, padding=\"same\")\n",
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:677: UserWarning: `num_beams` is set to 1. However, `early_stopping` is set to `True` -- this flag is only used in beam-based generation modes. You should set `num_beams>1` or unset `early_stopping`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "['The analysis.',\n",
       " \"Tes's projects---ism.\",\n",
       " 'The use in be in effect.',\n",
       " \"Count's open to be glelele-.\",\n",
       " 'The band, the brain,, the come out.',\n",
       " 'That can six the small soft was g--age.',\n",
       " 'Markage are children every every every.',\n",
       " 'Yyyay of trouble of bill.',\n",
       " 'The run made it won, only rle.',\n",
       " 'Would a boy back in a man in face.',\n",
       " 'He pick an exception on a oil.',\n",
       " 'Pail boy only only every morning.',\n",
       " 'With this blind,,,ary.',\n",
       " 'Pppp have and up and and.',\n",
       " 'And was the scrape but a a paint of the problem.',\n",
       " 'We are come the be the first the own of the own.',\n",
       " 'Momabababumabum-',\n",
       " 'The issues will be be be the supply of this face.',\n",
       " 'They eyes is put pleasure and last pleasure.',\n",
       " 'Tabababodababulous.',\n",
       " 'To twenty by the build to build intelligence.',\n",
       " 'An entire,,, are not not the setting.',\n",
       " 'In the fact it would to would to be the vision of the people.',\n",
       " \"You're bink mink mail you.\",\n",
       " 'Stage for the sl and the brain on the defense.',\n",
       " 'Ticalical is tend to measure to measure.',\n",
       " 'Gggyy was ground was ground for light.',\n",
       " 'Feed,,, and buy.',\n",
       " 'Careunts is identical.',\n",
       " 'Please not this clothes for sight.',\n",
       " 'Who also the original example.',\n",
       " 'The family that be be be early.',\n",
       " 'The measurements every times every.',\n",
       " 'Will will short before before.',\n",
       " 'They found is was I been.',\n",
       " 'The unexpected,, a have a college as a college.',\n",
       " 'Eightery, this have little little and be the sp.',\n",
       " 'The saw the slam, the bottle.',\n",
       " 'The chart is reduced.',\n",
       " 'They came petanned of.',\n",
       " 'A crash by a chip by a hood.',\n",
       " 'Boy you eat but you eat.',\n",
       " 'Tts are each intelligence.',\n",
       " 'The man-iled.',\n",
       " \"This cat's cated.\",\n",
       " 'She dressed and good.',\n",
       " 'Do you know when it have gone.',\n",
       " 'In the paper for the come for come.',\n",
       " 'Join is on on pain.',\n",
       " 'He said,, be be be the work.',\n",
       " 'ShSying is a large on a general.',\n",
       " 'Then seems longer.',\n",
       " 'A third the third.',\n",
       " 'The community from the little sl--',\n",
       " 'Like at off of bill.',\n",
       " 'Then he then then.',\n",
       " 'I am an hour in my own vision.',\n",
       " 'Both is one by one by after.',\n",
       " 'Samamamivelyively.',\n",
       " 'It all look and look and good.',\n",
       " 'That was no lillillilly.',\n",
       " 'In sey, were were child.',\n",
       " 'It can impress for her out for herself.',\n",
       " 'Ented taken him at him.']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.generate(batch[\"neural_feats\"].to(\"cuda:0\"),batch[\"day\"].to(\"cuda:0\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "37c2deb8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[Parameter containing:\n",
       "  tensor([[-0.0044, -0.0023, -0.0050,  ..., -0.0084, -0.0042, -0.0066],\n",
       "          [ 0.0101,  0.0038, -0.0006,  ..., -0.0018, -0.0027, -0.0078],\n",
       "          [ 0.0111,  0.0116,  0.0105,  ..., -0.0073, -0.0061, -0.0049],\n",
       "          ...,\n",
       "          [ 0.0072,  0.0114,  0.0157,  ..., -0.0236, -0.0218, -0.0234],\n",
       "          [-0.0046, -0.0045,  0.0043,  ...,  0.0120,  0.0027,  0.0229],\n",
       "          [-0.0033, -0.0001, -0.0002,  ...,  0.0126, -0.0165,  0.0004]],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[-0.0069, -0.0019,  0.0003,  ...,  0.0064,  0.0019, -0.0229],\n",
       "          [ 0.0092, -0.0150, -0.0269,  ..., -0.0239,  0.0310,  0.0040],\n",
       "          [-0.0087, -0.0008, -0.0054,  ...,  0.0051, -0.0127, -0.0092],\n",
       "          ...,\n",
       "          [-0.0272, -0.0045, -0.0157,  ...,  0.0356, -0.0365,  0.0227],\n",
       "          [ 0.0149, -0.0014,  0.0172,  ...,  0.0074,  0.0323,  0.0208],\n",
       "          [ 0.0112, -0.0041,  0.0214,  ...,  0.0269, -0.0131,  0.0435]],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([-0.0052, -0.0234, -0.0020,  ...,  0.0575, -0.0477,  0.0309],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0299, -0.0052,  0.0032,  ...,  0.0379, -0.0159,  0.0403],\n",
       "         device='cuda:0', requires_grad=True)],\n",
       " [Parameter containing:\n",
       "  tensor([[-0.0005, -0.0023,  0.0022,  ..., -0.0050, -0.0038, -0.0056],\n",
       "          [ 0.0091,  0.0080,  0.0159,  ..., -0.0024, -0.0086, -0.0126],\n",
       "          [ 0.0012,  0.0012, -0.0019,  ..., -0.0110, -0.0123, -0.0090],\n",
       "          ...,\n",
       "          [-0.0206, -0.0032,  0.0038,  ..., -0.0175, -0.0179, -0.0047],\n",
       "          [ 0.0203, -0.0100,  0.0124,  ..., -0.0119, -0.0073, -0.0160],\n",
       "          [ 0.0109, -0.0038, -0.0009,  ..., -0.0090, -0.0033, -0.0004]],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[-0.0136, -0.0187,  0.0089,  ..., -0.0111, -0.0227,  0.0045],\n",
       "          [ 0.0422,  0.0030,  0.0142,  ...,  0.0191,  0.0029,  0.0100],\n",
       "          [-0.0095, -0.0145,  0.0007,  ...,  0.0061, -0.0066, -0.0080],\n",
       "          ...,\n",
       "          [-0.0062, -0.0218, -0.0305,  ...,  0.0170,  0.0138, -0.0019],\n",
       "          [-0.0035,  0.0660,  0.0103,  ..., -0.0112,  0.0117, -0.0182],\n",
       "          [-0.0397, -0.0228,  0.0266,  ..., -0.0092,  0.0298,  0.0399]],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0353,  0.0130,  0.0167,  ...,  0.0262,  0.0340, -0.0141],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0400,  0.0125,  0.0373,  ..., -0.0120,  0.0004, -0.0115],\n",
       "         device='cuda:0', requires_grad=True)],\n",
       " [Parameter containing:\n",
       "  tensor([[-0.0093,  0.0184, -0.0084,  ...,  0.0098,  0.0014,  0.0040],\n",
       "          [ 0.0090, -0.0033,  0.0056,  ..., -0.0003, -0.0006,  0.0097],\n",
       "          [-0.0202, -0.0052, -0.0124,  ...,  0.0074, -0.0070,  0.0070],\n",
       "          ...,\n",
       "          [ 0.0160,  0.0217, -0.0135,  ..., -0.0283, -0.0237,  0.0189],\n",
       "          [ 0.0229,  0.0070,  0.0077,  ...,  0.0115,  0.0144, -0.0190],\n",
       "          [-0.0043,  0.0003, -0.0403,  ..., -0.0142, -0.0311, -0.0337]],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[ 1.5918e-02, -1.2569e-03, -2.2349e-03,  ..., -1.3961e-02,\n",
       "           -7.8066e-03, -1.7154e-03],\n",
       "          [ 1.4132e-02, -1.6089e-02,  6.3565e-03,  ...,  7.2755e-03,\n",
       "           -1.4228e-02, -4.0263e-03],\n",
       "          [ 9.9748e-03,  5.9188e-03,  2.1691e-02,  ..., -1.2677e-02,\n",
       "           -1.8236e-02, -8.5487e-03],\n",
       "          ...,\n",
       "          [ 2.2825e-02,  3.4633e-02, -2.5968e-02,  ...,  2.4206e-02,\n",
       "            1.8346e-03,  1.6353e-02],\n",
       "          [ 7.1963e-04, -1.7597e-05, -2.8866e-02,  ...,  3.4136e-03,\n",
       "            5.5424e-02,  4.6944e-03],\n",
       "          [ 1.2541e-02, -1.3969e-02,  6.0558e-03,  ..., -2.1161e-02,\n",
       "           -8.6614e-03,  3.4314e-02]], device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 1.8552e-02,  2.2236e-02,  2.3295e-02,  ..., -2.1889e-02,\n",
       "          -4.1010e-02, -7.9831e-05], device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 1.3801e-02,  2.7421e-02,  7.2231e-05,  ..., -1.5536e-02,\n",
       "          -3.5528e-03, -2.4865e-02], device='cuda:0', requires_grad=True)],\n",
       " [Parameter containing:\n",
       "  tensor([[ 1.4312e-02, -1.6631e-02, -1.1382e-02,  ...,  2.7317e-03,\n",
       "            5.3338e-03,  1.0776e-02],\n",
       "          [ 1.4663e-02,  1.8477e-02, -7.3896e-03,  ...,  2.4296e-02,\n",
       "           -7.1121e-03,  2.4054e-03],\n",
       "          [-5.5759e-03,  1.5628e-02,  1.7560e-02,  ..., -2.4209e-03,\n",
       "           -6.4703e-03, -2.5245e-03],\n",
       "          ...,\n",
       "          [ 1.8696e-02,  2.3336e-02, -3.7312e-02,  ...,  6.6283e-05,\n",
       "            7.5660e-03, -2.4696e-02],\n",
       "          [-4.5078e-02,  3.6337e-02,  3.3525e-02,  ...,  1.0326e-02,\n",
       "            1.5592e-02,  1.4567e-02],\n",
       "          [-2.8089e-02, -3.3518e-02,  3.9523e-02,  ..., -1.0550e-03,\n",
       "            3.6600e-03, -1.7445e-02]], device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[-0.0091,  0.0031,  0.0098,  ..., -0.0108, -0.0161, -0.0157],\n",
       "          [-0.0119,  0.0224,  0.0002,  ..., -0.0010, -0.0076, -0.0072],\n",
       "          [-0.0124,  0.0007,  0.0006,  ..., -0.0037,  0.0124, -0.0165],\n",
       "          ...,\n",
       "          [-0.0045,  0.0068, -0.0064,  ...,  0.0089, -0.0099,  0.0229],\n",
       "          [-0.0032,  0.0058,  0.0209,  ...,  0.0095,  0.0185, -0.0025],\n",
       "          [ 0.0300, -0.0253, -0.0053,  ..., -0.0118,  0.0221,  0.0012]],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0059,  0.0263,  0.0271,  ..., -0.0459,  0.0241,  0.0012],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0087,  0.0241,  0.0376,  ..., -0.0316,  0.0039,  0.0254],\n",
       "         device='cuda:0', requires_grad=True)],\n",
       " [Parameter containing:\n",
       "  tensor([[-2.8173e-03,  2.2498e-03, -8.9168e-03,  ..., -1.5682e-02,\n",
       "           -5.4059e-03,  3.4606e-04],\n",
       "          [ 1.8380e-02, -7.8838e-03,  1.0785e-02,  ..., -7.5265e-03,\n",
       "            7.0531e-03,  2.5981e-03],\n",
       "          [-1.7087e-02,  1.7269e-03,  9.7937e-03,  ..., -1.6696e-02,\n",
       "            3.5423e-03,  1.4356e-02],\n",
       "          ...,\n",
       "          [ 1.3268e-02,  3.6282e-03,  8.7936e-03,  ...,  3.8010e-03,\n",
       "            1.6477e-02, -1.7975e-03],\n",
       "          [ 2.6765e-02,  1.9178e-02,  1.1439e-02,  ...,  1.1463e-02,\n",
       "            2.9836e-02, -7.1729e-03],\n",
       "          [ 9.6478e-05, -1.4573e-02, -5.6656e-03,  ...,  2.3145e-02,\n",
       "            2.4088e-02,  1.4242e-02]], device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[ 1.6051e-03, -4.2559e-03,  2.8567e-03,  ..., -1.1574e-02,\n",
       "            6.4465e-03, -2.6761e-03],\n",
       "          [ 8.8060e-03, -5.3465e-03, -6.4087e-03,  ..., -8.2543e-03,\n",
       "            1.0604e-04, -9.7659e-03],\n",
       "          [ 6.2947e-03, -1.6884e-03,  1.9673e-02,  ..., -1.5686e-03,\n",
       "            9.0580e-03,  4.2378e-03],\n",
       "          ...,\n",
       "          [-1.4276e-02, -5.7515e-05, -4.0357e-03,  ...,  3.7697e-02,\n",
       "            1.7242e-03,  1.1795e-02],\n",
       "          [ 4.3363e-02, -7.2026e-03,  4.3395e-03,  ..., -7.1553e-03,\n",
       "            5.1368e-02,  1.3605e-02],\n",
       "          [-6.7334e-03,  5.2974e-03, -4.2865e-03,  ...,  8.8822e-04,\n",
       "            5.6222e-03,  5.6220e-02]], device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0022,  0.0177,  0.0168,  ..., -0.0095, -0.0028,  0.0351],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0119,  0.0103,  0.0266,  ..., -0.0067, -0.0128,  0.0172],\n",
       "         device='cuda:0', requires_grad=True)],\n",
       " [Parameter containing:\n",
       "  tensor([[ 5.8445e-03,  1.4390e-02,  5.4644e-03,  ..., -1.6870e-03,\n",
       "            2.8118e-03, -5.3684e-03],\n",
       "          [ 3.6881e-04, -1.7293e-03,  1.0902e-03,  ..., -4.8677e-03,\n",
       "            1.0086e-02, -1.4298e-02],\n",
       "          [ 2.7962e-03, -7.9294e-03, -4.1887e-03,  ..., -2.4055e-04,\n",
       "           -8.5968e-05,  9.1243e-05],\n",
       "          ...,\n",
       "          [-1.1541e-02, -2.2115e-02, -4.4295e-02,  ..., -1.3712e-02,\n",
       "            1.4445e-02, -8.9336e-03],\n",
       "          [ 2.2928e-02,  4.9056e-03,  8.9440e-03,  ...,  3.0587e-02,\n",
       "            1.1667e-03, -3.2378e-02],\n",
       "          [ 7.8100e-03,  5.9985e-03, -9.2184e-03,  ..., -2.6244e-02,\n",
       "           -1.7644e-02,  6.8278e-03]], device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[-5.5111e-05, -5.6352e-03,  8.9084e-03,  ...,  3.3379e-03,\n",
       "           -3.3208e-03, -2.4577e-02],\n",
       "          [ 3.3861e-03,  7.1946e-03, -5.5061e-03,  ..., -1.7104e-02,\n",
       "           -7.4123e-03, -1.8756e-03],\n",
       "          [-7.1169e-04, -1.0005e-02,  1.0312e-03,  ..., -6.7325e-04,\n",
       "            1.2453e-02, -3.8248e-03],\n",
       "          ...,\n",
       "          [ 3.4905e-03,  1.3379e-02,  4.7813e-02,  ...,  5.3354e-02,\n",
       "            1.5951e-02, -1.5819e-02],\n",
       "          [ 2.2557e-02, -7.1769e-03,  7.2348e-03,  ...,  1.6899e-02,\n",
       "            4.2868e-02,  2.1981e-02],\n",
       "          [ 2.7390e-02, -2.0409e-03, -2.5539e-02,  ..., -6.5740e-03,\n",
       "           -3.4945e-03,  6.5052e-02]], device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0348,  0.0280,  0.0289,  ...,  0.0016, -0.0228, -0.0277],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0250,  0.0260,  0.0282,  ..., -0.0081, -0.0077,  0.0162],\n",
       "         device='cuda:0', requires_grad=True)],\n",
       " [Parameter containing:\n",
       "  tensor([[ 0.0260, -0.0071,  0.0121,  ..., -0.0060, -0.0027, -0.0033],\n",
       "          [ 0.0069, -0.0137,  0.0126,  ..., -0.0057, -0.0127,  0.0025],\n",
       "          [ 0.0125,  0.0031,  0.0023,  ...,  0.0048,  0.0053, -0.0165],\n",
       "          ...,\n",
       "          [ 0.0237, -0.0231,  0.0073,  ..., -0.0243, -0.0238, -0.0140],\n",
       "          [-0.0177, -0.0179, -0.0273,  ..., -0.0220, -0.0054, -0.0315],\n",
       "          [ 0.0248, -0.0073, -0.0117,  ..., -0.0098,  0.0303,  0.0245]],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[-0.0132, -0.0091,  0.0156,  ...,  0.0179,  0.0096, -0.0048],\n",
       "          [ 0.0153, -0.0065,  0.0029,  ...,  0.0091,  0.0257, -0.0207],\n",
       "          [ 0.0097, -0.0194,  0.0006,  ..., -0.0048,  0.0111, -0.0048],\n",
       "          ...,\n",
       "          [ 0.0001, -0.0020,  0.0077,  ...,  0.0332, -0.0258, -0.0045],\n",
       "          [ 0.0325,  0.0212,  0.0124,  ..., -0.0118,  0.0802, -0.0125],\n",
       "          [-0.0294,  0.0080, -0.0331,  ...,  0.0278, -0.0282,  0.0537]],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0217,  0.0109,  0.0248,  ..., -0.0216,  0.0199,  0.0154],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0212,  0.0190,  0.0111,  ..., -0.0072, -0.0082, -0.0213],\n",
       "         device='cuda:0', requires_grad=True)],\n",
       " [Parameter containing:\n",
       "  tensor([[-0.0073, -0.0197, -0.0188,  ..., -0.0117,  0.0064,  0.0241],\n",
       "          [ 0.0154, -0.0055, -0.0019,  ..., -0.0002, -0.0262,  0.0072],\n",
       "          [-0.0003, -0.0077,  0.0222,  ..., -0.0025, -0.0127,  0.0053],\n",
       "          ...,\n",
       "          [ 0.0098,  0.0023,  0.0069,  ..., -0.0025, -0.0172, -0.0278],\n",
       "          [-0.0058,  0.0208, -0.0038,  ...,  0.0178,  0.0162, -0.0298],\n",
       "          [ 0.0161,  0.0459,  0.0207,  ...,  0.0283,  0.0101,  0.0107]],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[-0.0257, -0.0168, -0.0173,  ...,  0.0090, -0.0041, -0.0044],\n",
       "          [-0.0065, -0.0010, -0.0102,  ..., -0.0077,  0.0016,  0.0142],\n",
       "          [-0.0007, -0.0061, -0.0069,  ..., -0.0182, -0.0023,  0.0217],\n",
       "          ...,\n",
       "          [-0.0401, -0.0158, -0.0030,  ...,  0.0344, -0.0156, -0.0169],\n",
       "          [-0.0072, -0.0194, -0.0098,  ...,  0.0140,  0.0071, -0.0329],\n",
       "          [-0.0105, -0.0005,  0.0008,  ..., -0.0279,  0.0153,  0.0277]],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0201,  0.0081, -0.0048,  ..., -0.0151,  0.0070, -0.0134],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0194, -0.0004, -0.0004,  ..., -0.0047,  0.0130, -0.0278],\n",
       "         device='cuda:0', requires_grad=True)],\n",
       " [Parameter containing:\n",
       "  tensor([[ 0.0346,  0.0258,  0.0067,  ..., -0.0170, -0.0034,  0.0245],\n",
       "          [ 0.0071,  0.0034,  0.0073,  ..., -0.0093, -0.0051, -0.0076],\n",
       "          [-0.0135, -0.0066, -0.0145,  ..., -0.0107, -0.0045,  0.0123],\n",
       "          ...,\n",
       "          [ 0.0149, -0.0242, -0.0001,  ...,  0.0107, -0.0052, -0.0044],\n",
       "          [-0.0141,  0.0042,  0.0192,  ...,  0.0216,  0.0159,  0.0207],\n",
       "          [ 0.0220, -0.0309, -0.0327,  ..., -0.0229, -0.0086, -0.0260]],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[ 0.0457, -0.0139,  0.0335,  ...,  0.0017, -0.0106, -0.0039],\n",
       "          [-0.0031,  0.0125, -0.0018,  ...,  0.0025,  0.0038,  0.0224],\n",
       "          [ 0.0110,  0.0035, -0.0280,  ..., -0.0090,  0.0060,  0.0072],\n",
       "          ...,\n",
       "          [-0.0084, -0.0084,  0.0008,  ...,  0.0255, -0.0056, -0.0171],\n",
       "          [ 0.0193,  0.0027,  0.0279,  ..., -0.0364,  0.0357,  0.0204],\n",
       "          [-0.0359, -0.0182,  0.0149,  ..., -0.0010,  0.0107,  0.0240]],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0162,  0.0071,  0.0002,  ..., -0.0051,  0.0059,  0.0005],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([-0.0028,  0.0206,  0.0056,  ...,  0.0116, -0.0059,  0.0026],\n",
       "         device='cuda:0', requires_grad=True)],\n",
       " [Parameter containing:\n",
       "  tensor([[-7.8657e-05,  9.5959e-03, -4.8672e-03,  ...,  1.1935e-02,\n",
       "            1.3429e-04, -1.3365e-02],\n",
       "          [-5.2223e-03, -2.0591e-02, -7.0762e-03,  ...,  3.6497e-03,\n",
       "           -4.4151e-03, -8.5901e-03],\n",
       "          [-2.3526e-02, -4.4537e-03, -5.5906e-03,  ..., -6.5109e-03,\n",
       "            2.5775e-03,  7.0874e-03],\n",
       "          ...,\n",
       "          [ 7.2634e-03, -1.8029e-03, -6.7822e-04,  ...,  1.3679e-02,\n",
       "            9.7104e-03, -1.7462e-02],\n",
       "          [ 2.4691e-02,  1.0715e-04,  7.5727e-03,  ...,  3.5838e-02,\n",
       "           -7.6700e-03, -4.5338e-03],\n",
       "          [ 1.1234e-02, -7.3439e-03,  4.3485e-04,  ...,  2.6016e-02,\n",
       "            1.6748e-02,  1.1816e-02]], device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[-0.0110, -0.0065, -0.0003,  ..., -0.0043, -0.0111,  0.0067],\n",
       "          [ 0.0008,  0.0245, -0.0079,  ..., -0.0103,  0.0046,  0.0020],\n",
       "          [-0.0020, -0.0086, -0.0143,  ..., -0.0028, -0.0073, -0.0091],\n",
       "          ...,\n",
       "          [ 0.0052, -0.0137, -0.0035,  ...,  0.0122,  0.0084,  0.0480],\n",
       "          [ 0.0168, -0.0260, -0.0114,  ...,  0.0068,  0.0655, -0.0050],\n",
       "          [ 0.0247,  0.0103, -0.0129,  ...,  0.0045, -0.0113,  0.0364]],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0097,  0.0009,  0.0097,  ..., -0.0043,  0.0157, -0.0040],\n",
       "         device='cuda:0', requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0092, -0.0019, -0.0046,  ...,  0.0122, -0.0158,  0.0064],\n",
       "         device='cuda:0', requires_grad=True)]]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.encoder.gru_decoder.all_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a0fca047",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[Parameter containing:\n",
       "  tensor([[-0.0044, -0.0023, -0.0050,  ..., -0.0084, -0.0042, -0.0066],\n",
       "          [ 0.0101,  0.0038, -0.0006,  ..., -0.0018, -0.0027, -0.0078],\n",
       "          [ 0.0111,  0.0116,  0.0105,  ..., -0.0073, -0.0061, -0.0049],\n",
       "          ...,\n",
       "          [ 0.0072,  0.0114,  0.0157,  ..., -0.0236, -0.0218, -0.0234],\n",
       "          [-0.0046, -0.0045,  0.0043,  ...,  0.0120,  0.0027,  0.0229],\n",
       "          [-0.0033, -0.0001, -0.0002,  ...,  0.0126, -0.0165,  0.0004]],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[-0.0069, -0.0019,  0.0003,  ...,  0.0064,  0.0019, -0.0229],\n",
       "          [ 0.0092, -0.0150, -0.0269,  ..., -0.0239,  0.0310,  0.0040],\n",
       "          [-0.0087, -0.0008, -0.0054,  ...,  0.0051, -0.0127, -0.0092],\n",
       "          ...,\n",
       "          [-0.0272, -0.0045, -0.0157,  ...,  0.0356, -0.0365,  0.0227],\n",
       "          [ 0.0149, -0.0014,  0.0172,  ...,  0.0074,  0.0323,  0.0208],\n",
       "          [ 0.0112, -0.0041,  0.0214,  ...,  0.0269, -0.0131,  0.0435]],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([-0.0052, -0.0234, -0.0020,  ...,  0.0575, -0.0477,  0.0309],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0299, -0.0052,  0.0032,  ...,  0.0379, -0.0159,  0.0403],\n",
       "         requires_grad=True)],\n",
       " [Parameter containing:\n",
       "  tensor([[-0.0005, -0.0023,  0.0022,  ..., -0.0050, -0.0038, -0.0056],\n",
       "          [ 0.0091,  0.0080,  0.0159,  ..., -0.0024, -0.0086, -0.0126],\n",
       "          [ 0.0012,  0.0012, -0.0019,  ..., -0.0110, -0.0123, -0.0090],\n",
       "          ...,\n",
       "          [-0.0206, -0.0032,  0.0038,  ..., -0.0175, -0.0179, -0.0047],\n",
       "          [ 0.0203, -0.0100,  0.0124,  ..., -0.0119, -0.0073, -0.0160],\n",
       "          [ 0.0109, -0.0038, -0.0009,  ..., -0.0090, -0.0033, -0.0004]],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[-0.0136, -0.0187,  0.0089,  ..., -0.0111, -0.0227,  0.0045],\n",
       "          [ 0.0422,  0.0030,  0.0142,  ...,  0.0191,  0.0029,  0.0100],\n",
       "          [-0.0095, -0.0145,  0.0007,  ...,  0.0061, -0.0066, -0.0080],\n",
       "          ...,\n",
       "          [-0.0062, -0.0218, -0.0305,  ...,  0.0170,  0.0138, -0.0019],\n",
       "          [-0.0035,  0.0660,  0.0103,  ..., -0.0112,  0.0117, -0.0182],\n",
       "          [-0.0397, -0.0228,  0.0266,  ..., -0.0092,  0.0298,  0.0399]],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0353,  0.0130,  0.0167,  ...,  0.0262,  0.0340, -0.0141],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0400,  0.0125,  0.0373,  ..., -0.0120,  0.0004, -0.0115],\n",
       "         requires_grad=True)],\n",
       " [Parameter containing:\n",
       "  tensor([[-0.0093,  0.0184, -0.0084,  ...,  0.0098,  0.0014,  0.0040],\n",
       "          [ 0.0090, -0.0033,  0.0056,  ..., -0.0003, -0.0006,  0.0097],\n",
       "          [-0.0202, -0.0052, -0.0124,  ...,  0.0074, -0.0070,  0.0070],\n",
       "          ...,\n",
       "          [ 0.0160,  0.0217, -0.0135,  ..., -0.0283, -0.0237,  0.0189],\n",
       "          [ 0.0229,  0.0070,  0.0077,  ...,  0.0115,  0.0144, -0.0190],\n",
       "          [-0.0043,  0.0003, -0.0403,  ..., -0.0142, -0.0311, -0.0337]],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[ 1.5918e-02, -1.2569e-03, -2.2349e-03,  ..., -1.3961e-02,\n",
       "           -7.8066e-03, -1.7154e-03],\n",
       "          [ 1.4132e-02, -1.6089e-02,  6.3565e-03,  ...,  7.2755e-03,\n",
       "           -1.4228e-02, -4.0263e-03],\n",
       "          [ 9.9748e-03,  5.9188e-03,  2.1691e-02,  ..., -1.2677e-02,\n",
       "           -1.8236e-02, -8.5487e-03],\n",
       "          ...,\n",
       "          [ 2.2825e-02,  3.4633e-02, -2.5968e-02,  ...,  2.4206e-02,\n",
       "            1.8346e-03,  1.6353e-02],\n",
       "          [ 7.1963e-04, -1.7597e-05, -2.8866e-02,  ...,  3.4136e-03,\n",
       "            5.5424e-02,  4.6944e-03],\n",
       "          [ 1.2541e-02, -1.3969e-02,  6.0558e-03,  ..., -2.1161e-02,\n",
       "           -8.6614e-03,  3.4314e-02]], requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 1.8552e-02,  2.2236e-02,  2.3295e-02,  ..., -2.1889e-02,\n",
       "          -4.1010e-02, -7.9831e-05], requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 1.3801e-02,  2.7421e-02,  7.2231e-05,  ..., -1.5536e-02,\n",
       "          -3.5528e-03, -2.4865e-02], requires_grad=True)],\n",
       " [Parameter containing:\n",
       "  tensor([[ 1.4312e-02, -1.6631e-02, -1.1382e-02,  ...,  2.7317e-03,\n",
       "            5.3338e-03,  1.0776e-02],\n",
       "          [ 1.4663e-02,  1.8477e-02, -7.3896e-03,  ...,  2.4296e-02,\n",
       "           -7.1121e-03,  2.4054e-03],\n",
       "          [-5.5759e-03,  1.5628e-02,  1.7560e-02,  ..., -2.4209e-03,\n",
       "           -6.4703e-03, -2.5245e-03],\n",
       "          ...,\n",
       "          [ 1.8696e-02,  2.3336e-02, -3.7312e-02,  ...,  6.6283e-05,\n",
       "            7.5660e-03, -2.4696e-02],\n",
       "          [-4.5078e-02,  3.6337e-02,  3.3525e-02,  ...,  1.0326e-02,\n",
       "            1.5592e-02,  1.4567e-02],\n",
       "          [-2.8089e-02, -3.3518e-02,  3.9523e-02,  ..., -1.0550e-03,\n",
       "            3.6600e-03, -1.7445e-02]], requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[-0.0091,  0.0031,  0.0098,  ..., -0.0108, -0.0161, -0.0157],\n",
       "          [-0.0119,  0.0224,  0.0002,  ..., -0.0010, -0.0076, -0.0072],\n",
       "          [-0.0124,  0.0007,  0.0006,  ..., -0.0037,  0.0124, -0.0165],\n",
       "          ...,\n",
       "          [-0.0045,  0.0068, -0.0064,  ...,  0.0089, -0.0099,  0.0229],\n",
       "          [-0.0032,  0.0058,  0.0209,  ...,  0.0095,  0.0185, -0.0025],\n",
       "          [ 0.0300, -0.0253, -0.0053,  ..., -0.0118,  0.0221,  0.0012]],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0059,  0.0263,  0.0271,  ..., -0.0459,  0.0241,  0.0012],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0087,  0.0241,  0.0376,  ..., -0.0316,  0.0039,  0.0254],\n",
       "         requires_grad=True)],\n",
       " [Parameter containing:\n",
       "  tensor([[-2.8173e-03,  2.2498e-03, -8.9168e-03,  ..., -1.5682e-02,\n",
       "           -5.4059e-03,  3.4606e-04],\n",
       "          [ 1.8380e-02, -7.8838e-03,  1.0785e-02,  ..., -7.5265e-03,\n",
       "            7.0531e-03,  2.5981e-03],\n",
       "          [-1.7087e-02,  1.7269e-03,  9.7937e-03,  ..., -1.6696e-02,\n",
       "            3.5423e-03,  1.4356e-02],\n",
       "          ...,\n",
       "          [ 1.3268e-02,  3.6282e-03,  8.7936e-03,  ...,  3.8010e-03,\n",
       "            1.6477e-02, -1.7975e-03],\n",
       "          [ 2.6765e-02,  1.9178e-02,  1.1439e-02,  ...,  1.1463e-02,\n",
       "            2.9836e-02, -7.1729e-03],\n",
       "          [ 9.6478e-05, -1.4573e-02, -5.6656e-03,  ...,  2.3145e-02,\n",
       "            2.4088e-02,  1.4242e-02]], requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[ 1.6051e-03, -4.2559e-03,  2.8567e-03,  ..., -1.1574e-02,\n",
       "            6.4465e-03, -2.6761e-03],\n",
       "          [ 8.8060e-03, -5.3465e-03, -6.4087e-03,  ..., -8.2543e-03,\n",
       "            1.0604e-04, -9.7659e-03],\n",
       "          [ 6.2947e-03, -1.6884e-03,  1.9673e-02,  ..., -1.5686e-03,\n",
       "            9.0580e-03,  4.2378e-03],\n",
       "          ...,\n",
       "          [-1.4276e-02, -5.7515e-05, -4.0357e-03,  ...,  3.7697e-02,\n",
       "            1.7242e-03,  1.1795e-02],\n",
       "          [ 4.3363e-02, -7.2026e-03,  4.3395e-03,  ..., -7.1553e-03,\n",
       "            5.1368e-02,  1.3605e-02],\n",
       "          [-6.7334e-03,  5.2974e-03, -4.2865e-03,  ...,  8.8822e-04,\n",
       "            5.6222e-03,  5.6220e-02]], requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0022,  0.0177,  0.0168,  ..., -0.0095, -0.0028,  0.0351],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0119,  0.0103,  0.0266,  ..., -0.0067, -0.0128,  0.0172],\n",
       "         requires_grad=True)],\n",
       " [Parameter containing:\n",
       "  tensor([[ 5.8445e-03,  1.4390e-02,  5.4644e-03,  ..., -1.6870e-03,\n",
       "            2.8118e-03, -5.3684e-03],\n",
       "          [ 3.6881e-04, -1.7293e-03,  1.0902e-03,  ..., -4.8677e-03,\n",
       "            1.0086e-02, -1.4298e-02],\n",
       "          [ 2.7962e-03, -7.9294e-03, -4.1887e-03,  ..., -2.4055e-04,\n",
       "           -8.5968e-05,  9.1243e-05],\n",
       "          ...,\n",
       "          [-1.1541e-02, -2.2115e-02, -4.4295e-02,  ..., -1.3712e-02,\n",
       "            1.4445e-02, -8.9336e-03],\n",
       "          [ 2.2928e-02,  4.9056e-03,  8.9440e-03,  ...,  3.0587e-02,\n",
       "            1.1667e-03, -3.2378e-02],\n",
       "          [ 7.8100e-03,  5.9985e-03, -9.2184e-03,  ..., -2.6244e-02,\n",
       "           -1.7644e-02,  6.8278e-03]], requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[-5.5111e-05, -5.6352e-03,  8.9084e-03,  ...,  3.3379e-03,\n",
       "           -3.3208e-03, -2.4577e-02],\n",
       "          [ 3.3861e-03,  7.1946e-03, -5.5061e-03,  ..., -1.7104e-02,\n",
       "           -7.4123e-03, -1.8756e-03],\n",
       "          [-7.1169e-04, -1.0005e-02,  1.0312e-03,  ..., -6.7325e-04,\n",
       "            1.2453e-02, -3.8248e-03],\n",
       "          ...,\n",
       "          [ 3.4905e-03,  1.3379e-02,  4.7813e-02,  ...,  5.3354e-02,\n",
       "            1.5951e-02, -1.5819e-02],\n",
       "          [ 2.2557e-02, -7.1769e-03,  7.2348e-03,  ...,  1.6899e-02,\n",
       "            4.2868e-02,  2.1981e-02],\n",
       "          [ 2.7390e-02, -2.0409e-03, -2.5539e-02,  ..., -6.5740e-03,\n",
       "           -3.4945e-03,  6.5052e-02]], requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0348,  0.0280,  0.0289,  ...,  0.0016, -0.0228, -0.0277],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0250,  0.0260,  0.0282,  ..., -0.0081, -0.0077,  0.0162],\n",
       "         requires_grad=True)],\n",
       " [Parameter containing:\n",
       "  tensor([[ 0.0260, -0.0071,  0.0121,  ..., -0.0060, -0.0027, -0.0033],\n",
       "          [ 0.0069, -0.0137,  0.0126,  ..., -0.0057, -0.0127,  0.0025],\n",
       "          [ 0.0125,  0.0031,  0.0023,  ...,  0.0048,  0.0053, -0.0165],\n",
       "          ...,\n",
       "          [ 0.0237, -0.0231,  0.0073,  ..., -0.0243, -0.0238, -0.0140],\n",
       "          [-0.0177, -0.0179, -0.0273,  ..., -0.0220, -0.0054, -0.0315],\n",
       "          [ 0.0248, -0.0073, -0.0117,  ..., -0.0098,  0.0303,  0.0245]],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[-0.0132, -0.0091,  0.0156,  ...,  0.0179,  0.0096, -0.0048],\n",
       "          [ 0.0153, -0.0065,  0.0029,  ...,  0.0091,  0.0257, -0.0207],\n",
       "          [ 0.0097, -0.0194,  0.0006,  ..., -0.0048,  0.0111, -0.0048],\n",
       "          ...,\n",
       "          [ 0.0001, -0.0020,  0.0077,  ...,  0.0332, -0.0258, -0.0045],\n",
       "          [ 0.0325,  0.0212,  0.0124,  ..., -0.0118,  0.0802, -0.0125],\n",
       "          [-0.0294,  0.0080, -0.0331,  ...,  0.0278, -0.0282,  0.0537]],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0217,  0.0109,  0.0248,  ..., -0.0216,  0.0199,  0.0154],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0212,  0.0190,  0.0111,  ..., -0.0072, -0.0082, -0.0213],\n",
       "         requires_grad=True)],\n",
       " [Parameter containing:\n",
       "  tensor([[-0.0073, -0.0197, -0.0188,  ..., -0.0117,  0.0064,  0.0241],\n",
       "          [ 0.0154, -0.0055, -0.0019,  ..., -0.0002, -0.0262,  0.0072],\n",
       "          [-0.0003, -0.0077,  0.0222,  ..., -0.0025, -0.0127,  0.0053],\n",
       "          ...,\n",
       "          [ 0.0098,  0.0023,  0.0069,  ..., -0.0025, -0.0172, -0.0278],\n",
       "          [-0.0058,  0.0208, -0.0038,  ...,  0.0178,  0.0162, -0.0298],\n",
       "          [ 0.0161,  0.0459,  0.0207,  ...,  0.0283,  0.0101,  0.0107]],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[-0.0257, -0.0168, -0.0173,  ...,  0.0090, -0.0041, -0.0044],\n",
       "          [-0.0065, -0.0010, -0.0102,  ..., -0.0077,  0.0016,  0.0142],\n",
       "          [-0.0007, -0.0061, -0.0069,  ..., -0.0182, -0.0023,  0.0217],\n",
       "          ...,\n",
       "          [-0.0401, -0.0158, -0.0030,  ...,  0.0344, -0.0156, -0.0169],\n",
       "          [-0.0072, -0.0194, -0.0098,  ...,  0.0140,  0.0071, -0.0329],\n",
       "          [-0.0105, -0.0005,  0.0008,  ..., -0.0279,  0.0153,  0.0277]],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0201,  0.0081, -0.0048,  ..., -0.0151,  0.0070, -0.0134],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0194, -0.0004, -0.0004,  ..., -0.0047,  0.0130, -0.0278],\n",
       "         requires_grad=True)],\n",
       " [Parameter containing:\n",
       "  tensor([[ 0.0346,  0.0258,  0.0067,  ..., -0.0170, -0.0034,  0.0245],\n",
       "          [ 0.0071,  0.0034,  0.0073,  ..., -0.0093, -0.0051, -0.0076],\n",
       "          [-0.0135, -0.0066, -0.0145,  ..., -0.0107, -0.0045,  0.0123],\n",
       "          ...,\n",
       "          [ 0.0149, -0.0242, -0.0001,  ...,  0.0107, -0.0052, -0.0044],\n",
       "          [-0.0141,  0.0042,  0.0192,  ...,  0.0216,  0.0159,  0.0207],\n",
       "          [ 0.0220, -0.0309, -0.0327,  ..., -0.0229, -0.0086, -0.0260]],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[ 0.0457, -0.0139,  0.0335,  ...,  0.0017, -0.0106, -0.0039],\n",
       "          [-0.0031,  0.0125, -0.0018,  ...,  0.0025,  0.0038,  0.0224],\n",
       "          [ 0.0110,  0.0035, -0.0280,  ..., -0.0090,  0.0060,  0.0072],\n",
       "          ...,\n",
       "          [-0.0084, -0.0084,  0.0008,  ...,  0.0255, -0.0056, -0.0171],\n",
       "          [ 0.0193,  0.0027,  0.0279,  ..., -0.0364,  0.0357,  0.0204],\n",
       "          [-0.0359, -0.0182,  0.0149,  ..., -0.0010,  0.0107,  0.0240]],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0162,  0.0071,  0.0002,  ..., -0.0051,  0.0059,  0.0005],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([-0.0028,  0.0206,  0.0056,  ...,  0.0116, -0.0059,  0.0026],\n",
       "         requires_grad=True)],\n",
       " [Parameter containing:\n",
       "  tensor([[-7.8657e-05,  9.5959e-03, -4.8672e-03,  ...,  1.1935e-02,\n",
       "            1.3429e-04, -1.3365e-02],\n",
       "          [-5.2223e-03, -2.0591e-02, -7.0762e-03,  ...,  3.6497e-03,\n",
       "           -4.4151e-03, -8.5901e-03],\n",
       "          [-2.3526e-02, -4.4537e-03, -5.5906e-03,  ..., -6.5109e-03,\n",
       "            2.5775e-03,  7.0874e-03],\n",
       "          ...,\n",
       "          [ 7.2634e-03, -1.8029e-03, -6.7822e-04,  ...,  1.3679e-02,\n",
       "            9.7104e-03, -1.7462e-02],\n",
       "          [ 2.4691e-02,  1.0715e-04,  7.5727e-03,  ...,  3.5838e-02,\n",
       "           -7.6700e-03, -4.5338e-03],\n",
       "          [ 1.1234e-02, -7.3439e-03,  4.3485e-04,  ...,  2.6016e-02,\n",
       "            1.6748e-02,  1.1816e-02]], requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([[-0.0110, -0.0065, -0.0003,  ..., -0.0043, -0.0111,  0.0067],\n",
       "          [ 0.0008,  0.0245, -0.0079,  ..., -0.0103,  0.0046,  0.0020],\n",
       "          [-0.0020, -0.0086, -0.0143,  ..., -0.0028, -0.0073, -0.0091],\n",
       "          ...,\n",
       "          [ 0.0052, -0.0137, -0.0035,  ...,  0.0122,  0.0084,  0.0480],\n",
       "          [ 0.0168, -0.0260, -0.0114,  ...,  0.0068,  0.0655, -0.0050],\n",
       "          [ 0.0247,  0.0103, -0.0129,  ...,  0.0045, -0.0113,  0.0364]],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0097,  0.0009,  0.0097,  ..., -0.0043,  0.0157, -0.0040],\n",
       "         requires_grad=True),\n",
       "  Parameter containing:\n",
       "  tensor([ 0.0092, -0.0019, -0.0046,  ...,  0.0122, -0.0158,  0.0064],\n",
       "         requires_grad=True)]]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "neural_encoder_back.gru_decoder.all_weights"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "evo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
