{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import copy\n",
    "sys.path.append(\"./associative-recurrent-memory-transformer\")\n",
    "sys.path.append(\"./\")\n",
    "sys.path.append(\"../\")\n",
    "from tqdm import tqdm\n",
    "\n",
    "import logging\n",
    "import math\n",
    "import os\n",
    "from pathlib import Path\n",
    "from itertools import chain\n",
    "from datetime import timedelta\n",
    "import torch\n",
    "import datasets\n",
    "import numpy as np\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import set_seed\n",
    "\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "from peft import get_peft_model, LoraConfig, TaskType\n",
    "from babilong.babilong_utils import TaskDataset, SentenceSampler, NoiseInjectionDataset\n",
    "\n",
    "from grouped_batching.batching import GroupedBatcher\n",
    "from grouped_batching.executor import ArmtGroupedExecutor\n",
    "from grouped_batching.fast_executor import FastGroupedArmtExecutor, GroupedLayerContext, associate_with_context, update_mem_with_context\n",
    "from grouped_batching.llama1b_grouping import (\n",
    "    wrap_model_with_armt, get_grouped_states, \n",
    "    make_grouped_layer_from_single_layer, make_grouped_model_from_naive,\n",
    "    make_grouped_sliced_layer_from_single_layer\n",
    ")\n",
    "\n",
    "torch.autograd.set_detect_anomaly(True)\n",
    "from grouped_batching.llama1b_grouping_autograd import make_grouped_training_layer_from_single_layer, make_grouped_sliced_training_layer_from_single_layer\n",
    "from modeling_amt.language_modeling import AssociativeRecurrentWrapper, AssociativeMemoryCell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_opt_armt_train(original_model, segment_size):\n",
    "    # merge lora back to model\n",
    "    merge_and_save = True\n",
    "    unmerge_and_save = False\n",
    "    if merge_and_save or unmerge_and_save:\n",
    "        if merge_and_save:\n",
    "            original_model.memory_cell.model.merge_and_unload()\n",
    "        if unmerge_and_save:\n",
    "            original_model.memory_cell.model.unload()\n",
    "        original_model.memory_cell.model = original_model.memory_cell.model.base_model.model\n",
    "    # transform model into grouped version\n",
    "    dtype = torch.bfloat16\n",
    "    device = \"cuda\"\n",
    "    original_model = original_model.to(dtype)\n",
    "    grouped_context = GroupedLayerContext()\n",
    "    grouped_context.is_training = True\n",
    "    # grouped_states = get_grouped_states(armt_model)\n",
    "    grouped_layer = make_grouped_sliced_training_layer_from_single_layer(\n",
    "        grouped_context, copy.deepcopy(original_model.memory_cell.model.model.layers[0]), original_model.memory_cell.model.model.layers\n",
    "    )\n",
    "    grouped_layer = grouped_layer.to(dtype)\n",
    "    grouped_layer = grouped_layer.to(device)\n",
    "    armt_grouped_model, source_model_layers = make_grouped_model_from_naive(original_model, grouped_layer)\n",
    "    armt_grouped_model.to(device)\n",
    "    executor = FastGroupedArmtExecutor(\n",
    "        armt_grouped_model, \n",
    "        grouped_layer, \n",
    "        grouped_context, \n",
    "        16,#model_config.num_hidden_layers, \n",
    "        vanilla_armt_model=original_model,\n",
    "    )\n",
    "    ### ONLY FOR FAST LATENCY VERSION\n",
    "    # compile full layers\n",
    "    segments_input = torch.rand((16, segment_size, 2048), device=\"cuda\", dtype=dtype)\n",
    "    i, j = 0, 16\n",
    "    grouped_context.start_idx = i\n",
    "    grouped_context.end_idx = j\n",
    "    grouped_context.is_full = True\n",
    "    \n",
    "    ao = associate_with_context(grouped_layer, grouped_context, segments_input[i:j])\n",
    "    grouped_layer.generate_mode = True\n",
    "    armt_grouped_model.memory_cell.model.model(inputs_embeds=segments_input[i:j], use_cache=False)\n",
    "    update_mem_with_context(grouped_layer, grouped_context, segments_input[i:j])\n",
    "    return executor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "armt_cpt_path = \"../../data/pretrained_models/RMT-Llama-3.2-1B-Instruct-8x1024-mem16-lora-babilong-qa1-5_ct-v3.1/model.safetensors\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_default_device(\"cuda:0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "dtype = torch.bfloat16\n",
    "torch.set_default_dtype(dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load base model\n",
    "source_model = AutoModelForCausalLM.from_pretrained(\"unsloth/Llama-3.2-1B-Instruct\",\n",
    "                                                    attn_implementation=\"flash_attention_2\",\n",
    "                                                    torch_dtype=dtype,\n",
    "                                                    device_map=\"cpu\")\n",
    "source_model.eval()\n",
    "#source_model.lm_head = torch.nn.Identity()\n",
    "#reference_model = copy.deepcopy(source_model)\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"unsloth/Llama-3.2-1B-Instruct\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "peft_config = LoraConfig(\n",
    "    task_type=TaskType.CAUSAL_LM, \n",
    "    inference_mode=True, \n",
    "    r=8, \n",
    "    lora_alpha=32, \n",
    "    lora_dropout=0.1,\n",
    "    )\n",
    "source_model = get_peft_model(source_model, peft_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# after wrap base model in original ARMT and ARMT with grouped batching, and load pretrained weigths\n",
    "# the actual segment_size for this model is segment_size - mem_size, so we will use it later\n",
    "segment_size = 1024\n",
    "mem_size = 16\n",
    "segment_alignment = \"left\"\n",
    "attend_to_previous_input = False\n",
    "device = \"cpu\"\n",
    "max_n_segments = 32\n",
    "mem_cell_args = dict(\n",
    "    base_model=source_model,\n",
    "    num_mem_tokens=mem_size,\n",
    "    d_mem=64,\n",
    "    layers_attr=\"model.model.layers\",\n",
    "    wrap_pos=False,\n",
    "    correction=True,\n",
    ")\n",
    "\n",
    "cell = AssociativeMemoryCell(**mem_cell_args)\n",
    "original_model = AssociativeRecurrentWrapper(cell,\n",
    "                                            segment_size=segment_size,\n",
    "                                            max_n_segments=max_n_segments,\n",
    "                                            vary_n_segments=True,\n",
    "                                            k2=-1,\n",
    "                                            return_all_logits=False,\n",
    ").to(device)\n",
    "\n",
    "if \"safetensors\" in armt_cpt_path:\n",
    "    from safetensors.torch import load_model\n",
    "    load_model(original_model, armt_cpt_path, device=\"cuda:0\")\n",
    "else:\n",
    "    cpt = torch.load(armt_cpt_path, map_location=device)\n",
    "    original_model.load_state_dict(cpt, strict=True)\n",
    "original_model = original_model.to(\"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define params\n",
    "# load and prepare data\n",
    "output_dir = \"./runs/test/babilong_multitask/train_test\"\n",
    "# set current working dir\n",
    "working_dir = \"./associative-recurrent-memory-transformer\"\n",
    "working_dir = str(Path(working_dir).expanduser().absolute())\n",
    "os.chdir(working_dir)\n",
    "seed = 1\n",
    "set_seed(seed)\n",
    "# set bfloat16 for compatibility with grouped version\n",
    "dtype = torch.bfloat16\n",
    "torch.set_default_dtype(dtype)\n",
    "model_name = \"unsloth/Llama-3.2-1B-Instruct\"\n",
    "noise_dataset = \"wikitext\"\n",
    "noise_dataset_split = \"wikitext-103-raw-v1\"\n",
    "babi_path = \"../babilong/data/tasks_1-20_v1-2/en-10k\"\n",
    "max_n_facts = 800\n",
    "segment_size = 1024\n",
    "max_n_segments = 8\n",
    "sample_size = segment_size * max_n_segments\n",
    "learning_rate = 1e-05\n",
    "task_start_pct = None\n",
    "task_end_pct = None\n",
    "mixed_length_ratio = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "// Gemm operator cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8\n",
      "using cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_base =\n",
      "  typename cutlass::gemm::kernel::DefaultGemmGrouped<\n",
      "    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
      "    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
      "    cutlass::bfloat16_t, cutlass::layout::RowMajor,\n",
      "    float,\n",
      "    cutlass::arch::OpClassTensorOp,\n",
      "    cutlass::arch::Sm80,\n",
      "    cutlass::gemm::GemmShape<256, 128, 64>,\n",
      "    cutlass::gemm::GemmShape<64, 64, 64>,\n",
      "    cutlass::gemm::GemmShape<16, 8, 16>,\n",
      "    cutlass::epilogue::thread::LinearCombination<cutlass::bfloat16_t, 8, float, float>,\n",
      "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
      "    3,\n",
      "    cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,\n",
      "    cutlass::arch::OpMultiplyAdd\n",
      ">::GemmKernel;\n",
      "\n",
      "// Define named type\n",
      "struct cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_type :\n",
      "  public cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_base { };\n",
      "\n",
      "USE_EFFICIENT_ALLOCATION\n"
     ]
    }
   ],
   "source": [
    "model = prepare_opt_armt_train(original_model, segment_size-mem_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "noise_dataset = datasets.load_dataset(noise_dataset, noise_dataset_split)\n",
    "noise_dataset_train = noise_dataset['train']\n",
    "noise_dataset_test = noise_dataset['test']\n",
    "\n",
    "# task dataset \n",
    "task_datasets = [\"qa1_single-supporting-fact\"]\n",
    "train_paths = [os.path.join(babi_path, f\"{td}_train.txt\") for td in task_datasets]\n",
    "test_paths = [os.path.join(babi_path, f\"{td}_test.txt\") for td in task_datasets]\n",
    "\n",
    "task_dataset_train = TaskDataset(train_paths[0], max_n_facts=max_n_facts)\n",
    "task_dataset_test = TaskDataset(test_paths[0], max_n_facts=max_n_facts)\n",
    "\n",
    "# background text\n",
    "qa_margin = 70          # leave space for questions and answers\n",
    "train_sample_size = [int(segment_size * i) for i in range(1, max_n_segments)] + [sample_size]\n",
    "train_sample_size = [s - qa_margin for s in train_sample_size]\n",
    "\n",
    "test_sample_size = sample_size - qa_margin\n",
    "max_sentence_len = None\n",
    "if (task_start_pct is not None) and (task_end_pct is not None):\n",
    "    # do not sample sentences longer than task position range * 0.5\n",
    "    max_sentence_len = int((task_end_pct - task_start_pct) * 0.5 * sample_size)\n",
    "    \n",
    "noise_sampler_train = SentenceSampler(noise_dataset_train, tokenizer=tokenizer, max_sentence_len=max_sentence_len, shuffle=True, random_seed=None)\n",
    "noise_sampler_test = SentenceSampler(noise_dataset_test, tokenizer=tokenizer, max_sentence_len=max_sentence_len, shuffle=True, random_seed=42)\n",
    "\n",
    "train_dataset = NoiseInjectionDataset(task_dataset=task_dataset_train,\n",
    "                                        noise_sampler=noise_sampler_train,\n",
    "                                        tokenizer=tokenizer,\n",
    "                                        sample_size=train_sample_size,\n",
    "                                        mixed_length_ratio=mixed_length_ratio,\n",
    "                                        task_start_pct=task_start_pct,\n",
    "                                        task_end_pct=task_end_pct\n",
    "                                        )\n",
    "\n",
    "test_dataset = NoiseInjectionDataset(task_dataset=task_dataset_test,\n",
    "                                        noise_sampler=noise_sampler_test,\n",
    "                                        tokenizer=tokenizer,\n",
    "                                        sample_size=test_sample_size,\n",
    "                                        mixed_length_ratio=mixed_length_ratio,\n",
    "                                        task_start_pct=task_start_pct,\n",
    "                                        task_end_pct=task_end_pct\n",
    "                                        )\n",
    "\n",
    "id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id\n",
    "gen_token = tokenizer.encode('GEN')[0]\n",
    "eos_token = tokenizer.eos_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_input_ids(sample):\n",
    "    template = \"{} {}Answer with a single word.\"\n",
    "    context = tokenizer.decode(sample['input_tokens'])\n",
    "    messages = [\n",
    "        {\"role\": \"user\", \"content\": template.format(context, sample['question'])},\n",
    "        {\"role\": \"assistant\", \"content\": sample['answer']}\n",
    "    ]\n",
    "    input_ids = tokenizer.apply_chat_template(\n",
    "        messages,\n",
    "        tokenize=True,\n",
    "        add_generation_prompt=False\n",
    "    )\n",
    "    return input_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def collate_fn(batch):\n",
    "    inputs = [get_input_ids(sample) for sample in batch]\n",
    "    input_ids = [torch.tensor(i) for i in inputs]\n",
    "    attention_mask = [torch.ones_like(b, dtype=bool) for b in input_ids]\n",
    "\n",
    "    input_ids = pad_sequence(input_ids, padding_value=id_pad_value, batch_first=True)\n",
    "    attention_mask = pad_sequence(attention_mask, padding_value=0, batch_first=True)\n",
    "\n",
    "    collated = {}\n",
    "    collated['input_ids'] = collated['labels'] = input_ids\n",
    "    collated['attention_mask'] = attention_mask.bool()\n",
    "\n",
    "    return collated"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run train\n",
    "def train_one_epoch(training_loader, segment_size, max_steps=-1, device=\"cuda\"):\n",
    "    running_loss = 0.\n",
    "    last_loss = 0.\n",
    "    for i, data in tqdm(enumerate(training_loader)):\n",
    "        # Every data instance is an input + label pair\n",
    "        #inputs, labels = data[\"input_ids\"][...,:-1].to(device), data[\"labels\"][...,1:].to(device)\n",
    "        inputs, labels = data[\"input_ids\"].to(device), data[\"labels\"].to(device)\n",
    "        # pad each sample to the segm_num*segm_size, cause model cannot handle other samples\n",
    "        \n",
    "        pad_shape = (segment_size - inputs.shape[-1] % (segment_size), 0)\n",
    "        # print(inputs.shape, pad_shape)\n",
    "        inputs = torch.nn.functional.pad(inputs.squeeze(), pad_shape).unsqueeze(0)\n",
    "        labels = torch.nn.functional.pad(labels.squeeze(), pad_shape, value=-100).unsqueeze(0)\n",
    "        # print(inputs, labels)\n",
    "        # print(inputs.shape, labels.shape)\n",
    "        \n",
    "        # Zero your gradients for every batch\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        # Make predictions for this batch\n",
    "        outputs = model(inputs)\n",
    "\n",
    "        outputs = outputs.logits\n",
    "        # shift predicted and labels\n",
    "        labels = labels[..., 1:]\n",
    "        outputs = outputs[..., :-1, :]\n",
    "        # flatten and calc loss\n",
    "        labels = labels.view(-1)\n",
    "        outputs = outputs.view(-1, outputs.size(-1))\n",
    "        # Compute the loss and its gradients\n",
    "        loss = loss_fn(outputs, labels)\n",
    "        loss.backward()\n",
    "        \n",
    "        # Adjust learning weights\n",
    "        optimizer.step()\n",
    "\n",
    "        # Gather data and report\n",
    "        running_loss += loss.item()\n",
    "        if i % 10 == 0:\n",
    "            last_loss = running_loss / 1000 # loss per batch\n",
    "            print('  batch {} loss: {}'.format(i + 1, last_loss))\n",
    "            tb_x = i + 1\n",
    "            print('Loss/train', last_loss, tb_x)\n",
    "            running_loss = 0.\n",
    "        if i > max_steps:\n",
    "            break\n",
    "\n",
    "    return last_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "// Gemm operator cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8\n",
      "using cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_base =\n",
      "  typename cutlass::gemm::kernel::DefaultGemmGrouped<\n",
      "    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
      "    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
      "    cutlass::bfloat16_t, cutlass::layout::RowMajor,\n",
      "    float,\n",
      "    cutlass::arch::OpClassTensorOp,\n",
      "    cutlass::arch::Sm80,\n",
      "    cutlass::gemm::GemmShape<256, 128, 64>,\n",
      "    cutlass::gemm::GemmShape<64, 64, 64>,\n",
      "    cutlass::gemm::GemmShape<16, 8, 16>,\n",
      "    cutlass::epilogue::thread::LinearCombination<cutlass::bfloat16_t, 8, float, float>,\n",
      "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
      "    3,\n",
      "    cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,\n",
      "    cutlass::arch::OpMultiplyAdd\n",
      ">::GemmKernel;\n",
      "\n",
      "// Define named type\n",
      "struct cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_type :\n",
      "  public cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_base { };\n",
      "\n",
      "USE_EFFICIENT_ALLOCATION\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1it [00:02,  2.42s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  batch 1 loss: 0.0119375\n",
      "Loss/train 0.0119375 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "11it [00:18,  1.56s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  batch 11 loss: 0.06053125\n",
      "Loss/train 0.06053125 11\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "21it [00:35,  1.53s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  batch 21 loss: 0.05021875\n",
      "Loss/train 0.05021875 21\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "31it [00:50,  1.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  batch 31 loss: 0.046125\n",
      "Loss/train 0.046125 31\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "41it [01:06,  1.60s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  batch 41 loss: 0.04334375\n",
      "Loss/train 0.04334375 41\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "51it [01:23,  1.55s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  batch 51 loss: 0.042078125\n",
      "Loss/train 0.042078125 51\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "61it [01:38,  1.44s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  batch 61 loss: 0.040796875\n",
      "Loss/train 0.040796875 61\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "71it [01:54,  1.65s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  batch 71 loss: 0.040578125\n",
      "Loss/train 0.040578125 71\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "81it [02:12,  1.73s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  batch 81 loss: 0.0395\n",
      "Loss/train 0.0395 81\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "91it [02:26,  1.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  batch 91 loss: 0.038859375\n",
      "Loss/train 0.038859375 91\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "101it [02:41,  1.45s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  batch 101 loss: 0.039015625\n",
      "Loss/train 0.039015625 101\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "101it [02:43,  1.62s/it]\n"
     ]
    }
   ],
   "source": [
    "set_seed(421)\n",
    "device = \"cuda\"\n",
    "train_dl = DataLoader(train_dataset, batch_size=1,\n",
    "                      shuffle=False, collate_fn=collate_fn)\n",
    "\n",
    "loss_fn = torch.nn.CrossEntropyLoss().to(device)\n",
    "#optimizer = torch.optim.AdamW(list(model.armt_model.parameters()) + list(model.grouped_layer.parameters()), lr=args.learning_rate)\n",
    "optimizer = torch.optim.AdamW(model.armt_model.parameters(), lr=learning_rate)\n",
    "# Quick fix, only for BS=1, w/o any mask and normal padding, etc.\n",
    "for i in range(1):\n",
    "    train_one_epoch(train_dl, segment_size, max_steps=100, device=device)\n",
    "    # save model\n",
    "    os.makedirs(output_dir, exist_ok=True)\n",
    "    torch.save(model.grouped_layer.state_dict(), os.path.join(output_dir, \"grouped_layer.pth\"))\n",
    "    torch.save(model.armt_model.state_dict(), os.path.join(output_dir, \"armt_model.pth\"))\n",
    "\n",
    "# save model\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "torch.save(model.grouped_layer.state_dict(), os.path.join(output_dir, \"grouped_layer.pth\"))\n",
    "torch.save(model.armt_model.state_dict(), os.path.join(output_dir, \"armt_model.pth\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:conda_rmt_it_venv]",
   "language": "python",
   "name": "conda-env-conda_rmt_it_venv-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
