{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "18d9eeaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(\"../associative-recurrent-memory-transformer\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2b181ebd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "*** Setting default RWKV_MY_TESTING = x060 ***\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-08-30 13:49:37,359] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False\n",
      "[RWKV.model] Running RWKV infctx using 'torch-jit' with torch '2.5.1+cu124'\n",
      "[RWKV.model] Running RWKV infctx using 'torch-jit' with torch '2.5.1+cu124'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Current Triton version 3.1.0 is below the recommended 3.2.0 version. Errors may occur and these issues will not be fixed. Please consider upgrading Triton.\n",
      "Current Python version 3.9 is below the recommended 3.11 version. It is recommended to upgrade to Python 3.11 or higher for the best experience.\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "unsupported operand type(s) for |: 'torch._C._TensorMeta' and 'NoneType'",
     "output_type": "error",
     "traceback": [
     ]
    }
   ],
   "source": [
    "from baselines.rwkv.language_modeling import RWKV_v6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c12fb4ec",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "A new version of the following files was downloaded from https://huggingface.co/fla-hub/rwkv7-0.4B-world:\n",
      "- modeling_rwkv7.py\n",
      ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
      "A new version of the following files was downloaded from https://huggingface.co/fla-hub/rwkv7-0.4B-world:\n",
      "- hf_rwkv_tokenizer.py\n",
      ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "# Load model and tokeniser\n",
    "model = AutoModelForCausalLM.from_pretrained('fla-hub/rwkv7-0.4B-world', trust_remote_code=True)\n",
    "tokenizer = AutoTokenizer.from_pretrained('fla-hub/rwkv7-0.4B-world', trust_remote_code=True)\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7dc5286c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RWKV7ForCausalLM(\n",
       "  (model): RWKV7Model(\n",
       "    (embeddings): Embedding(65536, 1024)\n",
       "    (layers): ModuleList(\n",
       "      (0): RWKV7Block(\n",
       "        (pre_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): RWKV7Attention(\n",
       "          (time_shift): ZeroPad2d((0, 0, 1, -1))\n",
       "          (r_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "          (w_lora): LoRA(input_dim=1024, low_rank_dim=64, output_dim=1024)\n",
       "          (a_lora): LoRA(input_dim=1024, low_rank_dim=64, output_dim=1024)\n",
       "          (g_lora): LoRA(input_dim=1024, low_rank_dim=128, output_dim=1024, bias=False)\n",
       "          (g_norm): GroupNorm(16, 1024, eps=0.00064, affine=True)\n",
       "        )\n",
       "        (ffn_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        (ffn): RWKV7FeedForward(\n",
       "          (time_shift): ZeroPad2d((0, 0, 1, -1))\n",
       "          (key): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "          (value): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "        )\n",
       "      )\n",
       "      (1-23): 23 x RWKV7Block(\n",
       "        (attn_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): RWKV7Attention(\n",
       "          (time_shift): ZeroPad2d((0, 0, 1, -1))\n",
       "          (r_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "          (w_lora): LoRA(input_dim=1024, low_rank_dim=64, output_dim=1024)\n",
       "          (v_lora): LoRA(input_dim=1024, low_rank_dim=32, output_dim=1024)\n",
       "          (a_lora): LoRA(input_dim=1024, low_rank_dim=64, output_dim=1024)\n",
       "          (g_lora): LoRA(input_dim=1024, low_rank_dim=128, output_dim=1024, bias=False)\n",
       "          (g_norm): GroupNorm(16, 1024, eps=0.00064, affine=True)\n",
       "        )\n",
       "        (ffn_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        (ffn): RWKV7FeedForward(\n",
       "          (time_shift): ZeroPad2d((0, 0, 1, -1))\n",
       "          (key): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "          (value): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (lm_head): Linear(in_features=1024, out_features=65536, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "16162c25",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from fla.models.rwkv7.configuration_rwkv7 import RWKV7Config\n",
    "from fla.models.rwkv7.modeling_rwkv7 import RWKV7ForCausalLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "04fd8c74",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RWKV7Config {\n",
       "  \"a_low_rank_dim\": 64,\n",
       "  \"attn\": null,\n",
       "  \"attn_mode\": \"chunk\",\n",
       "  \"bos_token_id\": 1,\n",
       "  \"decay_low_rank_dim\": 64,\n",
       "  \"eos_token_id\": 2,\n",
       "  \"fuse_cross_entropy\": true,\n",
       "  \"fuse_norm\": true,\n",
       "  \"gate_low_rank_dim\": 128,\n",
       "  \"head_dim\": 64,\n",
       "  \"hidden_act\": \"sqrelu\",\n",
       "  \"hidden_ratio\": 4,\n",
       "  \"hidden_size\": 2048,\n",
       "  \"initializer_range\": 0.02,\n",
       "  \"intermediate_size\": null,\n",
       "  \"max_position_embeddings\": 2048,\n",
       "  \"model_type\": \"rwkv7\",\n",
       "  \"norm_bias\": true,\n",
       "  \"norm_eps\": 1e-05,\n",
       "  \"norm_first\": true,\n",
       "  \"num_heads\": 32,\n",
       "  \"num_hidden_layers\": 24,\n",
       "  \"tie_word_embeddings\": false,\n",
       "  \"transformers_version\": \"4.53.2\",\n",
       "  \"use_cache\": true,\n",
       "  \"use_l2warp\": true,\n",
       "  \"v_low_rank_dim\": 16,\n",
       "  \"value_dim\": [\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048,\n",
       "    2048\n",
       "  ],\n",
       "  \"vocab_size\": 32000\n",
       "}"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "RWKV7Config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8b147a9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model.config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a7d67b25",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = RWKV7Config(\n",
    "    num_hidden_layers=24,\n",
    "    \n",
    "    a_low_rank_dim=64,\n",
    "    hidden_size=1024,\n",
    "    intermediate_size=4096,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "983e3cd9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "model_h = RWKV7ForCausalLM(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "dfd183cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_parameters_in_billions(model):\n",
    "    \"\"\"Count the number of parameters in a model and return in billions.\"\"\"\n",
    "    total_params = sum(p.numel() for p in model.parameters())\n",
    "    return total_params / 1e9\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9d48d02f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(f\"Model has {count_parameters_in_billions(model):.2f} billion parameters\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "54a7a108",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model has 0.38 billion parameters\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "RWKV7ForCausalLM(\n",
       "  (model): RWKV7Model(\n",
       "    (embeddings): Embedding(32000, 1024)\n",
       "    (layers): ModuleList(\n",
       "      (0): RWKV7Block(\n",
       "        (pre_norm): LayerNorm(1024, eps=1e-05)\n",
       "        (attn_norm): LayerNorm(1024, eps=1e-05)\n",
       "        (attn): RWKV7Attention(\n",
       "          (time_shift): ZeroPad2d((0, 0, 1, -1))\n",
       "          (r_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "          (w_lora): LoRA(input_dim=1024, low_rank_dim=64, output_dim=1024)\n",
       "          (a_lora): LoRA(input_dim=1024, low_rank_dim=64, output_dim=1024)\n",
       "          (g_lora): LoRA(input_dim=1024, low_rank_dim=128, output_dim=1024, bias=False)\n",
       "          (g_norm): GroupNorm(16, 1024, eps=0.00064)\n",
       "        )\n",
       "        (ffn_norm): LayerNorm(1024, eps=1e-05)\n",
       "        (ffn): RWKV7FeedForward(\n",
       "          (time_shift): ZeroPad2d((0, 0, 1, -1))\n",
       "          (key): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "          (value): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "        )\n",
       "      )\n",
       "      (1-23): 23 x RWKV7Block(\n",
       "        (attn_norm): LayerNorm(1024, eps=1e-05)\n",
       "        (attn): RWKV7Attention(\n",
       "          (time_shift): ZeroPad2d((0, 0, 1, -1))\n",
       "          (r_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "          (w_lora): LoRA(input_dim=1024, low_rank_dim=64, output_dim=1024)\n",
       "          (v_lora): LoRA(input_dim=1024, low_rank_dim=16, output_dim=1024)\n",
       "          (a_lora): LoRA(input_dim=1024, low_rank_dim=64, output_dim=1024)\n",
       "          (g_lora): LoRA(input_dim=1024, low_rank_dim=128, output_dim=1024, bias=False)\n",
       "          (g_norm): GroupNorm(16, 1024, eps=0.00064)\n",
       "        )\n",
       "        (ffn_norm): LayerNorm(1024, eps=1e-05)\n",
       "        (ffn): RWKV7FeedForward(\n",
       "          (time_shift): ZeroPad2d((0, 0, 1, -1))\n",
       "          (key): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "          (value): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (norm): LayerNorm(1024, eps=1e-05)\n",
       "  )\n",
       "  (lm_head): Linear(in_features=1024, out_features=32000, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(f\"Model has {count_parameters_in_billions(model_h):.2f} billion parameters\")\n",
    "model_h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6770c4c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_h.to(\"cuda\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "147404da",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "4358e232",
   "metadata": {},
   "outputs": [],
   "source": [
    "input = torch.randint(0, 1000, (1, 1024*3), device='cuda')\n",
    "\n",
    "del o"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6223e94a",
   "metadata": {},
   "outputs": [
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
      "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
      "\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
      "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
     ]
    }
   ],
   "source": [
    "o = model_h(input)\n",
    "\n",
    "torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1124387e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "armt-kernel",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
