{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modeling_llama import *\n",
    "from transformers import AutoConfig\n",
    "\n",
    "config = AutoConfig.from_pretrained(\"/path/to/VisionSpec/LlamaGen/autoregressive/models/t2i_config.json\")\n",
    "\n",
    "model = LlamaForCausalLM(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LlamaForCausalLM(\n",
       "  (model): LlamaModel(\n",
       "    (embed_tokens): Embedding(16384, 1280)\n",
       "    (tok_dropout): Dropout(p=0.1, inplace=False)\n",
       "    (cls_embedding): CaptionEmbedder(\n",
       "      (cap_proj): MLP(\n",
       "        (fc1): Linear(in_features=2048, out_features=1280, bias=False)\n",
       "        (act): GELU(approximate='tanh')\n",
       "        (fc2): Linear(in_features=1280, out_features=1280, bias=False)\n",
       "      )\n",
       "    )\n",
       "    (layers): ModuleList(\n",
       "      (0-35): 36 x LlamaDecoderLayer(\n",
       "        (self_attn): LlamaAttention(\n",
       "          (q_proj): Linear(in_features=1280, out_features=1280, bias=False)\n",
       "          (k_proj): Linear(in_features=1280, out_features=1280, bias=False)\n",
       "          (v_proj): Linear(in_features=1280, out_features=1280, bias=False)\n",
       "          (o_proj): Linear(in_features=1280, out_features=1280, bias=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (mlp): LlamaMLP(\n",
       "          (gate_proj): Linear(in_features=1280, out_features=3584, bias=False)\n",
       "          (up_proj): Linear(in_features=1280, out_features=3584, bias=False)\n",
       "          (down_proj): Linear(in_features=3584, out_features=1280, bias=False)\n",
       "          (ffn_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (input_layernorm): LlamaRMSNorm()\n",
       "        (post_attention_layernorm): LlamaRMSNorm()\n",
       "      )\n",
       "    )\n",
       "    (norm): LlamaRMSNorm()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=1280, out_features=16384, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gpt import GPT_models\n",
    "gpt_model = GPT_models[\"GPT-XL\"](\n",
    "    vocab_size=16384,\n",
    "    block_size=1024,\n",
    "    cls_token_num=120,\n",
    "    model_type='t2i'\n",
    ")\n",
    "import torch\n",
    "gpt_model.setup_caches(2, 1144, torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "checkpoint = torch.load(\"/path/to/t2i_XL_stage2_512.pt\", map_location=\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3584, 1280])\n"
     ]
    }
   ],
   "source": [
    "# for key in checkpoint['model'].keys():\n",
    "#     print(key)\n",
    "print(checkpoint['model']['layers.0.feed_forward.w1.weight'].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Transformer(\n",
       "  (cls_embedding): CaptionEmbedder(\n",
       "    (cap_proj): MLP(\n",
       "      (fc1): Linear(in_features=2048, out_features=1280, bias=False)\n",
       "      (act): GELU(approximate='tanh')\n",
       "      (fc2): Linear(in_features=1280, out_features=1280, bias=False)\n",
       "    )\n",
       "  )\n",
       "  (tok_embeddings): Embedding(16384, 1280)\n",
       "  (tok_dropout): Dropout(p=0.1, inplace=False)\n",
       "  (layers): ModuleList(\n",
       "    (0-35): 36 x TransformerBlock(\n",
       "      (attention): Attention(\n",
       "        (wqkv): Linear(in_features=1280, out_features=3840, bias=False)\n",
       "        (wo): Linear(in_features=1280, out_features=1280, bias=False)\n",
       "        (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        (kv_cache): KVCache()\n",
       "      )\n",
       "      (feed_forward): FeedForward(\n",
       "        (w1): Linear(in_features=1280, out_features=3584, bias=False)\n",
       "        (w3): Linear(in_features=1280, out_features=3584, bias=False)\n",
       "        (w2): Linear(in_features=3584, out_features=1280, bias=False)\n",
       "        (ffn_dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (attention_norm): RMSNorm()\n",
       "      (ffn_norm): RMSNorm()\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "  )\n",
       "  (norm): RMSNorm()\n",
       "  (output): Linear(in_features=1280, out_features=16384, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gpt_model.load_state_dict(checkpoint, strict=False)\n",
    "gpt_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "new_checkpoint = OrderedDict()\n",
    "\n",
    "new_checkpoint['model.embed_tokens.weight'] = checkpoint['model']['tok_embeddings.weight']\n",
    "new_checkpoint['model.cls_embedding.cap_proj.fc1.weight'] = checkpoint['model']['cls_embedding.cap_proj.fc1.weight']\n",
    "new_checkpoint['model.cls_embedding.cap_proj.fc2.weight'] = checkpoint['model']['cls_embedding.cap_proj.fc2.weight']\n",
    "new_checkpoint['model.cls_embedding.uncond_embedding'] = checkpoint['model']['cls_embedding.uncond_embedding']\n",
    "# new_checkpoint['model.freqs_cis'] = checkpoint['model']['freqs_cis']\n",
    "for key in checkpoint['model'].keys():\n",
    "    if 'wqkv' in key:\n",
    "        wqkv_weight = checkpoint['model'][key]\n",
    "        q_proj_weight, k_proj_weight, v_proj_weight = torch.chunk(wqkv_weight, 3, dim=0)\n",
    "        new_checkpoint['model.' + key.replace('wqkv', 'q_proj').replace('attention', 'self_attn')] = q_proj_weight\n",
    "        new_checkpoint['model.' + key.replace('wqkv', 'k_proj').replace('attention', 'self_attn')] = k_proj_weight\n",
    "        new_checkpoint['model.' + key.replace('wqkv', 'v_proj').replace('attention', 'self_attn')] = v_proj_weight\n",
    "    elif 'wo' in key:\n",
    "        wo_weight = checkpoint['model'][key]\n",
    "        new_checkpoint['model.' + key.replace('wo', 'o_proj').replace('attention', 'self_attn')] = wo_weight\n",
    "    elif 'w1' in key:\n",
    "        gate_proj_weight = checkpoint['model'][key]\n",
    "        new_checkpoint['model.' + key.replace('w1', 'gate_proj').replace('feed_forward', 'mlp')] = gate_proj_weight\n",
    "    elif 'w3' in key:\n",
    "        up_proj_weight = checkpoint['model'][key]\n",
    "        new_checkpoint['model.' + key.replace('w3', 'up_proj').replace('feed_forward', 'mlp')] = up_proj_weight\n",
    "    elif 'w2' in key:\n",
    "        down_proj_weight = checkpoint['model'][key]\n",
    "        new_checkpoint['model.' + key.replace('w2', 'down_proj').replace('feed_forward', 'mlp')] = down_proj_weight\n",
    "    elif 'attention_norm' in key:\n",
    "        input_layernorm_weight = checkpoint['model'][key]\n",
    "        new_checkpoint['model.' + key.replace('attention_norm', 'input_layernorm')] = input_layernorm_weight\n",
    "    elif 'ffn_norm' in key:\n",
    "        post_attention_layernorm_weight = checkpoint['model'][key]\n",
    "        new_checkpoint['model.' + key.replace('ffn_norm', 'post_attention_layernorm')] = post_attention_layernorm_weight\n",
    "    elif 'norm' in key:\n",
    "        norm_weight = checkpoint['model'][key]\n",
    "        new_checkpoint['model.' + key] = norm_weight\n",
    "    elif 'output' in key:\n",
    "        lm_head_weight = checkpoint['model'][key]\n",
    "        new_checkpoint[key.replace('output', 'lm_head')] = lm_head_weight\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.load_state_dict(new_checkpoint, strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_pretrained(\"/path/to/model_hf/LlamaGen-T2I-2\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llamagen",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
