{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/anaconda3/envs/cramming/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from cramming import construct_model\n",
    "import os\n",
    "import json\n",
    "from omegaconf import OmegaConf\n",
    "import transformers\n",
    "import torch\n",
    "import datasets\n",
    "\n",
    "dir_name = \"/home/ubuntu/cramming_test/outputs/pretrain-ours/checkpoints/ScriptableMaskedLM_2023-09-24_2.1270\"\n",
    "\n",
    "with open(os.path.join(dir_name, \"model_config.json\"), \"r\") as file:\n",
    "    cfg_arch = OmegaConf.create(json.load(file))  # Could have done pure hydra here, but wanted interop\n",
    "model = construct_model(cfg_arch, 32768)\n",
    "\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(dir_name)\n",
    "\n",
    "data_path = \"/home/ubuntu/cramming_test/outputs/data/sanity-check-2_BPEx32768_aa4b98dc480e637aa82f59461e1b1729\"\n",
    "# data_path = \"/home/ubuntu/cramming/outputs/data/sanity-check-2_BPEx32768_aa4b98dc480e637aa82f59461e1b1729\"\n",
    "dataset = datasets.load_from_disk(data_path)\n",
    "# Cast to tensors after loading from arrow:\n",
    "dataset.set_format(\"torch\")\n",
    "# select the first 10000 samples\n",
    "dataset = dataset.select(range(10000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"This is a minor modification of huggingface's toking masking:\"\"\"\n",
    "\"\"\"original source:\n",
    "https://github.com/huggingface/transformers/blob/130b987880a9b1ade5c76dc1413c12c8924fda50/src/transformers/data/data_collator.py#L748\n",
    "at commit f00f22a3e290fd377b979124dcf9800b3d73eb11\"\"\"\n",
    "\n",
    "\n",
    "class PatchedDataCollatorForLanguageModeling(transformers.DataCollatorForLanguageModeling):\n",
    "    def __init__(self, *args, use_80_20_rule=True, token_drop=False, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.use_80_20_rule = use_80_20_rule\n",
    "        self.token_drop = token_drop\n",
    "\n",
    "        self.mask_token = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)\n",
    "\n",
    "    def torch_mask_tokens(self, inputs=None, special_tokens_mask=None):\n",
    "        \"\"\"\n",
    "        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.\n",
    "        \"\"\"\n",
    "        labels = inputs.clone()\n",
    "        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)\n",
    "        probability_matrix = torch.full(labels.shape, self.mlm_probability)\n",
    "        if special_tokens_mask is None:\n",
    "            special_tokens_mask = [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]\n",
    "            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)\n",
    "        else:\n",
    "            special_tokens_mask = special_tokens_mask.bool()\n",
    "\n",
    "        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)\n",
    "        masked_indices = torch.bernoulli(probability_matrix).bool()\n",
    "        labels[~masked_indices] = -100  # We only compute loss on masked tokens\n",
    "\n",
    "        if self.use_80_20_rule:\n",
    "            # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])\n",
    "            indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices\n",
    "            inputs[indices_replaced] = self.mask_token\n",
    "\n",
    "            # 10% of the time, we replace masked input tokens with random word\n",
    "            indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced\n",
    "            random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=inputs.dtype)\n",
    "            inputs[indices_random] = random_words[indices_random]\n",
    "\n",
    "            # The rest of the time (10% of the time) we keep the masked input tokens unchanged\n",
    "        else:\n",
    "            # 100% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])\n",
    "            inputs[masked_indices] = self.mask_token\n",
    "\n",
    "        if self.token_drop > 0:\n",
    "            inputs, labels = self._drop_tokens(inputs, labels)\n",
    "        return inputs, labels\n",
    "\n",
    "    def torch_call(self, examples):\n",
    "        \"\"\"Simplified call assuming all dicts in the list of examples have the same layout and contain tensors.\n",
    "        Assume further that all these tensors contain vectors of Long Tensors ... [AND THEY HAVE TO BE LONG]\"\"\"\n",
    "        # Handle dict or lists with proper padding and conversion to tensor.\n",
    "        # if isinstance(examples[0], Mapping):\n",
    "        #     batch = self.tokenizer.pad(examples, return_tensors=\"pt\", pad_to_multiple_of=self.pad_to_multiple_of)\n",
    "        # else:\n",
    "        #     batch = {\"input_ids\": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)}\n",
    "        # This raises dumb warnings with my latest setup\n",
    "\n",
    "        # So this is the handmade version\n",
    "        batch = dict()\n",
    "        for key in examples[0].keys():\n",
    "            elem = examples[0][key]\n",
    "            # block = examples[0][key].new_empty(len(examples), *examples[0][key].shape)\n",
    "            # for idx, example in enumerate(examples):\n",
    "            #     block[idx] = example[key]\n",
    "            out = None\n",
    "            if torch.utils.data.get_worker_info() is not None:\n",
    "                storage = elem._storage()._new_shared(len(examples) * 8 * elem.shape[0], device=elem.device)  # 8 for byte->long\n",
    "                out = elem.new(storage).resize_(len(examples), elem.shape[0])\n",
    "            batch[key] = torch.stack([example[key] for example in examples], 0, out=out).contiguous()\n",
    "\n",
    "        # If special token mask has been preprocessed, pop it from the dict.\n",
    "        special_tokens_mask = batch.pop(\"special_tokens_mask\", None)\n",
    "        if self.mlm:\n",
    "            batch[\"input_ids\"], batch[\"labels\"] = self.torch_mask_tokens(batch[\"input_ids\"], special_tokens_mask=special_tokens_mask)\n",
    "        else:\n",
    "            labels = batch[\"input_ids\"].clone()\n",
    "            if self.tokenizer.pad_token_id is not None:\n",
    "                labels[labels == self.tokenizer.pad_token_id] = -100\n",
    "            batch[\"labels\"] = labels\n",
    "        return batch\n",
    "\n",
    "    def _drop_tokens(self, input_ids, labels):\n",
    "        \"\"\"Drop random tokens. Hou et al., \"Token Dropping for Efficient BERT Pretraining\" also discuss dropping tokens\n",
    "        based on more advanced strategies, which might also be helpful.\n",
    "\n",
    "        This is the simplest strategy, randomly dropping a bunch of tokens for all layers.\n",
    "        \"\"\"\n",
    "        reduced_seq_length = int(input_ids.shape[1] * (1 - self.token_drop))\n",
    "        # There is probably a faster way to do this, but this works for now?\n",
    "        token_mask = torch.argsort(torch.rand_like(input_ids, dtype=torch.float), dim=-1)\n",
    "        fixed_mask = input_ids.scatter(1, token_mask[:, :reduced_seq_length], -1) == -1\n",
    "        return input_ids[fixed_mask].view(input_ids.shape[0], -1), labels[fixed_mask].view(input_ids.shape[0], -1)\n",
    "\n",
    "\n",
    "class InfiniteDataLoader(torch.utils.data.DataLoader):\n",
    "    \"\"\"Lazy copy-paste from https://gist.github.com/MFreidank/821cc87b012c53fade03b0c7aba13958.\"\"\"\n",
    "\n",
    "    def __init__(self, *args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        # Initialize an iterator over the dataset.\n",
    "        self.dataset_iterator = super().__iter__()\n",
    "        self.epoch_counter = 0\n",
    "\n",
    "    def __iter__(self):\n",
    "        return self\n",
    "\n",
    "    def __next__(self):\n",
    "        try:\n",
    "            batch = next(self.dataset_iterator)\n",
    "        except StopIteration:\n",
    "            # Dataset exhausted, use a new fresh iterator.\n",
    "            self.dataset_iterator = super().__iter__()\n",
    "            self.epoch_counter += 1\n",
    "            if hasattr(self.sampler, \"set_epoch\"):\n",
    "                self.sampler.set_epoch(self.epoch_counter)\n",
    "            batch = next(self.dataset_iterator)\n",
    "        return batch\n",
    "\n",
    "collate_fn = PatchedDataCollatorForLanguageModeling(\n",
    "    tokenizer=tokenizer,\n",
    "    mlm=True,\n",
    "    mlm_probability=0.15,\n",
    "    pad_to_multiple_of=8,\n",
    "    use_80_20_rule=True,\n",
    "    token_drop=0.0,\n",
    ")\n",
    "\n",
    "sampler = torch.utils.data.RandomSampler(dataset)\n",
    "\n",
    "dataloader = InfiniteDataLoader(\n",
    "    dataset,\n",
    "    sampler=sampler,\n",
    "    batch_size=32,\n",
    "    # num_workers=1,\n",
    "    # pin_memory=True,\n",
    "    # drop_last=True,\n",
    "    # prefetch_factor=2,\n",
    "    # persistent_workers=True,\n",
    "    collate_fn=collate_fn,\n",
    ")\n",
    "\n",
    "# from torch.utils.data import DataLoader, SequentialSampler, RandomSampler\n",
    "# sampler = RandomSampler(dataset)\n",
    "# dataloader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn, sampler=sampler)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "N: 619, k: 418, mask sum: 419\n",
      "Step 0: 10.432843208312988\n",
      "Average loss: 10.432843208312988\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\")\n",
    "model.to(device=device)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, eps=1e-12)\n",
    "\n",
    "N = 1\n",
    "total_loss = 0.0\n",
    "for step, batch in enumerate(dataloader):\n",
    "    if step >= N:\n",
    "        break\n",
    "    device_batch = {\n",
    "        k: v.to(device=device, dtype=torch.long if k == \"input_ids\" else None, non_blocking=True)\n",
    "        for k, v in batch.items()\n",
    "        if k in [\"input_ids\", \"labels\", \"attention_mask\"]  # Add more keywords here if needed\n",
    "    }\n",
    "    with torch.autocast(device_type=device.type, dtype=torch.float16):\n",
    "        loss = model(**device_batch)[\"loss\"]\n",
    "    total_loss += loss.item()\n",
    "    loss.backward()\n",
    "    # optimizer.step()\n",
    "    model.zero_grad()\n",
    "    print(f\"Step {step}: {loss}\")\n",
    "total_loss /= N\n",
    "print(f\"Average loss: {total_loss}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "N: 32, k: 21, mask sum: 18\n",
      "torch.Size([32, 32, 768])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd import Function\n",
    "\n",
    "def soft_topk(x: torch.Tensor, k: int):\n",
    "    N = x.shape[0]\n",
    "    if k >= N:\n",
    "        return torch.ones_like(x), torch.ones_like(x, dtype=torch.bool)\n",
    "    x_sorted, indices = torch.sort(x, dim=0, descending=False)\n",
    "    # indices_inverse = torch.argsort(indices)\n",
    "    prefix_sum = torch.cumsum(x_sorted, dim=0)\n",
    "    cmp = torch.arange(k + 1 - N, k + 1, 1, device=x.device) - prefix_sum / (x_sorted + 1e-9)\n",
    "    m = torch.searchsorted(cmp, 0).clamp(N - k + 1, N)\n",
    "    prob = torch.empty_like(x)\n",
    "    rand = torch.rand_like(x)\n",
    "    prob = x * (k - N + m) / (prefix_sum[m-1] + 1e-9) # will overflow 1 but doesn't matter since we will clamp later\n",
    "    # prob = prob[indices_inverse]\n",
    "    prob.clamp_(min=1e-9, max=1.0)\n",
    "    mask = rand < prob\n",
    "    return prob, mask\n",
    "\n",
    "class sampling(Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, input, idx):\n",
    "        ctx.idx = idx\n",
    "        return input\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_output):\n",
    "        idx = ctx.idx\n",
    "        sample_ratio = 0.6763\n",
    "        N = grad_output.shape[0]\n",
    "        prob = grad_output.norm(dim=1, p=2)\n",
    "        k = int(N * sample_ratio)\n",
    "        prob, mask = soft_topk(prob, k)\n",
    "        print(\"N: {}, k: {}, mask sum: {}\".format(N, k, mask.sum()))\n",
    "        if not mask.any():\n",
    "            print(\"mask is zero, idx: {}, mask: {}, prob: {}, k: {}, norm: {}\".format(idx, mask, prob, k, grad_output.norm(dim=1, p=2)))\n",
    "            return grad_output, None\n",
    "        print((grad_output * ((mask / prob).view(-1, 1))).shape)\n",
    "        return grad_output * ((mask / prob).view(-1, 1)), None\n",
    "\n",
    "class Sample(torch.nn.Module):\n",
    "    idx = 0\n",
    "    def __init__(self, sample=True):\n",
    "        super(Sample, self).__init__()\n",
    "\n",
    "    def forward(self, input):\n",
    "        return sampling.apply(input, self.idx)\n",
    "    \n",
    "\n",
    "a = torch.rand(32, 768, requires_grad=True)\n",
    "\n",
    "Sample()(a).sum().backward()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cramming",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
