{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "./anaconda3/envs/openflamingo/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "./anaconda3/envs/openflamingo/lib/python3.9/site-packages/torch/serialization.py:799: UserWarning: 'torch.load' received a zip file that looks like a TorchScript archive dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to silence this warning)\n",
      "  warnings.warn(\"'torch.load' received a zip file that looks like a TorchScript archive\"\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "./anaconda3/envs/openflamingo/lib/python3.9/site-packages/torch/_utils.py:776: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
      "  return self.fget.__get__(instance, owner)()\n",
      "./.cache/huggingface/modules/transformers_modules/attention.py:289: UserWarning: Using `attn_impl: torch`. If your model does not use `alibi` or `prefix_lm` we recommend using `attn_impl: flash` otherwise we recommend using `attn_impl: triton`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "You are using config.init_device='cpu', but you can also use config.init_device=\"meta\" with Composer + FSDP for fast initialization.\n",
      "Flamingo model initialized with 1046992944 trainable parameters\n"
     ]
    }
   ],
   "source": [
    "from open_flamingo import create_model_and_transforms\n",
    "\n",
    "model, image_processor, tokenizer = create_model_and_transforms(\n",
    "    clip_vision_encoder_path=\"ViT-L-14\",\n",
    "    clip_vision_encoder_pretrained=\"./.cache/clip/ViT-L-14.pt\",\n",
    "    lang_encoder_path=\"./.cache/huggingface/hub/mpt-1b-redpajama-200b/\",\n",
    "    tokenizer_path=\"./.cache/huggingface/hub/mpt-1b-redpajama-200b/\",\n",
    "    cross_attn_every_n_layers=1,\n",
    "    #cache_dir=\"~/.cache\"  # Defaults to ~/.cache\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_IncompatibleKeys(missing_keys=['vision_encoder.class_embedding', 'vision_encoder.positional_embedding', 'vision_encoder.proj', 'vision_encoder.conv1.weight', 'vision_encoder.ln_pre.weight', 'vision_encoder.ln_pre.bias', 'vision_encoder.transformer.resblocks.0.ln_1.weight', 'vision_encoder.transformer.resblocks.0.ln_1.bias', 'vision_encoder.transformer.resblocks.0.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.0.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.0.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.0.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.0.ln_2.weight', 'vision_encoder.transformer.resblocks.0.ln_2.bias', 'vision_encoder.transformer.resblocks.0.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.0.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.0.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.0.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.1.ln_1.weight', 'vision_encoder.transformer.resblocks.1.ln_1.bias', 'vision_encoder.transformer.resblocks.1.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.1.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.1.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.1.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.1.ln_2.weight', 'vision_encoder.transformer.resblocks.1.ln_2.bias', 'vision_encoder.transformer.resblocks.1.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.1.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.1.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.1.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.2.ln_1.weight', 'vision_encoder.transformer.resblocks.2.ln_1.bias', 'vision_encoder.transformer.resblocks.2.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.2.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.2.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.2.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.2.ln_2.weight', 'vision_encoder.transformer.resblocks.2.ln_2.bias', 'vision_encoder.transformer.resblocks.2.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.2.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.2.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.2.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.3.ln_1.weight', 'vision_encoder.transformer.resblocks.3.ln_1.bias', 'vision_encoder.transformer.resblocks.3.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.3.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.3.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.3.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.3.ln_2.weight', 'vision_encoder.transformer.resblocks.3.ln_2.bias', 'vision_encoder.transformer.resblocks.3.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.3.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.3.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.3.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.4.ln_1.weight', 'vision_encoder.transformer.resblocks.4.ln_1.bias', 'vision_encoder.transformer.resblocks.4.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.4.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.4.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.4.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.4.ln_2.weight', 'vision_encoder.transformer.resblocks.4.ln_2.bias', 'vision_encoder.transformer.resblocks.4.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.4.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.4.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.4.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.5.ln_1.weight', 'vision_encoder.transformer.resblocks.5.ln_1.bias', 'vision_encoder.transformer.resblocks.5.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.5.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.5.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.5.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.5.ln_2.weight', 'vision_encoder.transformer.resblocks.5.ln_2.bias', 'vision_encoder.transformer.resblocks.5.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.5.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.5.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.5.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.6.ln_1.weight', 'vision_encoder.transformer.resblocks.6.ln_1.bias', 'vision_encoder.transformer.resblocks.6.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.6.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.6.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.6.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.6.ln_2.weight', 'vision_encoder.transformer.resblocks.6.ln_2.bias', 'vision_encoder.transformer.resblocks.6.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.6.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.6.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.6.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.7.ln_1.weight', 'vision_encoder.transformer.resblocks.7.ln_1.bias', 'vision_encoder.transformer.resblocks.7.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.7.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.7.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.7.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.7.ln_2.weight', 'vision_encoder.transformer.resblocks.7.ln_2.bias', 'vision_encoder.transformer.resblocks.7.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.7.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.7.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.7.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.8.ln_1.weight', 'vision_encoder.transformer.resblocks.8.ln_1.bias', 'vision_encoder.transformer.resblocks.8.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.8.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.8.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.8.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.8.ln_2.weight', 'vision_encoder.transformer.resblocks.8.ln_2.bias', 'vision_encoder.transformer.resblocks.8.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.8.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.8.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.8.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.9.ln_1.weight', 'vision_encoder.transformer.resblocks.9.ln_1.bias', 'vision_encoder.transformer.resblocks.9.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.9.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.9.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.9.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.9.ln_2.weight', 'vision_encoder.transformer.resblocks.9.ln_2.bias', 'vision_encoder.transformer.resblocks.9.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.9.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.9.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.9.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.10.ln_1.weight', 'vision_encoder.transformer.resblocks.10.ln_1.bias', 'vision_encoder.transformer.resblocks.10.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.10.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.10.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.10.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.10.ln_2.weight', 'vision_encoder.transformer.resblocks.10.ln_2.bias', 'vision_encoder.transformer.resblocks.10.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.10.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.10.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.10.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.11.ln_1.weight', 'vision_encoder.transformer.resblocks.11.ln_1.bias', 'vision_encoder.transformer.resblocks.11.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.11.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.11.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.11.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.11.ln_2.weight', 'vision_encoder.transformer.resblocks.11.ln_2.bias', 'vision_encoder.transformer.resblocks.11.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.11.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.11.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.11.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.12.ln_1.weight', 'vision_encoder.transformer.resblocks.12.ln_1.bias', 'vision_encoder.transformer.resblocks.12.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.12.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.12.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.12.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.12.ln_2.weight', 'vision_encoder.transformer.resblocks.12.ln_2.bias', 'vision_encoder.transformer.resblocks.12.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.12.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.12.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.12.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.13.ln_1.weight', 'vision_encoder.transformer.resblocks.13.ln_1.bias', 'vision_encoder.transformer.resblocks.13.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.13.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.13.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.13.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.13.ln_2.weight', 'vision_encoder.transformer.resblocks.13.ln_2.bias', 'vision_encoder.transformer.resblocks.13.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.13.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.13.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.13.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.14.ln_1.weight', 'vision_encoder.transformer.resblocks.14.ln_1.bias', 'vision_encoder.transformer.resblocks.14.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.14.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.14.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.14.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.14.ln_2.weight', 'vision_encoder.transformer.resblocks.14.ln_2.bias', 'vision_encoder.transformer.resblocks.14.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.14.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.14.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.14.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.15.ln_1.weight', 'vision_encoder.transformer.resblocks.15.ln_1.bias', 'vision_encoder.transformer.resblocks.15.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.15.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.15.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.15.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.15.ln_2.weight', 'vision_encoder.transformer.resblocks.15.ln_2.bias', 'vision_encoder.transformer.resblocks.15.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.15.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.15.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.15.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.16.ln_1.weight', 'vision_encoder.transformer.resblocks.16.ln_1.bias', 'vision_encoder.transformer.resblocks.16.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.16.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.16.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.16.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.16.ln_2.weight', 'vision_encoder.transformer.resblocks.16.ln_2.bias', 'vision_encoder.transformer.resblocks.16.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.16.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.16.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.16.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.17.ln_1.weight', 'vision_encoder.transformer.resblocks.17.ln_1.bias', 'vision_encoder.transformer.resblocks.17.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.17.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.17.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.17.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.17.ln_2.weight', 'vision_encoder.transformer.resblocks.17.ln_2.bias', 'vision_encoder.transformer.resblocks.17.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.17.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.17.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.17.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.18.ln_1.weight', 'vision_encoder.transformer.resblocks.18.ln_1.bias', 'vision_encoder.transformer.resblocks.18.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.18.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.18.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.18.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.18.ln_2.weight', 'vision_encoder.transformer.resblocks.18.ln_2.bias', 'vision_encoder.transformer.resblocks.18.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.18.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.18.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.18.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.19.ln_1.weight', 'vision_encoder.transformer.resblocks.19.ln_1.bias', 'vision_encoder.transformer.resblocks.19.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.19.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.19.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.19.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.19.ln_2.weight', 'vision_encoder.transformer.resblocks.19.ln_2.bias', 'vision_encoder.transformer.resblocks.19.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.19.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.19.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.19.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.20.ln_1.weight', 'vision_encoder.transformer.resblocks.20.ln_1.bias', 'vision_encoder.transformer.resblocks.20.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.20.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.20.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.20.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.20.ln_2.weight', 'vision_encoder.transformer.resblocks.20.ln_2.bias', 'vision_encoder.transformer.resblocks.20.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.20.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.20.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.20.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.21.ln_1.weight', 'vision_encoder.transformer.resblocks.21.ln_1.bias', 'vision_encoder.transformer.resblocks.21.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.21.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.21.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.21.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.21.ln_2.weight', 'vision_encoder.transformer.resblocks.21.ln_2.bias', 'vision_encoder.transformer.resblocks.21.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.21.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.21.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.21.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.22.ln_1.weight', 'vision_encoder.transformer.resblocks.22.ln_1.bias', 'vision_encoder.transformer.resblocks.22.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.22.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.22.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.22.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.22.ln_2.weight', 'vision_encoder.transformer.resblocks.22.ln_2.bias', 'vision_encoder.transformer.resblocks.22.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.22.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.22.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.22.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.23.ln_1.weight', 'vision_encoder.transformer.resblocks.23.ln_1.bias', 'vision_encoder.transformer.resblocks.23.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.23.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.23.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.23.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.23.ln_2.weight', 'vision_encoder.transformer.resblocks.23.ln_2.bias', 'vision_encoder.transformer.resblocks.23.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.23.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.23.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.23.mlp.c_proj.bias', 'vision_encoder.ln_post.weight', 'vision_encoder.ln_post.bias', 'lang_encoder.transformer.blocks.0.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.0.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.0.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.0.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.0.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.0.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.0.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.0.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.1.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.1.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.1.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.1.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.1.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.1.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.1.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.1.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.2.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.2.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.2.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.2.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.2.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.2.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.2.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.2.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.3.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.3.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.3.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.3.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.3.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.3.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.3.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.3.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.4.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.4.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.4.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.4.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.4.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.4.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.4.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.4.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.5.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.5.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.5.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.5.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.5.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.5.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.5.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.5.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.6.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.6.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.6.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.6.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.6.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.6.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.6.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.6.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.7.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.7.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.7.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.7.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.7.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.7.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.7.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.7.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.8.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.8.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.8.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.8.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.8.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.8.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.8.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.8.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.9.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.9.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.9.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.9.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.9.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.9.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.9.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.9.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.10.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.10.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.10.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.10.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.10.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.10.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.10.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.10.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.11.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.11.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.11.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.11.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.11.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.11.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.11.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.11.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.12.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.12.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.12.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.12.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.12.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.12.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.12.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.12.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.13.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.13.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.13.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.13.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.13.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.13.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.13.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.13.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.14.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.14.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.14.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.14.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.14.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.14.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.14.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.14.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.15.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.15.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.15.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.15.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.15.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.15.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.15.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.15.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.16.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.16.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.16.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.16.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.16.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.16.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.16.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.16.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.17.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.17.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.17.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.17.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.17.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.17.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.17.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.17.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.18.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.18.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.18.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.18.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.18.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.18.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.18.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.18.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.19.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.19.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.19.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.19.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.19.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.19.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.19.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.19.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.20.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.20.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.20.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.20.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.20.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.20.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.20.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.20.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.21.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.21.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.21.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.21.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.21.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.21.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.21.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.21.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.22.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.22.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.22.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.22.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.22.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.22.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.22.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.22.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.blocks.23.decoder_layer.ln_1.weight', 'lang_encoder.transformer.blocks.23.decoder_layer.attn.Wqkv.weight', 'lang_encoder.transformer.blocks.23.decoder_layer.attn.q_ln.weight', 'lang_encoder.transformer.blocks.23.decoder_layer.attn.k_ln.weight', 'lang_encoder.transformer.blocks.23.decoder_layer.attn.out_proj.weight', 'lang_encoder.transformer.blocks.23.decoder_layer.ln_2.weight', 'lang_encoder.transformer.blocks.23.decoder_layer.mlp.mlp_up.weight', 'lang_encoder.transformer.blocks.23.decoder_layer.mlp.mlp_down.weight', 'lang_encoder.transformer.ln_f.weight', 'lang_encoder.old_decoder_blocks.0.ln_1.weight', 'lang_encoder.old_decoder_blocks.0.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.0.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.0.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.0.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.0.ln_2.weight', 'lang_encoder.old_decoder_blocks.0.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.0.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.1.ln_1.weight', 'lang_encoder.old_decoder_blocks.1.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.1.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.1.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.1.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.1.ln_2.weight', 'lang_encoder.old_decoder_blocks.1.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.1.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.2.ln_1.weight', 'lang_encoder.old_decoder_blocks.2.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.2.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.2.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.2.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.2.ln_2.weight', 'lang_encoder.old_decoder_blocks.2.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.2.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.3.ln_1.weight', 'lang_encoder.old_decoder_blocks.3.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.3.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.3.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.3.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.3.ln_2.weight', 'lang_encoder.old_decoder_blocks.3.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.3.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.4.ln_1.weight', 'lang_encoder.old_decoder_blocks.4.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.4.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.4.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.4.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.4.ln_2.weight', 'lang_encoder.old_decoder_blocks.4.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.4.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.5.ln_1.weight', 'lang_encoder.old_decoder_blocks.5.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.5.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.5.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.5.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.5.ln_2.weight', 'lang_encoder.old_decoder_blocks.5.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.5.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.6.ln_1.weight', 'lang_encoder.old_decoder_blocks.6.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.6.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.6.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.6.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.6.ln_2.weight', 'lang_encoder.old_decoder_blocks.6.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.6.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.7.ln_1.weight', 'lang_encoder.old_decoder_blocks.7.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.7.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.7.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.7.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.7.ln_2.weight', 'lang_encoder.old_decoder_blocks.7.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.7.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.8.ln_1.weight', 'lang_encoder.old_decoder_blocks.8.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.8.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.8.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.8.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.8.ln_2.weight', 'lang_encoder.old_decoder_blocks.8.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.8.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.9.ln_1.weight', 'lang_encoder.old_decoder_blocks.9.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.9.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.9.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.9.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.9.ln_2.weight', 'lang_encoder.old_decoder_blocks.9.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.9.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.10.ln_1.weight', 'lang_encoder.old_decoder_blocks.10.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.10.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.10.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.10.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.10.ln_2.weight', 'lang_encoder.old_decoder_blocks.10.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.10.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.11.ln_1.weight', 'lang_encoder.old_decoder_blocks.11.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.11.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.11.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.11.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.11.ln_2.weight', 'lang_encoder.old_decoder_blocks.11.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.11.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.12.ln_1.weight', 'lang_encoder.old_decoder_blocks.12.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.12.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.12.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.12.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.12.ln_2.weight', 'lang_encoder.old_decoder_blocks.12.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.12.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.13.ln_1.weight', 'lang_encoder.old_decoder_blocks.13.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.13.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.13.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.13.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.13.ln_2.weight', 'lang_encoder.old_decoder_blocks.13.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.13.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.14.ln_1.weight', 'lang_encoder.old_decoder_blocks.14.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.14.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.14.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.14.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.14.ln_2.weight', 'lang_encoder.old_decoder_blocks.14.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.14.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.15.ln_1.weight', 'lang_encoder.old_decoder_blocks.15.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.15.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.15.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.15.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.15.ln_2.weight', 'lang_encoder.old_decoder_blocks.15.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.15.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.16.ln_1.weight', 'lang_encoder.old_decoder_blocks.16.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.16.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.16.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.16.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.16.ln_2.weight', 'lang_encoder.old_decoder_blocks.16.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.16.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.17.ln_1.weight', 'lang_encoder.old_decoder_blocks.17.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.17.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.17.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.17.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.17.ln_2.weight', 'lang_encoder.old_decoder_blocks.17.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.17.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.18.ln_1.weight', 'lang_encoder.old_decoder_blocks.18.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.18.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.18.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.18.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.18.ln_2.weight', 'lang_encoder.old_decoder_blocks.18.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.18.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.19.ln_1.weight', 'lang_encoder.old_decoder_blocks.19.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.19.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.19.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.19.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.19.ln_2.weight', 'lang_encoder.old_decoder_blocks.19.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.19.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.20.ln_1.weight', 'lang_encoder.old_decoder_blocks.20.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.20.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.20.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.20.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.20.ln_2.weight', 'lang_encoder.old_decoder_blocks.20.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.20.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.21.ln_1.weight', 'lang_encoder.old_decoder_blocks.21.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.21.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.21.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.21.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.21.ln_2.weight', 'lang_encoder.old_decoder_blocks.21.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.21.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.22.ln_1.weight', 'lang_encoder.old_decoder_blocks.22.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.22.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.22.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.22.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.22.ln_2.weight', 'lang_encoder.old_decoder_blocks.22.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.22.mlp.mlp_down.weight', 'lang_encoder.old_decoder_blocks.23.ln_1.weight', 'lang_encoder.old_decoder_blocks.23.attn.Wqkv.weight', 'lang_encoder.old_decoder_blocks.23.attn.q_ln.weight', 'lang_encoder.old_decoder_blocks.23.attn.k_ln.weight', 'lang_encoder.old_decoder_blocks.23.attn.out_proj.weight', 'lang_encoder.old_decoder_blocks.23.ln_2.weight', 'lang_encoder.old_decoder_blocks.23.mlp.mlp_up.weight', 'lang_encoder.old_decoder_blocks.23.mlp.mlp_down.weight', 'lang_encoder.gated_cross_attn_layers.0.attn_gate', 'lang_encoder.gated_cross_attn_layers.0.ff_gate', 'lang_encoder.gated_cross_attn_layers.0.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.0.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.0.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.0.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.0.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.0.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.0.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.0.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.0.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.1.attn_gate', 'lang_encoder.gated_cross_attn_layers.1.ff_gate', 'lang_encoder.gated_cross_attn_layers.1.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.1.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.1.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.1.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.1.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.1.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.1.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.1.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.1.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.2.attn_gate', 'lang_encoder.gated_cross_attn_layers.2.ff_gate', 'lang_encoder.gated_cross_attn_layers.2.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.2.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.2.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.2.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.2.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.2.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.2.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.2.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.2.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.3.attn_gate', 'lang_encoder.gated_cross_attn_layers.3.ff_gate', 'lang_encoder.gated_cross_attn_layers.3.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.3.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.3.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.3.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.3.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.3.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.3.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.3.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.3.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.4.attn_gate', 'lang_encoder.gated_cross_attn_layers.4.ff_gate', 'lang_encoder.gated_cross_attn_layers.4.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.4.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.4.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.4.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.4.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.4.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.4.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.4.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.4.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.5.attn_gate', 'lang_encoder.gated_cross_attn_layers.5.ff_gate', 'lang_encoder.gated_cross_attn_layers.5.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.5.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.5.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.5.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.5.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.5.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.5.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.5.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.5.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.6.attn_gate', 'lang_encoder.gated_cross_attn_layers.6.ff_gate', 'lang_encoder.gated_cross_attn_layers.6.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.6.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.6.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.6.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.6.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.6.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.6.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.6.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.6.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.7.attn_gate', 'lang_encoder.gated_cross_attn_layers.7.ff_gate', 'lang_encoder.gated_cross_attn_layers.7.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.7.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.7.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.7.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.7.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.7.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.7.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.7.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.7.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.8.attn_gate', 'lang_encoder.gated_cross_attn_layers.8.ff_gate', 'lang_encoder.gated_cross_attn_layers.8.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.8.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.8.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.8.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.8.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.8.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.8.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.8.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.8.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.9.attn_gate', 'lang_encoder.gated_cross_attn_layers.9.ff_gate', 'lang_encoder.gated_cross_attn_layers.9.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.9.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.9.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.9.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.9.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.9.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.9.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.9.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.9.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.10.attn_gate', 'lang_encoder.gated_cross_attn_layers.10.ff_gate', 'lang_encoder.gated_cross_attn_layers.10.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.10.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.10.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.10.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.10.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.10.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.10.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.10.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.10.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.11.attn_gate', 'lang_encoder.gated_cross_attn_layers.11.ff_gate', 'lang_encoder.gated_cross_attn_layers.11.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.11.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.11.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.11.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.11.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.11.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.11.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.11.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.11.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.12.attn_gate', 'lang_encoder.gated_cross_attn_layers.12.ff_gate', 'lang_encoder.gated_cross_attn_layers.12.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.12.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.12.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.12.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.12.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.12.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.12.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.12.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.12.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.13.attn_gate', 'lang_encoder.gated_cross_attn_layers.13.ff_gate', 'lang_encoder.gated_cross_attn_layers.13.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.13.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.13.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.13.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.13.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.13.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.13.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.13.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.13.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.14.attn_gate', 'lang_encoder.gated_cross_attn_layers.14.ff_gate', 'lang_encoder.gated_cross_attn_layers.14.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.14.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.14.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.14.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.14.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.14.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.14.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.14.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.14.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.15.attn_gate', 'lang_encoder.gated_cross_attn_layers.15.ff_gate', 'lang_encoder.gated_cross_attn_layers.15.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.15.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.15.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.15.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.15.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.15.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.15.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.15.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.15.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.16.attn_gate', 'lang_encoder.gated_cross_attn_layers.16.ff_gate', 'lang_encoder.gated_cross_attn_layers.16.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.16.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.16.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.16.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.16.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.16.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.16.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.16.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.16.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.17.attn_gate', 'lang_encoder.gated_cross_attn_layers.17.ff_gate', 'lang_encoder.gated_cross_attn_layers.17.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.17.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.17.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.17.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.17.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.17.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.17.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.17.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.17.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.18.attn_gate', 'lang_encoder.gated_cross_attn_layers.18.ff_gate', 'lang_encoder.gated_cross_attn_layers.18.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.18.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.18.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.18.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.18.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.18.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.18.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.18.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.18.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.19.attn_gate', 'lang_encoder.gated_cross_attn_layers.19.ff_gate', 'lang_encoder.gated_cross_attn_layers.19.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.19.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.19.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.19.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.19.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.19.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.19.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.19.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.19.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.20.attn_gate', 'lang_encoder.gated_cross_attn_layers.20.ff_gate', 'lang_encoder.gated_cross_attn_layers.20.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.20.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.20.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.20.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.20.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.20.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.20.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.20.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.20.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.21.attn_gate', 'lang_encoder.gated_cross_attn_layers.21.ff_gate', 'lang_encoder.gated_cross_attn_layers.21.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.21.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.21.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.21.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.21.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.21.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.21.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.21.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.21.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.22.attn_gate', 'lang_encoder.gated_cross_attn_layers.22.ff_gate', 'lang_encoder.gated_cross_attn_layers.22.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.22.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.22.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.22.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.22.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.22.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.22.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.22.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.22.ff.3.weight', 'lang_encoder.gated_cross_attn_layers.23.attn_gate', 'lang_encoder.gated_cross_attn_layers.23.ff_gate', 'lang_encoder.gated_cross_attn_layers.23.attn.norm.weight', 'lang_encoder.gated_cross_attn_layers.23.attn.norm.bias', 'lang_encoder.gated_cross_attn_layers.23.attn.to_q.weight', 'lang_encoder.gated_cross_attn_layers.23.attn.to_kv.weight', 'lang_encoder.gated_cross_attn_layers.23.attn.to_out.weight', 'lang_encoder.gated_cross_attn_layers.23.ff.0.weight', 'lang_encoder.gated_cross_attn_layers.23.ff.0.bias', 'lang_encoder.gated_cross_attn_layers.23.ff.1.weight', 'lang_encoder.gated_cross_attn_layers.23.ff.3.weight'], unexpected_keys=[])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# grab model checkpoint from huggingface hub\n",
    "from huggingface_hub import hf_hub_download\n",
    "import torch\n",
    "\n",
    "# checkpoint_path = hf_hub_download(\"openflamingo/OpenFlamingo-3B-vitl-mpt1b\", \"checkpoint.pt\")\n",
    "model.load_state_dict(torch.load(\"../OpenFlamingo-3B-vitl-mpt1b/checkpoint.pt\"), strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<bound method Module.modules of Flamingo(\n",
      "  (vision_encoder): VisionTransformer(\n",
      "    (conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)\n",
      "    (patch_dropout): Identity()\n",
      "    (ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
      "    (transformer): Transformer(\n",
      "      (resblocks): ModuleList(\n",
      "        (0-23): 24 x ResidualAttentionBlock(\n",
      "          (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
      "          (attn): MultiheadAttention(\n",
      "            (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)\n",
      "          )\n",
      "          (ls_1): Identity()\n",
      "          (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
      "          (mlp): Sequential(\n",
      "            (c_fc): Linear(in_features=1024, out_features=4096, bias=True)\n",
      "            (gelu): GELU(approximate='none')\n",
      "            (c_proj): Linear(in_features=4096, out_features=1024, bias=True)\n",
      "          )\n",
      "          (ls_2): Identity()\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (ln_post): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
      "  )\n",
      "  (perceiver): PerceiverResampler(\n",
      "    (layers): ModuleList(\n",
      "      (0-5): 6 x ModuleList(\n",
      "        (0): PerceiverAttention(\n",
      "          (norm_media): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
      "          (norm_latents): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
      "          (to_q): Linear(in_features=1024, out_features=512, bias=False)\n",
      "          (to_kv): Linear(in_features=1024, out_features=1024, bias=False)\n",
      "          (to_out): Linear(in_features=512, out_features=1024, bias=False)\n",
      "        )\n",
      "        (1): Sequential(\n",
      "          (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
      "          (1): Linear(in_features=1024, out_features=4096, bias=False)\n",
      "          (2): GELU(approximate='none')\n",
      "          (3): Linear(in_features=4096, out_features=1024, bias=False)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
      "  )\n",
      "  (lang_encoder): MosaicGPT(\n",
      "    (transformer): ModuleDict(\n",
      "      (wte): Embedding(50280, 2048)\n",
      "      (emb_drop): Dropout(p=0, inplace=False)\n",
      "      (blocks): ModuleList(\n",
      "        (0-23): 24 x FlamingoLayer(\n",
      "          (gated_cross_attn_layer): GatedCrossAttentionBlock(\n",
      "            (attn): MaskedCrossAttention(\n",
      "              (norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
      "              (to_q): Linear(in_features=2048, out_features=512, bias=False)\n",
      "              (to_kv): Linear(in_features=1024, out_features=1024, bias=False)\n",
      "              (to_out): Linear(in_features=512, out_features=2048, bias=False)\n",
      "            )\n",
      "            (ff): Sequential(\n",
      "              (0): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
      "              (1): Linear(in_features=2048, out_features=8192, bias=False)\n",
      "              (2): GELU(approximate='none')\n",
      "              (3): Linear(in_features=8192, out_features=2048, bias=False)\n",
      "            )\n",
      "          )\n",
      "          (decoder_layer): GPTBlock(\n",
      "            (ln_1): LPLayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
      "            (attn): MultiheadAttention(\n",
      "              (Wqkv): Linear(in_features=2048, out_features=6144, bias=False)\n",
      "              (q_ln): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
      "              (k_ln): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
      "              (out_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
      "            )\n",
      "            (ln_2): LPLayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
      "            (mlp): GPTMLP(\n",
      "              (mlp_up): Linear(in_features=2048, out_features=8192, bias=False)\n",
      "              (mlp_act): GELU(approximate='none')\n",
      "              (mlp_down): Linear(in_features=8192, out_features=2048, bias=False)\n",
      "            )\n",
      "            (resid_attn_dropout): Dropout(p=0, inplace=False)\n",
      "            (resid_mlp_dropout): Dropout(p=0, inplace=False)\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "      (ln_f): LPLayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
      "    )\n",
      "    (old_decoder_blocks): ModuleList(\n",
      "      (0-23): 24 x GPTBlock(\n",
      "        (ln_1): LPLayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
      "        (attn): MultiheadAttention(\n",
      "          (Wqkv): Linear(in_features=2048, out_features=6144, bias=False)\n",
      "          (q_ln): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
      "          (k_ln): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
      "          (out_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
      "        )\n",
      "        (ln_2): LPLayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
      "        (mlp): GPTMLP(\n",
      "          (mlp_up): Linear(in_features=2048, out_features=8192, bias=False)\n",
      "          (mlp_act): GELU(approximate='none')\n",
      "          (mlp_down): Linear(in_features=8192, out_features=2048, bias=False)\n",
      "        )\n",
      "        (resid_attn_dropout): Dropout(p=0, inplace=False)\n",
      "        (resid_mlp_dropout): Dropout(p=0, inplace=False)\n",
      "      )\n",
      "    )\n",
      "    (gated_cross_attn_layers): ModuleList(\n",
      "      (0-23): 24 x GatedCrossAttentionBlock(\n",
      "        (attn): MaskedCrossAttention(\n",
      "          (norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
      "          (to_q): Linear(in_features=2048, out_features=512, bias=False)\n",
      "          (to_kv): Linear(in_features=1024, out_features=1024, bias=False)\n",
      "          (to_out): Linear(in_features=512, out_features=2048, bias=False)\n",
      "        )\n",
      "        (ff): Sequential(\n",
      "          (0): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
      "          (1): Linear(in_features=2048, out_features=8192, bias=False)\n",
      "          (2): GELU(approximate='none')\n",
      "          (3): Linear(in_features=8192, out_features=2048, bias=False)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "  )\n",
      ")>\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "print(str(model.modules))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50277 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated text:  <image>An image of two cats.<|endofchunk|><image>An image of a bathroom sink.<|endofchunk|><image>An image of a Christmas buffet.<|endofchunk|>\n"
     ]
    }
   ],
   "source": [
    "from PIL import Image\n",
    "import requests\n",
    "import torch\n",
    "\n",
    "\"\"\"\n",
    "Step 1: Load images\n",
    "\"\"\"\n",
    "demo_image_one = Image.open(\n",
    "    requests.get(\n",
    "        \"http://images.cocodataset.org/val2017/000000039769.jpg\", stream=True\n",
    "    ).raw\n",
    ")\n",
    "\n",
    "demo_image_two = Image.open(\n",
    "    requests.get(\n",
    "        \"http://images.cocodataset.org/test-stuff2017/000000028137.jpg\",\n",
    "        stream=True\n",
    "    ).raw\n",
    ")\n",
    "\n",
    "query_image = Image.open(\n",
    "    requests.get(\n",
    "        \"http://images.cocodataset.org/test-stuff2017/000000028352.jpg\", \n",
    "        stream=True\n",
    "    ).raw\n",
    ")\n",
    "\n",
    "\n",
    "\"\"\"\n",
    "Step 2: Preprocessing images\n",
    "Details: For OpenFlamingo, we expect the image to be a torch tensor of shape \n",
    " batch_size x num_media x num_frames x channels x height x width. \n",
    " In this case batch_size = 1, num_media = 3, num_frames = 1,\n",
    " channels = 3, height = 224, width = 224.\n",
    "\"\"\"\n",
    "vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)]\n",
    "vision_x = torch.cat(vision_x, dim=0)\n",
    "vision_x = vision_x.unsqueeze(1).unsqueeze(0)\n",
    "\n",
    "\"\"\"\n",
    "Step 3: Preprocessing text\n",
    "Details: In the text we expect an <image> special token to indicate where an image is.\n",
    " We also expect an <|endofchunk|> special token to indicate the end of the text \n",
    " portion associated with an image.\n",
    "\"\"\"\n",
    "tokenizer.padding_side = \"left\" # For generation padding tokens should be on the left\n",
    "lang_x = tokenizer(\n",
    "    [\"<image>An image of two cats.<|endofchunk|><image>An image of a bathroom sink.<|endofchunk|><image>An image of\"],\n",
    "    return_tensors=\"pt\",\n",
    ")\n",
    "\n",
    "\n",
    "\"\"\"\n",
    "Step 4: Generate text\n",
    "\"\"\"\n",
    "generated_text = model.generate(\n",
    "    vision_x=vision_x,\n",
    "    lang_x=lang_x[\"input_ids\"],\n",
    "    attention_mask=lang_x[\"attention_mask\"],\n",
    "    max_new_tokens=20,\n",
    "    num_beams=3,\n",
    ")\n",
    "\n",
    "print(\"Generated text: \", tokenizer.decode(generated_text[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "openflamingo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
