{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ff725208",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from types import SimpleNamespace\n",
    "from models.meshtok import MeshTok"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a59efa83",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _ns(d):\n",
    "    if isinstance(d, dict):\n",
    "        obj = SimpleNamespace(**{k: _ns(v) for k, v in d.items()})\n",
    "        obj.get = lambda k, default=None, _obj=obj: getattr(_obj, k, default)\n",
    "        return obj\n",
    "    if isinstance(d, list):\n",
    "        return [_ns(x) for x in d]\n",
    "    return d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3ddeedb7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MeshTok(\n",
       "  (embedder): ConvEmbedder(\n",
       "    (positional_encoding_3d): SinusoidalPositionalEncoding3D()\n",
       "    (in_proj): Conv2d(4, 512, kernel_size=(16, 16), stride=(16, 16), bias=False)\n",
       "    (conv_proj): Sequential(\n",
       "      (0): GELU(approximate='none')\n",
       "      (1): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "    )\n",
       "    (in_proj_sub): Conv2d(4, 512, kernel_size=(8, 8), stride=(8, 8), bias=False)\n",
       "    (conv_proj_sub): Sequential(\n",
       "      (0): GELU(approximate='none')\n",
       "      (1): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "    )\n",
       "    (post_proj): Sequential(\n",
       "      (0): Rearrange('b (h w) d -> b d h w', h=8, w=8)\n",
       "      (1): ConvTranspose2d(512, 32, kernel_size=(16, 16), stride=(16, 16), bias=False)\n",
       "    )\n",
       "    (post_proj_sub): Sequential(\n",
       "      (0): Rearrange('b (h w) d -> b d h w', h=16, w=16)\n",
       "      (1): ConvTranspose2d(512, 32, kernel_size=(8, 8), stride=(8, 8), bias=False)\n",
       "    )\n",
       "    (head): Sequential(\n",
       "      (0): GELU(approximate='none')\n",
       "      (1): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))\n",
       "      (2): GELU(approximate='none')\n",
       "      (3): Conv2d(32, 4, kernel_size=(1, 1), stride=(1, 1))\n",
       "    )\n",
       "  )\n",
       "  (transformer): CacheCustomTransformerEncoder(\n",
       "    (layers): ModuleList(\n",
       "      (0): CacheCustomTransformerEncoderLayer(\n",
       "        (self_attn): MultiheadAttention(\n",
       "          (linear_q): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (linear_k): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (linear_v): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (q_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
       "          (k_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (grid_depth): GridLevelFiLM(\n",
       "          (grid_proj): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=256, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "          )\n",
       "          (level_embed): Embedding(2, 256)\n",
       "          (level_act): GELU(approximate='none')\n",
       "          (to_gamma_beta): Sequential(\n",
       "            (0): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Linear(in_features=512, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (ffn): FFN(\n",
       "          (fc1): Linear(in_features=512, out_features=1280, bias=True)\n",
       "          (fc_gate): Linear(in_features=512, out_features=1280, bias=True)\n",
       "          (activation): SwiGLU(\n",
       "            (act): SiLU()\n",
       "          )\n",
       "          (dropout): Identity()\n",
       "          (fc2): Linear(in_features=1280, out_features=512, bias=True)\n",
       "        )\n",
       "        (norm1): RMSNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (norm2): RMSNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (dropout1): Identity()\n",
       "        (dropout2): Identity()\n",
       "        (se_attn): Identity()\n",
       "        (se_ffn): Identity()\n",
       "      )\n",
       "      (1-7): 7 x CacheCustomTransformerEncoderLayer(\n",
       "        (self_attn): MultiheadAttention(\n",
       "          (linear_q): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (linear_k): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (linear_v): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (q_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
       "          (k_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (ffn): FFN(\n",
       "          (fc1): Linear(in_features=512, out_features=1280, bias=True)\n",
       "          (fc_gate): Linear(in_features=512, out_features=1280, bias=True)\n",
       "          (activation): SwiGLU(\n",
       "            (act): SiLU()\n",
       "          )\n",
       "          (dropout): Identity()\n",
       "          (fc2): Linear(in_features=1280, out_features=512, bias=True)\n",
       "        )\n",
       "        (norm1): RMSNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (norm2): RMSNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (dropout1): Identity()\n",
       "        (dropout2): Identity()\n",
       "        (se_attn): Identity()\n",
       "        (se_ffn): Identity()\n",
       "      )\n",
       "    )\n",
       "    (norm): RMSNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "bs = 2\n",
    "x_num = 128         \n",
    "data_dim = 4         \n",
    "input_len = 10      \n",
    "gen_len = 10        \n",
    "t_total = input_len + gen_len  # 20\n",
    "\n",
    "\n",
    "\n",
    "cfg = _ns({\n",
    "    \"name\": \"meshtok_auto\",\n",
    "\n",
    "    \"all_exp\": 16,\n",
    "    \"training\": True,\n",
    "    \"topk\": 4,\n",
    "    \"n_shared_experts\": 2,\n",
    "    \"n_layer\": 8,\n",
    "    \"dim_emb\": 512,\n",
    "    \"moe_intermediate_size\": 256,\n",
    "    \"dim_ffn\": 1280,\n",
    "    \"dropout\": 0.0,\n",
    "    \"attn_dropout\": 0.0,\n",
    "    \"n_head\": 8,\n",
    "    \"norm_first\": True,\n",
    "    \"positional_embedding\": None,\n",
    "    \"qk_norm\": True,       # YAML: 1\n",
    "    \"norm\": \"rms\",\n",
    "    \"activation\": \"swiglu\",\n",
    "    \"rotary\": False,       # YAML: 0\n",
    "\n",
    "    \"refine_ratio\": 0.25,\n",
    "\n",
    "    \"flex_attn\": False,    # YAML: 0\n",
    "    \"kv_cache\": True,      # YAML: 1\n",
    "\n",
    "\n",
    "    \"dense\": [True] * 8,\n",
    "\n",
    "    \"patch_num\": 8,\n",
    "    \"patch_num_output\": 8,\n",
    "\n",
    "    \"embedder\": {\n",
    "        \"type\": \"conv\",\n",
    "        \"dim\": 512,\n",
    "        \"patch_num\": 8,\n",
    "        \"patch_num_output\": 8,\n",
    "        \"time_embed\": \"learnable\",\n",
    "        \"select\": \"physical\",\n",
    "        \"max_time_len\": 20,\n",
    "\n",
    "        \"conv_dim\": 32,\n",
    "        \"early_conv\": False,  # YAML: 0\n",
    "        \"deep\": False,        # YAML: 0\n",
    "    },\n",
    "})\n",
    "cfg.get = lambda k, default=None: getattr(cfg, k, default)\n",
    "\n",
    "\n",
    "\n",
    "model = MeshTok(cfg, x_num=x_num, max_output_dim=data_dim, max_data_len=t_total).to(device)\n",
    "model.eval()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3becf3da",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generate out: torch.Size([2, 10, 128, 128, 4])\n"
     ]
    }
   ],
   "source": [
    "times = torch.linspace(0, 1, t_total, device=device)[None, :, None].repeat(bs, 1, 1)    # (bs, 20, 1)\n",
    "data_input = torch.randn(bs, input_len, x_num, x_num, data_dim, device=device)         # (bs, 10, 128,128,4)\n",
    "data_mask = torch.ones(bs, 1, x_num, x_num, data_dim, device=device)                   # (bs, 1, 128,128,4)\n",
    "\n",
    "with torch.no_grad():\n",
    "    out = model(\n",
    "        \"generate\",\n",
    "        data_input=data_input,\n",
    "        times=times,\n",
    "        input_len=input_len,\n",
    "        data_mask=data_mask,\n",
    "        carry_over_c=-1,\n",
    "    )\n",
    "\n",
    "print(\"generate out:\", out.shape) "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "bcat",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
