{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import gc\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from transformer_lens import HookedTransformer\n",
    "from transformer_lens.loading_from_pretrained import get_pretrained_model_config, get_pretrained_state_dict, get_checkpoint_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x7f3c8df91150>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME = \"EleutherAI/pythia-6.9b\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31mSignature:\u001b[0m\n",
      "\u001b[0mHookedTransformer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_pretrained\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mmodel_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mfold_ln\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbool\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mcenter_writing_weights\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbool\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mcenter_unembed\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbool\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mrefactor_factored_attn_matrices\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbool\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mcheckpoint_index\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mcheckpoint_value\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mhf_model\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtransformers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodeling_auto\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAutoModelForCausalLM\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mdevice\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mn_devices\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mtokenizer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtransformers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtokenization_utils_base\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPreTrainedTokenizerBase\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mmove_to_device\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbool\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mfold_value_biases\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbool\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mdefault_prepend_bos\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbool\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mdefault_padding_side\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mLiteral\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'left'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'right'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'right'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'float32'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0;34m**\u001b[0m\u001b[0mfrom_pretrained_kwargs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;34m'HookedTransformer'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mDocstring:\u001b[0m\n",
      "Load in a Pretrained Model.\n",
      "\n",
      "Load in pretrained model weights to the HookedTransformer format and optionally to do some\n",
      "processing to make the model easier to interpret. Currently supports loading from most\n",
      "autoregressive HuggingFace models (``gpt2``, ``neo``, ``gptj``, ``opt``...) and from a range\n",
      "of toy models and SoLU models trained by Neel Nanda. The full list is available in the docs\n",
      "under :doc:`model properties</generated/model_properties_table>`. Also supports loading from\n",
      "a checkpoint for checkpointed models (currently, models trained by NeelNanda and the\n",
      "stanford-crfm models (using parameters ``checkpoint_index`` and ``checkpoint_value``).\n",
      "\n",
      "See :meth:`load_and_process_state_dict` for details on the processing (folding layer norm,\n",
      "centering the unembedding and centering the writing weights).\n",
      "\n",
      "Example:\n",
      "\n",
      ">>> from transformer_lens import HookedTransformer\n",
      ">>> model = HookedTransformer.from_pretrained(\"tiny-stories-1M\")\n",
      "Loaded pretrained model tiny-stories-1M into HookedTransformer\n",
      "\n",
      "Args:\n",
      "    model_name: The model name - must be an element of\n",
      "        :const:`transformer_lens.loading_from_pretrained.OFFICIAL_MODEL_NAMES` or an alias\n",
      "        of one. The full list of available models can be found in the docs under :doc:`model\n",
      "        properties</generated/model_properties_table>`.\n",
      "    fold_ln: Whether to fold in the LayerNorm weights to the\n",
      "        subsequent linear layer. This does not change the computation.\n",
      "\n",
      "        `LayerNorm\n",
      "        <https://wandb.ai/wandb_fc/LayerNorm/reports/Layer-Normalization-in-Pytorch-With-Examples---VmlldzoxMjk5MTk1>`_\n",
      "        is a common regularization technique used in transformers. Unlike BatchNorm, it\n",
      "        cannot be turned off at inference time, as it significantly alters the mathematical\n",
      "        function implemented by the transformer.\n",
      "\n",
      "        When `fold_ln` is set to True, LayerNorm (with weights :math:`w_{ln}` and\n",
      "        :math:`b_{ln}`) followed by a linear layer (:math:`W + b`) is optimized to\n",
      "        LayerNormPre (just centering & normalizing) followed by a new linear layer with\n",
      "        :math:`W_{eff} = w[:,   ext{None}] * W` (element-wise multiplication) and\n",
      "        :math:`b_{eff} = b + b_{ln} @ W`. This transformation is computationally equivalent\n",
      "        and simplifies the model's interpretability. It essentially merges LayerNorm weights\n",
      "        into the subsequent linear layer's weights, which is handled by HookedTransformer\n",
      "        when loading pre-trained weights. Set `fold_ln` to False when loading a state dict\n",
      "        if you wish to turn this off.\n",
      "\n",
      "        Mathematically, LayerNorm is defined as follows:\n",
      "\n",
      "        .. math::\n",
      "            x_1 &= x_0 - \\text{mean}(x_0)\n",
      "\n",
      "            x_2 &= \\frac{x_1}{\\sqrt{\\text{mean}(x_1^2)}}\n",
      "\n",
      "            x_3 &= x_2 \\cdot w\n",
      "\n",
      "            x_4 &= x_3 + b\n",
      "\n",
      "        For further details, refer to `this document\n",
      "        <https://transformer-circuits.pub/2021/framework/index.html#:~:text=Handling%20Layer%20Normalization>`_.\n",
      "    center_writing_weights: Whether to center weights\n",
      "        writing to the residual stream (ie set mean to be zero). Due to LayerNorm this\n",
      "        doesn't change the computation.\n",
      "\n",
      "        A related idea to folding layernorm (``fold_ln``) - *every* component reading an\n",
      "        input from the residual stream is preceded by a LayerNorm, which means that the mean\n",
      "        of a residual stream vector (ie the component in the direction of all ones) never\n",
      "        matters. This means we can remove the all ones component of weights and biases whose\n",
      "        output *writes* to the residual stream. Mathematically, ``W_writing -=\n",
      "        W_writing.mean(dim=1, keepdim=True)``.\n",
      "    center_unembed: Whether to center W_U (ie set mean\n",
      "        to be zero). Softmax is translation invariant so this doesn't affect log probs or\n",
      "        loss, but does change logits.\n",
      "\n",
      "        The logits are fed into a softmax. Softmax is translation invariant (eg, adding 1 to\n",
      "        every logit doesn't change the output), so we can simplify things by setting the\n",
      "        mean of the logits to be zero. This is equivalent to setting the mean of every\n",
      "        output vector of ``W_U`` to zero. In code, ``W_U -= W_U.mean(dim=-1,\n",
      "        keepdim=True)``.\n",
      "    refactor_factored_attn_matrices: Whether to convert the factored\n",
      "        matrices (W_Q & W_K, and W_O & W_V) to be \"even\". Defaults to False\n",
      "    checkpoint_index: If loading from a checkpoint, the index of\n",
      "        the checkpoint to load.\n",
      "    checkpoint_value: If loading from a checkpoint, the value of\n",
      "        the checkpoint to load, ie the step or token number (each model has checkpoints\n",
      "        labelled with exactly one of these). E.g. ``1000`` for a checkpoint taken at step\n",
      "        1000 or after 1000 tokens. If `checkpoint_index` is also specified, this will be\n",
      "        ignored.\n",
      "    hf_model: If you have already loaded in the\n",
      "        HuggingFace model, you can pass it in here rather than needing to recreate the\n",
      "        object. Defaults to None.\n",
      "    device: The device to load the model onto. By\n",
      "        default will load to CUDA if available, else CPU.\n",
      "    n_devices: The number of devices to split the model\n",
      "        across. Defaults to 1. If greater than 1, `device` must be cuda.\n",
      "    tokenizer: The tokenizer to use for the model. If not\n",
      "        provided, it is inferred from cfg.tokenizer_name or initialized to None. If None,\n",
      "        then the model cannot be passed strings, and d_vocab must be explicitly set.\n",
      "    move_to_device: Whether to move the model to the device specified in\n",
      "        cfg. device. Must be true if `n_devices` in the config is greater than 1, since the\n",
      "        model's layers will be split across multiple devices.\n",
      "    fold_value_biases: Each attention head has a value bias. Values are averaged to create\n",
      "        mixed values (``z``), weighted by the attention pattern, but as the bias is\n",
      "        constant, its contribution to ``z`` is exactly the same. The output of a head is ``z\n",
      "        @ W_O``, and so the value bias just linearly adds to the output of the head. This\n",
      "        means that the value bias of a head has nothing to do with the head, and is just a\n",
      "        constant added to the attention layer outputs. We can take the sum across these and\n",
      "        b_O to get an \"effective bias\" for the layer. In code, we set ``b_V=0``. and ``b_O =\n",
      "        (b_V @ W_O).sum(dim=0) + b_O``.\n",
      "\n",
      "        The technical derivation of this is as follows. ``v = residual @ W_V[h] +\n",
      "        broadcast_b_V[h]`` for each head ``h`` (where ``b_V`` is broadcast up from shape\n",
      "        ``d_head`` to shape ``[position, d_head]``). And ``z = pattern[h] @ v = pattern[h] @\n",
      "        residual @ W_V[h] + pattern[h] @ broadcast_b_V[h]``. Because ``pattern[h]`` is\n",
      "        ``[destination_position, source_position]`` and ``broadcast_b_V`` is constant along\n",
      "        the ``(source_)position`` dimension, we're basically just multiplying it by the sum\n",
      "        of the pattern across the ``source_position`` dimension, which is just ``1``. So it\n",
      "        remains exactly the same, and so is just broadcast across the destination positions.\n",
      "    default_prepend_bos: Default behavior of whether to prepend the BOS\n",
      "        token when the methods of HookedTransformer process input text to tokenize (only\n",
      "        when input is a string). Defaults to True - even for models not explicitly trained\n",
      "        with this, heads often use the first position as a resting position and accordingly\n",
      "        lose information from the first token, so this empirically seems to give better\n",
      "        results. To change the default behavior to False, pass in default_prepend_bos=False.\n",
      "        Note that you can also locally override the default behavior by passing in\n",
      "        prepend_bos=True/False when you call a method that processes the input string.\n",
      "    from_pretrained_kwargs: Any other optional argument passed to\n",
      "        HuggingFace's from_pretrained (e.g. \"cache_dir\" or \"torch_dtype\"). Also passed to\n",
      "        other HuggingFace functions when compatible. For some models or arguments it doesn't\n",
      "        work, especially for models that are not internally loaded with HuggingFace's\n",
      "        from_pretrained (e.g. SoLU models).\n",
      "    dtype: What data type to load the model in (also sets the dtype of\n",
      "        the HuggingFace model). Set to bfloat16 or float16 if you get out of memory errors when loading\n",
      "        the model.\n",
      "    default_padding_side: Which side to pad on when tokenizing. Defaults to\n",
      "        \"right\".\n",
      "\u001b[0;31mFile:\u001b[0m      /workspace/transformer_lens/transformer_lens/HookedTransformer.py\n",
      "\u001b[0;31mType:\u001b[0m      method"
     ]
    }
   ],
   "source": [
    "HookedTransformer.from_pretrained?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a8589e34c5564b7db8a0d6a404a6e15f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "14dc00976a524fe2af765a0e8b0876cd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "pytorch_model-00001-of-00002.bin:   0%|          | 0.00/9.91G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6856579183b44c6bbe6713110157ff0b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "pytorch_model-00002-of-00002.bin:   0%|          | 0.00/3.94G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "724f699ee7d04f0883206bd34096812f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "WARNING:root:With reduced precision, it is advised to use `from_pretrained_no_processing` instead of `from_pretrained`.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model EleutherAI/pythia-6.9b into HookedTransformer\n"
     ]
    }
   ],
   "source": [
    "model = HookedTransformer.from_pretrained(\n",
    "    MODEL_NAME, \n",
    "    checkpoint_value=141000,\n",
    "    center_unembed=True,\n",
    "    center_writing_weights=True,\n",
    "    fold_ln=True,\n",
    "    dtype=torch.bfloat16,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "86dccaeb1a014ab3a06bc3aea30274b9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f35409b9270745bf94520d07c329ed8d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "pytorch_model-00001-of-00002.bin:   0%|          | 0.00/9.91G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ff1f3f922c284c7687127866233703f4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "pytorch_model-00002-of-00002.bin:   0%|          | 0.00/3.94G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "628ac18b244441d68a3e41437413f7d1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "WARNING:root:With reduced precision, it is advised to use `from_pretrained_no_processing` instead of `from_pretrained`.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model EleutherAI/pythia-6.9b into HookedTransformer\n"
     ]
    }
   ],
   "source": [
    "model = HookedTransformer.from_pretrained(\n",
    "    MODEL_NAME, \n",
    "    checkpoint_value=142000,\n",
    "    center_unembed=True,\n",
    "    center_writing_weights=True,\n",
    "    fold_ln=True,\n",
    "    dtype=torch.bfloat16,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4dc94a89ec4a491685ad9a921ea2e002",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "82343e614f044a8fb81705615ff6d171",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "pytorch_model-00001-of-00002.bin:   0%|          | 0.00/9.91G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "299cd6938a95445ebe5aed58ef537a61",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "pytorch_model-00002-of-00002.bin:   0%|          | 0.00/3.94G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4d33fa8de2ba4a3d81c9f9ba07b04702",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "WARNING:root:With reduced precision, it is advised to use `from_pretrained_no_processing` instead of `from_pretrained`.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model EleutherAI/pythia-6.9b into HookedTransformer\n"
     ]
    }
   ],
   "source": [
    "cache_dir = \"model_cache\"\n",
    "model = HookedTransformer.from_pretrained(\n",
    "    MODEL_NAME, \n",
    "    checkpoint_value=130000,\n",
    "    center_unembed=True,\n",
    "    center_writing_weights=True,\n",
    "    fold_ln=True,\n",
    "    dtype=torch.bfloat16,\n",
    "    **{\"cache_dir\": cache_dir},\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3cb129f40bfc4a7cac283b6935802430",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "'Mary and John went to the store. Afterwards, Mary gave some milk to a neighbor’s child. The child drank the'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.generate(\"Mary and John went to the store. Afterwards, Mary gave some milk to\", temperature=0.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d246130db69746298985c4f8e40342f9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "beb300b7669a4561979410b4fde698ec",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    hf_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, revision=\"step143000\", torch_dtype=torch.float16)\n",
    "    cfg = get_pretrained_model_config(MODEL_NAME, torch_dtype=torch.float16)\n",
    "    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "    model = HookedTransformer(cfg, tokenizer=tokenizer)\n",
    "    state_dict = get_pretrained_state_dict(MODEL_NAME, cfg, hf_model, torch_dtype=torch.float16)\n",
    "    model.load_state_dict(state_dict, strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "21"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "del hf_model\n",
    "torch.cuda.empty_cache()  # Clear GPU cache\n",
    "gc.collect()  # Run garbage collection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "403afcf1232842eb9805c4914d865881",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "'Testing, testing, ercim.\\n\\nWe’ve just submitted'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.generate(\"Testing, testing, \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "del hf_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "29"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "70a86cccc4a84381ab2fdc59fa668602",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c598838d57f348dd830970b082c6b4f9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "pytorch_model-00001-of-00002.bin:   0%|          | 0.00/9.91G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4c426a355e4246bf874269bccf5a61fb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "pytorch_model-00002-of-00002.bin:   0%|          | 0.00/3.94G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[14], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m hf_model \u001b[39m=\u001b[39m AutoModelForCausalLM\u001b[39m.\u001b[39;49mfrom_pretrained(MODEL_NAME, revision\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mstep50000\u001b[39;49m\u001b[39m\"\u001b[39;49m, torch_dtype\u001b[39m=\u001b[39;49mtorch\u001b[39m.\u001b[39;49mfloat16)\n",
      "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py:566\u001b[0m, in \u001b[0;36m_BaseAutoModelClass.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m    564\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39mtype\u001b[39m(config) \u001b[39min\u001b[39;00m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39m_model_mapping\u001b[39m.\u001b[39mkeys():\n\u001b[1;32m    565\u001b[0m     model_class \u001b[39m=\u001b[39m _get_model_class(config, \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39m_model_mapping)\n\u001b[0;32m--> 566\u001b[0m     \u001b[39mreturn\u001b[39;00m model_class\u001b[39m.\u001b[39;49mfrom_pretrained(\n\u001b[1;32m    567\u001b[0m         pretrained_model_name_or_path, \u001b[39m*\u001b[39;49mmodel_args, config\u001b[39m=\u001b[39;49mconfig, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mhub_kwargs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs\n\u001b[1;32m    568\u001b[0m     )\n\u001b[1;32m    569\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m    570\u001b[0m     \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mUnrecognized configuration class \u001b[39m\u001b[39m{\u001b[39;00mconfig\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m for this kind of AutoModel: \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mcls\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[1;32m    571\u001b[0m     \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mModel type should be one of \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mjoin(c\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m\u001b[39m \u001b[39m\u001b[39mfor\u001b[39;00m\u001b[39m \u001b[39mc\u001b[39m \u001b[39m\u001b[39min\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mcls\u001b[39m\u001b[39m.\u001b[39m_model_mapping\u001b[39m.\u001b[39mkeys())\u001b[39m}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m    572\u001b[0m )\n",
      "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/transformers/modeling_utils.py:3351\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m   3348\u001b[0m \u001b[39m# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.\u001b[39;00m\n\u001b[1;32m   3349\u001b[0m \u001b[39mif\u001b[39;00m is_sharded:\n\u001b[1;32m   3350\u001b[0m     \u001b[39m# rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.\u001b[39;00m\n\u001b[0;32m-> 3351\u001b[0m     resolved_archive_file, sharded_metadata \u001b[39m=\u001b[39m get_checkpoint_shard_files(\n\u001b[1;32m   3352\u001b[0m         pretrained_model_name_or_path,\n\u001b[1;32m   3353\u001b[0m         resolved_archive_file,\n\u001b[1;32m   3354\u001b[0m         cache_dir\u001b[39m=\u001b[39;49mcache_dir,\n\u001b[1;32m   3355\u001b[0m         force_download\u001b[39m=\u001b[39;49mforce_download,\n\u001b[1;32m   3356\u001b[0m         proxies\u001b[39m=\u001b[39;49mproxies,\n\u001b[1;32m   3357\u001b[0m         resume_download\u001b[39m=\u001b[39;49mresume_download,\n\u001b[1;32m   3358\u001b[0m         local_files_only\u001b[39m=\u001b[39;49mlocal_files_only,\n\u001b[1;32m   3359\u001b[0m         token\u001b[39m=\u001b[39;49mtoken,\n\u001b[1;32m   3360\u001b[0m         user_agent\u001b[39m=\u001b[39;49muser_agent,\n\u001b[1;32m   3361\u001b[0m         revision\u001b[39m=\u001b[39;49mrevision,\n\u001b[1;32m   3362\u001b[0m         subfolder\u001b[39m=\u001b[39;49msubfolder,\n\u001b[1;32m   3363\u001b[0m         _commit_hash\u001b[39m=\u001b[39;49mcommit_hash,\n\u001b[1;32m   3364\u001b[0m     )\n\u001b[1;32m   3366\u001b[0m \u001b[39mif\u001b[39;00m (\n\u001b[1;32m   3367\u001b[0m     is_safetensors_available()\n\u001b[1;32m   3368\u001b[0m     \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(resolved_archive_file, \u001b[39mstr\u001b[39m)\n\u001b[1;32m   3369\u001b[0m     \u001b[39mand\u001b[39;00m resolved_archive_file\u001b[39m.\u001b[39mendswith(\u001b[39m\"\u001b[39m\u001b[39m.safetensors\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m   3370\u001b[0m ):\n\u001b[1;32m   3371\u001b[0m     \u001b[39mwith\u001b[39;00m safe_open(resolved_archive_file, framework\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mpt\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39mas\u001b[39;00m f:\n",
      "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/transformers/utils/hub.py:1017\u001b[0m, in \u001b[0;36mget_checkpoint_shard_files\u001b[0;34m(pretrained_model_name_or_path, index_filename, cache_dir, force_download, proxies, resume_download, local_files_only, token, user_agent, revision, subfolder, _commit_hash, **deprecated_kwargs)\u001b[0m\n\u001b[1;32m   1014\u001b[0m \u001b[39mfor\u001b[39;00m shard_filename \u001b[39min\u001b[39;00m tqdm(shard_filenames, desc\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mDownloading shards\u001b[39m\u001b[39m\"\u001b[39m, disable\u001b[39m=\u001b[39m\u001b[39mnot\u001b[39;00m show_progress_bar):\n\u001b[1;32m   1015\u001b[0m     \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m   1016\u001b[0m         \u001b[39m# Load from URL\u001b[39;00m\n\u001b[0;32m-> 1017\u001b[0m         cached_filename \u001b[39m=\u001b[39m cached_file(\n\u001b[1;32m   1018\u001b[0m             pretrained_model_name_or_path,\n\u001b[1;32m   1019\u001b[0m             shard_filename,\n\u001b[1;32m   1020\u001b[0m             cache_dir\u001b[39m=\u001b[39;49mcache_dir,\n\u001b[1;32m   1021\u001b[0m             force_download\u001b[39m=\u001b[39;49mforce_download,\n\u001b[1;32m   1022\u001b[0m             proxies\u001b[39m=\u001b[39;49mproxies,\n\u001b[1;32m   1023\u001b[0m             resume_download\u001b[39m=\u001b[39;49mresume_download,\n\u001b[1;32m   1024\u001b[0m             local_files_only\u001b[39m=\u001b[39;49mlocal_files_only,\n\u001b[1;32m   1025\u001b[0m             token\u001b[39m=\u001b[39;49mtoken,\n\u001b[1;32m   1026\u001b[0m             user_agent\u001b[39m=\u001b[39;49muser_agent,\n\u001b[1;32m   1027\u001b[0m             revision\u001b[39m=\u001b[39;49mrevision,\n\u001b[1;32m   1028\u001b[0m             subfolder\u001b[39m=\u001b[39;49msubfolder,\n\u001b[1;32m   1029\u001b[0m             _commit_hash\u001b[39m=\u001b[39;49m_commit_hash,\n\u001b[1;32m   1030\u001b[0m         )\n\u001b[1;32m   1031\u001b[0m     \u001b[39m# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so\u001b[39;00m\n\u001b[1;32m   1032\u001b[0m     \u001b[39m# we don't have to catch them here.\u001b[39;00m\n\u001b[1;32m   1033\u001b[0m     \u001b[39mexcept\u001b[39;00m EntryNotFoundError:\n",
      "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/transformers/utils/hub.py:389\u001b[0m, in \u001b[0;36mcached_file\u001b[0;34m(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)\u001b[0m\n\u001b[1;32m    386\u001b[0m user_agent \u001b[39m=\u001b[39m http_user_agent(user_agent)\n\u001b[1;32m    387\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m    388\u001b[0m     \u001b[39m# Load from URL or cache if already cached\u001b[39;00m\n\u001b[0;32m--> 389\u001b[0m     resolved_file \u001b[39m=\u001b[39m hf_hub_download(\n\u001b[1;32m    390\u001b[0m         path_or_repo_id,\n\u001b[1;32m    391\u001b[0m         filename,\n\u001b[1;32m    392\u001b[0m         subfolder\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m \u001b[39mif\u001b[39;49;00m \u001b[39mlen\u001b[39;49m(subfolder) \u001b[39m==\u001b[39;49m \u001b[39m0\u001b[39;49m \u001b[39melse\u001b[39;49;00m subfolder,\n\u001b[1;32m    393\u001b[0m         repo_type\u001b[39m=\u001b[39;49mrepo_type,\n\u001b[1;32m    394\u001b[0m         revision\u001b[39m=\u001b[39;49mrevision,\n\u001b[1;32m    395\u001b[0m         cache_dir\u001b[39m=\u001b[39;49mcache_dir,\n\u001b[1;32m    396\u001b[0m         user_agent\u001b[39m=\u001b[39;49muser_agent,\n\u001b[1;32m    397\u001b[0m         force_download\u001b[39m=\u001b[39;49mforce_download,\n\u001b[1;32m    398\u001b[0m         proxies\u001b[39m=\u001b[39;49mproxies,\n\u001b[1;32m    399\u001b[0m         resume_download\u001b[39m=\u001b[39;49mresume_download,\n\u001b[1;32m    400\u001b[0m         token\u001b[39m=\u001b[39;49mtoken,\n\u001b[1;32m    401\u001b[0m         local_files_only\u001b[39m=\u001b[39;49mlocal_files_only,\n\u001b[1;32m    402\u001b[0m     )\n\u001b[1;32m    403\u001b[0m \u001b[39mexcept\u001b[39;00m GatedRepoError \u001b[39mas\u001b[39;00m e:\n\u001b[1;32m    404\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mEnvironmentError\u001b[39;00m(\n\u001b[1;32m    405\u001b[0m         \u001b[39m\"\u001b[39m\u001b[39mYou are trying to access a gated repo.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39mMake sure to request access at \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m    406\u001b[0m         \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mhttps://huggingface.co/\u001b[39m\u001b[39m{\u001b[39;00mpath_or_repo_id\u001b[39m}\u001b[39;00m\u001b[39m and pass a token having permission to this repo either \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m    407\u001b[0m         \u001b[39m\"\u001b[39m\u001b[39mby logging in with `huggingface-cli login` or by passing `token=<your_token>`.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m    408\u001b[0m     ) \u001b[39mfrom\u001b[39;00m \u001b[39me\u001b[39;00m\n",
      "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py:118\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.<locals>._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    115\u001b[0m \u001b[39mif\u001b[39;00m check_use_auth_token:\n\u001b[1;32m    116\u001b[0m     kwargs \u001b[39m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[39m=\u001b[39mfn\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, has_token\u001b[39m=\u001b[39mhas_token, kwargs\u001b[39m=\u001b[39mkwargs)\n\u001b[0;32m--> 118\u001b[0m \u001b[39mreturn\u001b[39;00m fn(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
      "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1461\u001b[0m, in \u001b[0;36mhf_hub_download\u001b[0;34m(repo_id, filename, subfolder, repo_type, revision, library_name, library_version, cache_dir, local_dir, local_dir_use_symlinks, user_agent, force_download, force_filename, proxies, etag_timeout, resume_download, token, local_files_only, legacy_cache_layout, endpoint)\u001b[0m\n\u001b[1;32m   1458\u001b[0m         \u001b[39mif\u001b[39;00m local_dir \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m   1459\u001b[0m             _check_disk_space(expected_size, local_dir)\n\u001b[0;32m-> 1461\u001b[0m     http_get(\n\u001b[1;32m   1462\u001b[0m         url_to_download,\n\u001b[1;32m   1463\u001b[0m         temp_file,\n\u001b[1;32m   1464\u001b[0m         proxies\u001b[39m=\u001b[39;49mproxies,\n\u001b[1;32m   1465\u001b[0m         resume_size\u001b[39m=\u001b[39;49mresume_size,\n\u001b[1;32m   1466\u001b[0m         headers\u001b[39m=\u001b[39;49mheaders,\n\u001b[1;32m   1467\u001b[0m         expected_size\u001b[39m=\u001b[39;49mexpected_size,\n\u001b[1;32m   1468\u001b[0m     )\n\u001b[1;32m   1470\u001b[0m \u001b[39mif\u001b[39;00m local_dir \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m   1471\u001b[0m     logger\u001b[39m.\u001b[39mdebug(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mStoring \u001b[39m\u001b[39m{\u001b[39;00murl\u001b[39m}\u001b[39;00m\u001b[39m in cache at \u001b[39m\u001b[39m{\u001b[39;00mblob_path\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n",
      "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:541\u001b[0m, in \u001b[0;36mhttp_get\u001b[0;34m(url, temp_file, proxies, resume_size, headers, expected_size, _nb_retries)\u001b[0m\n\u001b[1;32m    539\u001b[0m new_resume_size \u001b[39m=\u001b[39m resume_size\n\u001b[1;32m    540\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 541\u001b[0m     \u001b[39mfor\u001b[39;00m chunk \u001b[39min\u001b[39;00m r\u001b[39m.\u001b[39miter_content(chunk_size\u001b[39m=\u001b[39mDOWNLOAD_CHUNK_SIZE):\n\u001b[1;32m    542\u001b[0m         \u001b[39mif\u001b[39;00m chunk:  \u001b[39m# filter out keep-alive new chunks\u001b[39;00m\n\u001b[1;32m    543\u001b[0m             progress\u001b[39m.\u001b[39mupdate(\u001b[39mlen\u001b[39m(chunk))\n",
      "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/requests/models.py:816\u001b[0m, in \u001b[0;36mResponse.iter_content.<locals>.generate\u001b[0;34m()\u001b[0m\n\u001b[1;32m    814\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mhasattr\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mraw, \u001b[39m\"\u001b[39m\u001b[39mstream\u001b[39m\u001b[39m\"\u001b[39m):\n\u001b[1;32m    815\u001b[0m     \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 816\u001b[0m         \u001b[39myield from\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mraw\u001b[39m.\u001b[39mstream(chunk_size, decode_content\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m    817\u001b[0m     \u001b[39mexcept\u001b[39;00m ProtocolError \u001b[39mas\u001b[39;00m e:\n\u001b[1;32m    818\u001b[0m         \u001b[39mraise\u001b[39;00m ChunkedEncodingError(e)\n",
      "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/urllib3/response.py:628\u001b[0m, in \u001b[0;36mHTTPResponse.stream\u001b[0;34m(self, amt, decode_content)\u001b[0m\n\u001b[1;32m    626\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    627\u001b[0m     \u001b[39mwhile\u001b[39;00m \u001b[39mnot\u001b[39;00m is_fp_closed(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_fp):\n\u001b[0;32m--> 628\u001b[0m         data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mread(amt\u001b[39m=\u001b[39;49mamt, decode_content\u001b[39m=\u001b[39;49mdecode_content)\n\u001b[1;32m    630\u001b[0m         \u001b[39mif\u001b[39;00m data:\n\u001b[1;32m    631\u001b[0m             \u001b[39myield\u001b[39;00m data\n",
      "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/urllib3/response.py:567\u001b[0m, in \u001b[0;36mHTTPResponse.read\u001b[0;34m(self, amt, decode_content, cache_content)\u001b[0m\n\u001b[1;32m    564\u001b[0m fp_closed \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_fp, \u001b[39m\"\u001b[39m\u001b[39mclosed\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39mFalse\u001b[39;00m)\n\u001b[1;32m    566\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_error_catcher():\n\u001b[0;32m--> 567\u001b[0m     data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_fp_read(amt) \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m fp_closed \u001b[39melse\u001b[39;00m \u001b[39mb\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m    568\u001b[0m     \u001b[39mif\u001b[39;00m amt \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m    569\u001b[0m         flush_decoder \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n",
      "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/urllib3/response.py:533\u001b[0m, in \u001b[0;36mHTTPResponse._fp_read\u001b[0;34m(self, amt)\u001b[0m\n\u001b[1;32m    530\u001b[0m     \u001b[39mreturn\u001b[39;00m buffer\u001b[39m.\u001b[39mgetvalue()\n\u001b[1;32m    531\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    532\u001b[0m     \u001b[39m# StringIO doesn't like amt=None\u001b[39;00m\n\u001b[0;32m--> 533\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_fp\u001b[39m.\u001b[39;49mread(amt) \u001b[39mif\u001b[39;00m amt \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_fp\u001b[39m.\u001b[39mread()\n",
      "File \u001b[0;32m/opt/conda/lib/python3.10/http/client.py:466\u001b[0m, in \u001b[0;36mHTTPResponse.read\u001b[0;34m(self, amt)\u001b[0m\n\u001b[1;32m    463\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlength \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m amt \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlength:\n\u001b[1;32m    464\u001b[0m     \u001b[39m# clip the read to the \"end of response\"\u001b[39;00m\n\u001b[1;32m    465\u001b[0m     amt \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlength\n\u001b[0;32m--> 466\u001b[0m s \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mfp\u001b[39m.\u001b[39;49mread(amt)\n\u001b[1;32m    467\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m s \u001b[39mand\u001b[39;00m amt:\n\u001b[1;32m    468\u001b[0m     \u001b[39m# Ideally, we would raise IncompleteRead if the content-length\u001b[39;00m\n\u001b[1;32m    469\u001b[0m     \u001b[39m# wasn't satisfied, but it might break compatibility.\u001b[39;00m\n\u001b[1;32m    470\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_close_conn()\n",
      "File \u001b[0;32m/opt/conda/lib/python3.10/socket.py:705\u001b[0m, in \u001b[0;36mSocketIO.readinto\u001b[0;34m(self, b)\u001b[0m\n\u001b[1;32m    703\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mTrue\u001b[39;00m:\n\u001b[1;32m    704\u001b[0m     \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 705\u001b[0m         \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_sock\u001b[39m.\u001b[39;49mrecv_into(b)\n\u001b[1;32m    706\u001b[0m     \u001b[39mexcept\u001b[39;00m timeout:\n\u001b[1;32m    707\u001b[0m         \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_timeout_occurred \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n",
      "File \u001b[0;32m/opt/conda/lib/python3.10/ssl.py:1307\u001b[0m, in \u001b[0;36mSSLSocket.recv_into\u001b[0;34m(self, buffer, nbytes, flags)\u001b[0m\n\u001b[1;32m   1303\u001b[0m     \u001b[39mif\u001b[39;00m flags \u001b[39m!=\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m   1304\u001b[0m         \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m   1305\u001b[0m           \u001b[39m\"\u001b[39m\u001b[39mnon-zero flags not allowed in calls to recv_into() on \u001b[39m\u001b[39m%s\u001b[39;00m\u001b[39m\"\u001b[39m \u001b[39m%\u001b[39m\n\u001b[1;32m   1306\u001b[0m           \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m)\n\u001b[0;32m-> 1307\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mread(nbytes, buffer)\n\u001b[1;32m   1308\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m   1309\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39msuper\u001b[39m()\u001b[39m.\u001b[39mrecv_into(buffer, nbytes, flags)\n",
      "File \u001b[0;32m/opt/conda/lib/python3.10/ssl.py:1163\u001b[0m, in \u001b[0;36mSSLSocket.read\u001b[0;34m(self, len, buffer)\u001b[0m\n\u001b[1;32m   1161\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m   1162\u001b[0m     \u001b[39mif\u001b[39;00m buffer \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m-> 1163\u001b[0m         \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_sslobj\u001b[39m.\u001b[39;49mread(\u001b[39mlen\u001b[39;49m, buffer)\n\u001b[1;32m   1164\u001b[0m     \u001b[39melse\u001b[39;00m:\n\u001b[1;32m   1165\u001b[0m         \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sslobj\u001b[39m.\u001b[39mread(\u001b[39mlen\u001b[39m)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    hf_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, revision=\"step143000\", torch_dtype=torch.float16)\n",
    "    cfg = get_pretrained_model_config(MODEL_NAME, torch_dtype=torch.float16)\n",
    "    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "    model = HookedTransformer(cfg, tokenizer=tokenizer)\n",
    "    state_dict = get_pretrained_state_dict(MODEL_NAME, cfg, hf_model, torch_dtype=torch.float16)\n",
    "    model.load_state_dict(state_dict, strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([50432, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([12288])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([50432, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 50432])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 50432])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([50432])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([16384, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([4096, 50432])\n",
      "<class 'torch.Tensor'> torch.Size([50432])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 1, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 32, 6, 6])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 16384])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 4096])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.Tensor'> torch.Size([1, 6, 1, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([50432, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 4096, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([32, 128])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 2048])\n",
      "<class 'torch.Tensor'> torch.Size([])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.Tensor'> torch.Size([2048, 32])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096, 16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([16384, 4096])\n",
      "<class 'torch.nn.parameter.Parameter'> torch.Size([4096])\n"
     ]
    }
   ],
   "source": [
    "import gc\n",
    "for obj in gc.get_objects():\n",
    "    try:\n",
    "        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):\n",
    "            print(type(obj), obj.size())\n",
    "    except:\n",
    "        pass\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
