{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "os.environ['HF_HOME'] = '/home/jovyan/USR/data/.cache/huggingface'\n",
    "os.chdir('/home/jovyan/USR/data/test_time_gd/')\n",
    "\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPT2Model(\n",
      "  (wte): Embedding(50257, 768)\n",
      "  (wpe): Embedding(1024, 768)\n",
      "  (drop): Dropout(p=0.1, inplace=False)\n",
      "  (h): ModuleList(\n",
      "    (0-11): 12 x GPT2Block(\n",
      "      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "      (attn): GPT2Attention(\n",
      "        (c_attn): Conv1D(nf=2304, nx=768)\n",
      "        (c_proj): Conv1D(nf=768, nx=768)\n",
      "        (attn_dropout): Dropout(p=0.1, inplace=False)\n",
      "        (resid_dropout): Dropout(p=0.1, inplace=False)\n",
      "      )\n",
      "      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "      (mlp): GPT2MLP(\n",
      "        (c_fc): Conv1D(nf=3072, nx=768)\n",
      "        (c_proj): Conv1D(nf=768, nx=3072)\n",
      "        (act): NewGELUActivation()\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "from transformers import AutoModel\n",
    "\n",
    "model = AutoModel.from_pretrained('gpt2')\n",
    "\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'4.51.3'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import transformers\n",
    "transformers.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2.6.0+cu124'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "torch.__version__\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/home/jovyan/USR/envs/py311_pt2.6_cu12.4/lib/python3.11/site-packages/torch/__init__.py'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# torch package path:\n",
    "torch.__file__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO 05-08 01:12:29 [__init__.py:239] Automatically detected platform cuda.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'0.8.4'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import vllm\n",
    "vllm.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from kv_dataset_utils import create_tokenizer\n",
    "tokenizer = create_tokenizer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GPTNeoXConfig {\n",
       "  \"architectures\": [\n",
       "    \"GPTNeoXForCausalLM\"\n",
       "  ],\n",
       "  \"attention_bias\": true,\n",
       "  \"attention_dropout\": 0.0,\n",
       "  \"bos_token_id\": 1,\n",
       "  \"classifier_dropout\": 0.1,\n",
       "  \"eos_token_id\": 2,\n",
       "  \"hidden_act\": \"gelu\",\n",
       "  \"hidden_dropout\": 0.0,\n",
       "  \"hidden_size\": 128,\n",
       "  \"initializer_range\": 0.02,\n",
       "  \"intermediate_size\": 512,\n",
       "  \"layer_norm_eps\": 1e-05,\n",
       "  \"max_position_embeddings\": 2048,\n",
       "  \"model_type\": \"gpt_neox\",\n",
       "  \"num_attention_heads\": 4,\n",
       "  \"num_hidden_layers\": 4,\n",
       "  \"pad_token_id\": 0,\n",
       "  \"partial_rotary_factor\": 0.25,\n",
       "  \"rope_scaling\": null,\n",
       "  \"rope_theta\": 10000,\n",
       "  \"rotary_emb_base\": 10000,\n",
       "  \"rotary_pct\": 0.25,\n",
       "  \"tie_word_embeddings\": false,\n",
       "  \"torch_dtype\": \"float16\",\n",
       "  \"transformers_version\": \"4.51.3\",\n",
       "  \"use_cache\": true,\n",
       "  \"use_parallel_residual\": true,\n",
       "  \"vocab_size\": 128\n",
       "}"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import AutoConfig\n",
    "config = AutoConfig.from_pretrained('EleutherAI/pythia-160m')\n",
    "config.num_hidden_layers = 4\n",
    "config.num_attention_heads = 4\n",
    "config.hidden_size = 128\n",
    "config.vocab_size = 128\n",
    "config.intermediate_size = config.hidden_size * 4\n",
    "config.pad_token_id = tokenizer.convert_tokens_to_ids('[PAD]')\n",
    "config.bos_token_id = tokenizer.convert_tokens_to_ids('[BOS]')\n",
    "config.eos_token_id = tokenizer.convert_tokens_to_ids('[EOS]')\n",
    "config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM\n",
    "model = AutoModelForCausalLM.from_config(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GPTNeoXForCausalLM(\n",
       "  (gpt_neox): GPTNeoXModel(\n",
       "    (embed_in): Embedding(128, 128)\n",
       "    (emb_dropout): Dropout(p=0.0, inplace=False)\n",
       "    (layers): ModuleList(\n",
       "      (0-3): 4 x GPTNeoXLayer(\n",
       "        (input_layernorm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "        (post_attention_layernorm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "        (post_attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "        (post_mlp_dropout): Dropout(p=0.0, inplace=False)\n",
       "        (attention): GPTNeoXAttention(\n",
       "          (query_key_value): Linear(in_features=128, out_features=384, bias=True)\n",
       "          (dense): Linear(in_features=128, out_features=128, bias=True)\n",
       "        )\n",
       "        (mlp): GPTNeoXMLP(\n",
       "          (dense_h_to_4h): Linear(in_features=128, out_features=512, bias=True)\n",
       "          (dense_4h_to_h): Linear(in_features=512, out_features=128, bias=True)\n",
       "          (act): GELUActivation()\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (final_layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "    (rotary_emb): GPTNeoXRotaryEmbedding()\n",
       "  )\n",
       "  (embed_out): Linear(in_features=128, out_features=128, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_ids = tokenizer('|ABCDEF|', return_tensors='pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(**input_ids, labels=input_ids['input_ids'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "odict_keys(['loss', 'logits', 'past_key_values'])"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 8, 128])"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.logits.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(4.8883, grad_fn=<NllLossBackward0>)\n",
      "tensor(4.3410)\n"
     ]
    }
   ],
   "source": [
    "model = AutoModelForCausalLM.from_config(config)\n",
    "model.to(torch.bfloat16)\n",
    "\n",
    "out = model(**input_ids, labels=input_ids['input_ids'])\n",
    "print(out.loss)\n",
    "\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-04)\n",
    "optimizer.zero_grad()\n",
    "out.loss.backward()\n",
    "optimizer.step()\n",
    "\n",
    "\n",
    "with torch.no_grad():\n",
    "    out = model(**input_ids, labels=input_ids['input_ids'])\n",
    "    print(out.loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'gpt_neox'"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.config.model_type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GPTNeoXForCausalLM(\n",
       "  (gpt_neox): GPTNeoXModel(\n",
       "    (embed_in): Embedding(128, 128)\n",
       "    (emb_dropout): Dropout(p=0.0, inplace=False)\n",
       "    (layers): ModuleList(\n",
       "      (0-3): 4 x GPTNeoXLayer(\n",
       "        (input_layernorm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "        (post_attention_layernorm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "        (post_attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "        (post_mlp_dropout): Dropout(p=0.0, inplace=False)\n",
       "        (attention): GPTNeoXAttention(\n",
       "          (query_key_value): Linear(in_features=128, out_features=384, bias=True)\n",
       "          (dense): Linear(in_features=128, out_features=128, bias=True)\n",
       "        )\n",
       "        (mlp): GPTNeoXMLP(\n",
       "          (dense_h_to_4h): Linear(in_features=128, out_features=512, bias=True)\n",
       "          (dense_4h_to_h): Linear(in_features=512, out_features=128, bias=True)\n",
       "          (act): GELUActivation()\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (final_layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "    (rotary_emb): GPTNeoXRotaryEmbedding()\n",
       "  )\n",
       "  (embed_out): Linear(in_features=128, out_features=128, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 80,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GPTNeoXModel(\n",
       "  (embed_in): Embedding(128, 128)\n",
       "  (emb_dropout): Dropout(p=0.0, inplace=False)\n",
       "  (layers): ModuleList(\n",
       "    (0-3): 4 x GPTNeoXLayer(\n",
       "      (input_layernorm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      (post_attention_layernorm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      (post_attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "      (post_mlp_dropout): Dropout(p=0.0, inplace=False)\n",
       "      (attention): GPTNeoXAttention(\n",
       "        (query_key_value): Linear(in_features=128, out_features=384, bias=True)\n",
       "        (dense): Linear(in_features=128, out_features=128, bias=True)\n",
       "      )\n",
       "      (mlp): GPTNeoXMLP(\n",
       "        (dense_h_to_4h): Linear(in_features=128, out_features=512, bias=True)\n",
       "        (dense_4h_to_h): Linear(in_features=512, out_features=128, bias=True)\n",
       "        (act): GELUActivation()\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (final_layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "  (rotary_emb): GPTNeoXRotaryEmbedding()\n",
       ")"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.gpt_neox"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-0.0170, -0.0432,  0.0082,  ..., -0.0093, -0.0177,  0.0028],\n",
       "         [-0.0250,  0.0354,  0.0165,  ...,  0.0134,  0.0182, -0.0297],\n",
       "         [ 0.0040,  0.0089,  0.0050,  ..., -0.0025, -0.0176, -0.0113],\n",
       "         ...,\n",
       "         [-0.0435,  0.0160, -0.0070,  ...,  0.0060,  0.0117, -0.0063],\n",
       "         [ 0.0038, -0.0095, -0.0153,  ...,  0.0220, -0.0190,  0.0269],\n",
       "         [-0.0170, -0.0432,  0.0082,  ..., -0.0093, -0.0177,  0.0028]]],\n",
       "       dtype=torch.bfloat16, grad_fn=<EmbeddingBackward0>)"
      ]
     },
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
