{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# fim\n",
    "\n",
    "> Fill in a module description here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp fim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from ProtMamba_ssm.utils import MASK_TO_ID, AA_TO_ID\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "\n",
    "class AbstractFIM(object):\n",
    "    def __init__(self,\n",
    "                 max_patches=5,\n",
    "                 mask_fraction=0.2,\n",
    "                 always_mask=False,\n",
    "                 mask_tokens=MASK_TO_ID,\n",
    "                 eos_token=AA_TO_ID[\"<eos>\"],\n",
    "                 add_position_ids=False,\n",
    "                 troubleshoot=False):\n",
    "        \"\"\"\n",
    "        This class is designed to concatenate sequences based on different scrambling strategies.\n",
    "        It takes a list of sequences, tuples indicating the start and end indices of each sequence,\n",
    "        an optional number of patches to sample, and a scrambling strategy as inputs.\n",
    "        \"\"\"\n",
    "        self.troubleshoot = troubleshoot\n",
    "        self.max_patches = max_patches\n",
    "        self.mask_fraction = mask_fraction\n",
    "        self.mask_tokens = mask_tokens\n",
    "        assert len(\n",
    "            self.mask_tokens) >= self.max_patches, \"Number of mask tokens must be bigger than max number of patches.\"\n",
    "        self.eos_token = eos_token\n",
    "        self.add_position_ids = add_position_ids\n",
    "        self.always_mask = always_mask\n",
    "\n",
    "    def apply(self, sequences, tuples):\n",
    "        \"\"\"\n",
    "        This function concatenates the sequences scrambling each one according to the scrambling strategy.\n",
    "        \"\"\"\n",
    "        input_ids, position_ids = [], []\n",
    "        for t in tuples:\n",
    "            seq, pos = self.fim(sequences, t)\n",
    "            input_ids.extend(seq)\n",
    "            if self.add_position_ids:\n",
    "                position_ids.extend(pos)\n",
    "        if self.add_position_ids:\n",
    "            return input_ids, position_ids\n",
    "        return input_ids, None\n",
    "\n",
    "    def fim(self, sequences, t):\n",
    "        \"\"\"\n",
    "        This function concatenates the sequence's parts based on the scrambling strategy.\n",
    "        \"\"\"\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "class NoFIM(AbstractFIM):\n",
    "    def __init__(self,\n",
    "                 max_patches=5,\n",
    "                 mask_fraction=0.2,\n",
    "                 always_mask=False,\n",
    "                 mask_tokens=MASK_TO_ID,\n",
    "                 eos_token=AA_TO_ID[\"<eos>\"],\n",
    "                 add_position_ids=False,\n",
    "                 troubleshoot=False):\n",
    "        super().__init__(max_patches, mask_fraction, always_mask, mask_tokens, eos_token, add_position_ids, troubleshoot)\n",
    "\n",
    "    def fim(self, sequences, t):\n",
    "        \"\"\"\n",
    "        This function keeps the sequence identical without any scrambling.\n",
    "        \"\"\"\n",
    "        if self.add_position_ids:\n",
    "            position_ids = np.arange(t[0], t[1]) - t[0]\n",
    "            return sequences[t[0]:t[1]], position_ids\n",
    "        return sequences[t[0]:t[1]], None\n",
    "\n",
    "\n",
    "class SingleSpanFIM(AbstractFIM):\n",
    "\n",
    "    def __init__(self,\n",
    "                 max_patches=5,\n",
    "                 mask_fraction=0.2,\n",
    "                 always_mask=False,\n",
    "                 mask_tokens=MASK_TO_ID,\n",
    "                 eos_token=AA_TO_ID[\"<eos>\"],\n",
    "                 add_position_ids=False,\n",
    "                 troubleshoot=False):\n",
    "        super().__init__(max_patches, mask_fraction, always_mask, mask_tokens, eos_token, add_position_ids, troubleshoot)\n",
    "\n",
    "    def fim(self, sequences, t):\n",
    "        \"\"\"\n",
    "        This function creates and concatenates parts of the sequences based on the OpenAI scrambling strategy.\n",
    "        It randomly selects two indices within the range of the given tuple,\n",
    "        splits the sequence into three parts based on these indices, and then concatenates them with the\n",
    "        masked patch at the end\n",
    "        \"\"\"\n",
    "        new_tuple = tuple(np.sort(np.random.choice(np.arange(t[0] + 1, t[1]), 2, replace=False)))\n",
    "        part1 = sequences[t[0]:new_tuple[0]]\n",
    "        part2 = sequences[new_tuple[0]:new_tuple[1]]\n",
    "        part3 = sequences[new_tuple[1]:t[1]]\n",
    "        sequence = np.concatenate([part1, [self.mask_tokens[\"<mask-1>\"]], part3, [self.mask_tokens[\"<mask-1>\"]], part2])\n",
    "        position_ids_sequence = None\n",
    "        if self.add_position_ids:\n",
    "            position_ids = np.arange(t[0], t[1]) - t[0]\n",
    "            position_ids_part1 = position_ids[t[0]:new_tuple[0]]\n",
    "            position_ids_part2 = position_ids[new_tuple[0]:new_tuple[1]]\n",
    "            position_ids_part3 = position_ids[new_tuple[1]:t[1]]\n",
    "            position_ids_sequence = np.concatenate(\n",
    "                [position_ids_part1, [position_ids_part2[0]], position_ids_part3, [position_ids_part2[0]],\n",
    "                 position_ids_part2])\n",
    "\n",
    "        return sequence, position_ids_sequence\n",
    "\n",
    "\n",
    "class MultipleSpanFIM(AbstractFIM):\n",
    "    def __init__(self,\n",
    "                 max_patches=5,\n",
    "                 mask_fraction=0.2,\n",
    "                 always_mask=False,\n",
    "                 mask_tokens=MASK_TO_ID,\n",
    "                 eos_token=AA_TO_ID[\"<eos>\"],\n",
    "                 add_position_ids=False,\n",
    "                 troubleshoot=False):\n",
    "        super().__init__(max_patches, mask_fraction, always_mask, mask_tokens, eos_token, add_position_ids, troubleshoot)\n",
    "\n",
    "    def fim(self, sequences, t):\n",
    "        \"\"\"\n",
    "        This function creates and concatenates parts of the sequences based on the inpaint scrambling strategy.\n",
    "        It randomly selects `2*num_patches` indices within the range of the given tuple,\n",
    "        splits the sequence into unmasked and masked parts based on these indices, and then concatenates them.\n",
    "        The number of patches is sampled from a poisson distribution with upper limit `self.max_patches` and average 1.\n",
    "        The concatenation is done by joining all unmaksed parts (interleaved with mask tokens) and afterwards\n",
    "        all masked parts (interleaved with mask tokens). At the end of the unmasked parts, a special token is added\n",
    "        to indicate the end of the unmasked parts, and at the end of the masked parts, a special token is added\n",
    "        to indicate the end of the masked parts.\n",
    "        \"\"\"\n",
    "        # sample num_patches from a discrete poisson distribution with upper limit L\n",
    "        def sample_lengths(start, end):\n",
    "            \"\"\"\n",
    "            Sample a length uniformly from 1 to max_L*self.mask_fraction (must be bigger than 1).\n",
    "            If the length is larger than max_L, return max_L.\n",
    "            \"\"\"\n",
    "            max_L = end - start\n",
    "            length = np.random.randint(1, max(int(max_L * self.mask_fraction), 2))\n",
    "            return min(length, max_L)\n",
    "\n",
    "        # sample num_patches from a discrete poisson distribution with upper limit max_patches\n",
    "        num_patches = 1000\n",
    "        while num_patches > self.max_patches:\n",
    "            num_patches = np.random.poisson(1)\n",
    "        if self.always_mask:\n",
    "            num_patches = max(num_patches, 1)\n",
    "        # sample num_patches starting points for the masked positions (+ final position)\n",
    "        start_patches = list(np.sort(np.random.choice(np.arange(t[0] + 1, t[1]),\n",
    "                                                      num_patches,\n",
    "                                                      replace=False))) + [t[1]]\n",
    "        # sample num_patches lengths of the patches\n",
    "        len_patches = [sample_lengths(start_patches[i], start_patches[i + 1])\n",
    "                       for i in range(len(start_patches) - 1)]\n",
    "        # create masked tuples with start and end indices of the patches\n",
    "        masked_tuples = [(start_patches[i], start_patches[i] + len_patches[i]) for i in range(len(start_patches) - 1)]\n",
    "        # split the sequences into unmasked and masked parts\n",
    "        unmasked_sequence, masked_sequence, unmasked_position_ids, masked_position_ids = self.split_sequences(sequences,\n",
    "                                                                                                              t,\n",
    "                                                                                                              masked_tuples)\n",
    "\n",
    "        if self.troubleshoot:\n",
    "            print(f\"For sequence in {t}: sampled {num_patches=}, {start_patches=}, {len_patches=}, {masked_tuples=}\")\n",
    "        # concatenate the unmasked and masked parts\n",
    "        return unmasked_sequence + masked_sequence, unmasked_position_ids + masked_position_ids if self.add_position_ids else None\n",
    "\n",
    "    def split_sequences(self, sequences, t, masked_tuples):\n",
    "        \"\"\"\n",
    "        This function splits the sequences into unmasked and masked parts based on the given tuples.\n",
    "        Args:\n",
    "            t (tuple): The start and end index of each sequence.\n",
    "            masked_tuples (list): A list of tuples specifying the indices for masked regions.\n",
    "        Returns:\n",
    "            unmasked_parts (list): The unmasked parts of the sequences interleaved with mask_tokens.\n",
    "            masked_parts (list): The masked parts of the sequences interleaved with mask_tokens.\n",
    "        \"\"\"\n",
    "        unmasked_parts, masked_parts = [], []\n",
    "        unmasked_positions, masked_positions = [], []\n",
    "        position_ids = None\n",
    "        start, end = t\n",
    "        if self.add_position_ids:\n",
    "            position_ids = np.arange(start, end) - start\n",
    "        for i, region in enumerate(masked_tuples):\n",
    "            mask_token = self.mask_tokens[f\"<mask-{i + 1}>\"]\n",
    "            unmasked_parts.extend(sequences[start:region[0]])\n",
    "            unmasked_parts.append(mask_token)\n",
    "            masked_parts.append(mask_token)\n",
    "            masked_parts.extend(sequences[region[0]:region[1]])\n",
    "            if self.add_position_ids:\n",
    "                unmasked_positions.extend(position_ids[start-t[0]:region[0]-t[0]])\n",
    "                unmasked_positions.append(position_ids[region[0]-t[0]])\n",
    "                masked_positions.append(position_ids[region[0]-t[0]])\n",
    "                masked_positions.extend(position_ids[region[0]-t[0]:region[1]-t[0]])\n",
    "\n",
    "            start = region[1]\n",
    "        unmasked_parts.extend(sequences[start:end])\n",
    "        if self.add_position_ids:\n",
    "            unmasked_positions.extend(position_ids[start-t[0]:end-t[0]])\n",
    "        if len(masked_tuples) > 0:\n",
    "            unmasked_parts.append(self.eos_token)\n",
    "            if self.add_position_ids:\n",
    "                unmasked_positions.append(0)\n",
    "        return unmasked_parts, masked_parts, unmasked_positions, masked_positions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import nbdev; nbdev.nbdev_export()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
