{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# dataloaders\n",
    "\n",
    "> Fill in a module description here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp dataloaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from ProtMamba_ssm.utils import AA_TO_ID\n",
    "from ProtMamba_ssm.fim import NoFIM, SingleSpanFIM, MultipleSpanFIM\n",
    "import pickle\n",
    "import torch\n",
    "import numpy as np\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data import DataLoader\n",
    "from dataclasses import dataclass\n",
    "from typing import Dict, Sequence\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "\n",
    "# Make dataset\n",
    "class Uniclust30_Dataset(Dataset):\n",
    "    \"\"\"\n",
    "        Dataset class used to import the Uniclust30 folders.\n",
    "        If `filename` = \"encoded_MSAs.pkl\", it will load the full dataset.\n",
    "        If `filename` = \"encoded_MSAs_subset.pkl\", it will load a small subset of the dataset.\n",
    "        If `sample` = True, it will sample a random number of sequences from each cluster.\n",
    "        If `sample` = False, it will load all the sequences from each cluster (and shuffle them).\n",
    "        To limit the length of the MSAs, set `max_msa_len` to a positive integer.\n",
    "        If `reverse` = True, it will reverse the sequences with probability 0.5 and move the last token to the front.\n",
    "        If `scrambling_strategy` = \"no-scramble\", it will not scramble the sequences and simply concatenate them.\n",
    "        If `scrambling_strategy` = \"OpenAI\", it will scramble the sequences using the OpenAI strategy.\n",
    "        If `scrambling_strategy` = \"inpaint\", it will scramble the sequences using the inpaint strategy. In this case it will use\n",
    "        `max_patches` patches and mask `mask_fraction` of the patches.\n",
    "    \"\"\"\n",
    "    _FIM = {\"no-scramble\": NoFIM, \"one_span\": SingleSpanFIM, \"multiple_span\": MultipleSpanFIM}\n",
    "    _POSIDS = {\"none\", \"1d\", \"2d\"}\n",
    "\n",
    "    def __init__(self, filename=\"encoded_MSAs_train.pkl\",\n",
    "                 filepath=\"/nvme1/common/OpenProteinSet/\",\n",
    "                 sample=False,\n",
    "                 max_msa_len=-1,\n",
    "                 reverse=False,\n",
    "                 seed=42,\n",
    "                 troubleshoot=False,\n",
    "                 fim_strategy=\"no-scramble\",\n",
    "                 max_patches=5,\n",
    "                 mask_fraction=0.2,\n",
    "                 always_mask=False,\n",
    "                 max_position_embeddings=2048,\n",
    "                 max_seq_position_embeddings=512,\n",
    "                 add_position_ids=\"none\", ):\n",
    "        np.random.seed(seed)\n",
    "        self.path = filepath\n",
    "        # self.path_clusters = self.path + \"OpenProteinSet_uniclust30-filtered/\"\n",
    "        if filename:\n",
    "            self.dataset = pickle.load(open(self.path + filename, \"rb\"))\n",
    "            self.cluster_names = list(self.dataset.keys())\n",
    "        else:\n",
    "            self.dataset = None\n",
    "            self.cluster_names = []\n",
    "        self.sample = sample\n",
    "        self.max_msa_len = max_msa_len\n",
    "        self.reverse = reverse\n",
    "        self.fim_strategy = fim_strategy\n",
    "        if fim_strategy in Uniclust30_Dataset._FIM:\n",
    "            self.fim = Uniclust30_Dataset._FIM[fim_strategy](max_patches=max_patches,\n",
    "                                                             mask_fraction=mask_fraction,\n",
    "                                                             always_mask=always_mask,\n",
    "                                                             add_position_ids=add_position_ids != \"none\",\n",
    "                                                             troubleshoot=troubleshoot)\n",
    "        else:\n",
    "            raise ValueError(f'Fill in the middle stragy \"{fim_strategy}\" not recognized.')\n",
    "        self.max_position_embeddings = max_position_embeddings\n",
    "        self.max_seq_position_embeddings = max_seq_position_embeddings\n",
    "        self.add_position_ids = add_position_ids\n",
    "\n",
    "        self.troubleshoot = troubleshoot\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.cluster_names)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        # get all the sequences in the cluster\n",
    "        sequences = self.get_sequences(idx)\n",
    "        # get total number of sequences in the cluster and choose how many to sample\n",
    "        orig_num_sequences = len(self.get_index_start_of_sequences(sequences))\n",
    "        num_sequences = np.random.randint(1, orig_num_sequences + 1) if self.sample else orig_num_sequences\n",
    "        # sample the sequences\n",
    "        sequences, position_ids = self.sample_sequences(sequences, num_sequences)\n",
    "        # with probability 0.5, reverse the sequences and move the last token to the front\n",
    "        sequences, position_ids = self.reverse_sequences(sequences, position_ids) if (\n",
    "                self.reverse and np.random.rand() > 0.5) else sequences, position_ids\n",
    "        # limit the length of the MSA\n",
    "        sequences = sequences[:self.max_msa_len] if self.max_msa_len > 0 else sequences\n",
    "        if self.add_position_ids != \"none\":\n",
    "            position_ids = position_ids[:self.max_msa_len] if self.max_msa_len > 0 else position_ids\n",
    "        # convert to tensor\n",
    "        sequences = torch.asarray(sequences, dtype=torch.int64)\n",
    "        position_ids = torch.asarray(position_ids, dtype=torch.int64).clamp(0,\n",
    "                                                                            self.max_position_embeddings - 1) if self.add_position_ids!=\"none\" else None\n",
    "\n",
    "        if self.troubleshoot:\n",
    "            print(\n",
    "                f\"Cluster {idx} has {orig_num_sequences} sequences, of which {num_sequences} sampled now. Total MSA length: {len(sequences)}\")\n",
    "        if self.add_position_ids == \"1d\":\n",
    "            return dict(input_ids=sequences, position_ids=position_ids, labels=sequences)\n",
    "        if self.add_position_ids == \"2d\":\n",
    "            seq_position_ids = (sequences == AA_TO_ID[\"<cls>\"]).int().cumsum(-1).clamp(0,\n",
    "                                                                                       self.max_seq_position_embeddings - 1).contiguous()\n",
    "            return dict(input_ids=sequences, position_ids=position_ids, seq_position_ids=seq_position_ids,\n",
    "                        labels=sequences)\n",
    "        return dict(input_ids=sequences, labels=sequences)\n",
    "\n",
    "    def get_sequences(self, idx):\n",
    "        \"\"\"Get the sequences in the cluster with index `idx`.\"\"\"\n",
    "        cluster_name = self.cluster_names[idx]\n",
    "        sequences = self.dataset[cluster_name]\n",
    "        return sequences\n",
    "\n",
    "    def get_index_start_of_sequences(self, sequences):\n",
    "        \"\"\"Get the positions of the start of each sequence in the cluster.\"\"\"\n",
    "        return np.where(sequences == 0)[0]\n",
    "\n",
    "    def reverse_sequences(self, sequence, position_ids=None):\n",
    "        \"\"\"Reverse the sequences and move the last token to the front.\"\"\"\n",
    "        sequence = sequence[::-1]\n",
    "        if position_ids is not None:\n",
    "            position_ids = position_ids[::-1]\n",
    "        return np.concatenate([sequence[-1:], sequence[:-1]]), np.concatenate(\n",
    "            [position_ids[-1:], position_ids[:-1]]) if position_ids is not None else None\n",
    "\n",
    "    def sample_sequences(self, sequences, num_sequences):\n",
    "        \"\"\"Sample `num_sequences` from the sequences in the cluster.\"\"\"\n",
    "        L = len(sequences)\n",
    "        # get the indexes of the start of each sequence\n",
    "        inds = self.get_index_start_of_sequences(sequences)\n",
    "        # check that there are sequences in the cluster and that there are enough of them\n",
    "        assert len(inds) > 0, \"No sequences found in cluster.\"\n",
    "        assert len(inds) >= num_sequences, \"Not enough sequences in cluster.\"\n",
    "        # sample n_sequences randomly from the sequences\n",
    "        which_seqs = np.random.choice(np.arange(len(inds)), num_sequences, replace=False)\n",
    "        # get the tuples of start and end indexes of the sequences\n",
    "        tuples = [(inds[i], inds[i + 1]) if i < len(inds) - 1 else (inds[i], L) for i in which_seqs]\n",
    "        if self.troubleshoot:\n",
    "            print(f\"Sampled sequences: {tuples}\")\n",
    "        # concatenate the sequences\n",
    "        sequences, position_ids = self.fim.apply(sequences, tuples)\n",
    "        return sequences, position_ids\n",
    "\n",
    "\n",
    "def make_dataloader(dataset):\n",
    "    \"\"\"Basic function to make a dataloader.\n",
    "    \"\"\"\n",
    "    dataloader = DataLoader(dataset)\n",
    "    return dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "\n",
    "@dataclass\n",
    "class DataCollatorForUniclust30Dataset(object):\n",
    "    \"\"\"\n",
    "    Collate examples into a batch, and pad batch to the maximum sequence length.\n",
    "    \"\"\"\n",
    "\n",
    "    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:\n",
    "        input_ids, labels = tuple([instance[key] for instance in instances] for key in (\"input_ids\", \"input_ids\")) \n",
    "        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=AA_TO_ID[\"<pad>\"])\n",
    "        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)\n",
    "        if \"seq_position_ids\" in instances[0] and \"position_ids\" in instances[0]:\n",
    "            position_ids = torch.nn.utils.rnn.pad_sequence(\n",
    "                [instance[\"position_ids\"] for instance in instances],\n",
    "                batch_first=True, padding_value=0)\n",
    "            seq_position_ids = torch.nn.utils.rnn.pad_sequence(\n",
    "                [instance[\"seq_position_ids\"] for instance in instances],\n",
    "                batch_first=True, padding_value=0)\n",
    "            return dict(\n",
    "                input_ids=input_ids,\n",
    "                labels=labels,\n",
    "                position_ids=position_ids,\n",
    "                seq_position_ids=seq_position_ids,\n",
    "                attention_mask=input_ids.ne(AA_TO_ID[\"<pad>\"]),\n",
    "            )\n",
    "\n",
    "        if \"position_ids\" in instances[0]:\n",
    "            position_ids = torch.nn.utils.rnn.pad_sequence(\n",
    "                [instance[\"position_ids\"] for instance in instances],\n",
    "                batch_first=True, padding_value=0)\n",
    "            return dict(\n",
    "                input_ids=input_ids,\n",
    "                labels=labels,\n",
    "                position_ids=position_ids,\n",
    "                attention_mask=input_ids.ne(AA_TO_ID[\"<pad>\"]),\n",
    "            )\n",
    "\n",
    "        return dict(\n",
    "            input_ids=input_ids,\n",
    "            labels=labels,\n",
    "            attention_mask=input_ids.ne(AA_TO_ID[\"<pad>\"]),\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from protmamba.utils import MASK_TO_ID, AA_TO_ID\n",
    "\n",
    "import cProfile\n",
    "import pstats\n",
    "import io\n",
    "\n",
    "def profile(func):\n",
    "    def wrapper(*args, **kwargs):\n",
    "        pr = cProfile.Profile()\n",
    "        pr.enable()\n",
    "        retval = func(*args, **kwargs)\n",
    "        pr.disable()\n",
    "        s = io.StringIO()\n",
    "        sortby = 'cumulative'\n",
    "        ps = pstats.Stats(pr, stream=s).sort_stats(sortby)\n",
    "        ps.print_stats()\n",
    "        print(s.getvalue())\n",
    "        return retval\n",
    "    return wrapper\n",
    "\n",
    "class ConcatenateSequences:\n",
    "    def __init__(self, \n",
    "                 max_patches=5,\n",
    "                 mask_fraction=0.2,\n",
    "                 scrambling_strategy=\"\",\n",
    "                 mask_tokens=MASK_TO_ID,\n",
    "                 eos_token=AA_TO_ID[\"<eos>\"],\n",
    "                 troubleshoot=False):\n",
    "        \"\"\"\n",
    "        This class is designed to concatenate sequences based on different scrambling strategies.\n",
    "        It takes a list of sequences, tuples indicating the start and end indices of each sequence,\n",
    "        an optional number of patches to sample, and a scrambling strategy as inputs.\n",
    "        \"\"\"\n",
    "        self.troubleshoot = troubleshoot\n",
    "        self.max_patches = max_patches\n",
    "        self.scrambling_strategy = scrambling_strategy\n",
    "        self.mask_fraction = mask_fraction\n",
    "        self.mask_tokens = mask_tokens\n",
    "        assert len(self.mask_tokens)>=self.max_patches, \"Number of mask tokens must be bigger than max number of patches.\"\n",
    "        self.eos_token = eos_token\n",
    "        # self.eom_token = -3\n",
    "\n",
    "    def concatenate(self, sequences, tuples):\n",
    "        \"\"\"\n",
    "        This function concatenates the sequences based on the scrambling strategy.\n",
    "        \"\"\"\n",
    "        if self.scrambling_strategy==\"no-scramble\":\n",
    "            return np.concatenate([sequences[slice(t[0],t[1])] for t in tuples])\n",
    "        # We could remove this, same as max_patches = 1\n",
    "        if self.scrambling_strategy==\"OpenAI\":\n",
    "            return np.concatenate([self.create_and_concatenate_parts_openAI(sequences, t) for t in tuples])\n",
    "        elif self.scrambling_strategy==\"inpaint\":\n",
    "            return np.concatenate([self.create_and_concatenate_parts_inpaint(sequences, t) for t in tuples])\n",
    "\n",
    "    def split_sequences(self, sequences, t, masked_tuples):\n",
    "        \"\"\"\n",
    "        This function splits the sequences into unmasked and masked parts based on the given tuples.\n",
    "        Args:\n",
    "            t (tuple): The start and end index of each sequence.\n",
    "            masked_tuples (list): A list of tuples specifying the indices for masked regions.\n",
    "        Returns:\n",
    "            unmasked_parts (list): The unmasked parts of the sequences interleaved with -1.\n",
    "            masked_parts (list): The masked parts of the sequences interleaved with -1.\n",
    "        \"\"\"\n",
    "        start, end = t\n",
    "        while False:\n",
    "            unmasked_parts, masked_parts = [], []\n",
    "            for i, region in enumerate(masked_tuples):\n",
    "                mask_token = self.mask_tokens[f\"<mask-{i+1}>\"]\n",
    "                unmasked_parts.extend(sequences[slice(start,region[0])])\n",
    "                unmasked_parts.append(mask_token)\n",
    "                masked_parts.append(mask_token)\n",
    "                masked_parts.extend(sequences[slice(region[0],region[1])])\n",
    "                start = region[1]\n",
    "            unmasked_parts.extend(sequences[slice(start,end)])\n",
    "            if len(masked_tuples) > 0:\n",
    "                unmasked_parts.append(self.eos_token)\n",
    "                # masked_parts.append(self.eom_token)\n",
    "            return unmasked_parts, masked_parts\n",
    "        while True:\n",
    "            masked_parts = [elem for tupl in [([self.mask_tokens[f\"<mask-{i+1}>\"]], sequences[slice(region[0],region[1])]) for i, region in enumerate(masked_tuples)]\n",
    "                            for subl in tupl for elem in subl]\n",
    "            unmasked_parts = [elem for tupl in [(sequences[slice(start,masked_tuples[i][0])], [self.mask_tokens[f\"<mask-{i+1}>\"]]) if i==0\n",
    "                                                else (sequences[slice(masked_tuples[i-1][1],masked_tuples[i][0])], [self.mask_tokens[f\"<mask-{i+1}>\"]]) for i in range(len(masked_tuples))]\n",
    "                            for subl in tupl for elem in subl]\n",
    "            unmasked_parts_end = [elem for sublst in ((sequences[slice(masked_tuples[-1][1],end)], [self.eos_token]) if len(masked_tuples) > 0 else (sequences[slice(start,end)],[]))\n",
    "                                for elem in sublst]\n",
    "            # unmasked_parts.extend(sequences[slice(masked_tuples[-1][1],end)] if len(masked_tuples) > 0 else sequences[slice(start,end)])\n",
    "            unmasked_parts += unmasked_parts_end\n",
    "            return unmasked_parts, masked_parts\n",
    "        while False:\n",
    "            masked_parts = np.concatenate([\n",
    "                self.mask_tokens[f\"<mask-{i+1}>\"] * np.ones((1, region[1] - region[0]), dtype=int)\n",
    "                for i, region in enumerate(masked_tuples)\n",
    "            ], axis=1)\n",
    "\n",
    "            unmasked_parts = np.concatenate([\n",
    "                sequences[slice(start, masked_tuples[i][0])],\n",
    "                self.mask_tokens[f\"<mask-{i+1}>\"] * np.ones((1, masked_tuples[i][0] - start), dtype=int)\n",
    "            ] if i == 0 else [\n",
    "                sequences[slice(masked_tuples[i-1][1], masked_tuples[i][0])],\n",
    "                self.mask_tokens[f\"<mask-{i+1}>\"] * np.ones((1, masked_tuples[i][0] - masked_tuples[i-1][1]), dtype=int)\n",
    "            ] for i in range(len(masked_tuples)))\n",
    "\n",
    "            unmasked_parts_end = np.concatenate([\n",
    "                sequences[slice(masked_tuples[-1][1], end)],\n",
    "                np.array([self.eos_token], dtype=int) if len(masked_tuples) > 0 else np.empty((0,), dtype=int)\n",
    "            ])\n",
    "\n",
    "            unmasked_parts = np.concatenate([unmasked_parts, unmasked_parts_end])\n",
    "            print(masked_parts, unmasked_parts)\n",
    "            return unmasked_parts, masked_parts\n",
    "        while False:    \n",
    "            masked_parts = [\n",
    "            [self.mask_tokens[f\"<mask-{i+1}>\"]] + sequences[slice(region[0], region[1])]\n",
    "            for i, region in enumerate(masked_tuples)\n",
    "            ]\n",
    "            unmasked_parts = [\n",
    "                sequences[slice(start, masked_tuples[0][0])] + [self.mask_tokens[f\"<mask-{i+1}>\"]]\n",
    "                if i == 0\n",
    "                else sequences[slice(masked_tuples[i-1][1], region[0])] + [self.mask_tokens[f\"<mask-{i+1}>\"]]\n",
    "                for i, region in enumerate(masked_tuples)\n",
    "            ]\n",
    "            unmasked_parts_end = [sequences[slice(masked_tuples[-1][1], end)] + [self.eos_token] if masked_tuples else sequences[slice(start, end)]]\n",
    "            unmasked_parts += unmasked_parts_end\n",
    "            return [elem for sublist in unmasked_parts for elem in sublist], [elem for sublist in masked_parts for elem in sublist]\n",
    "   \n",
    "    def sample_lengths(self, start,end):\n",
    "        \"\"\"\n",
    "        Sample a length uniformly from 1 to max_L*self.mask_fraction (must be bigger than 1).\n",
    "        If the length is larger than max_L, return max_L.\n",
    "        \"\"\"\n",
    "        max_L = end-start\n",
    "        # length = np.random.randint(1, max(int(max_L*self.mask_fraction),2))\n",
    "        length = 1+int(random.random() * (max(int(max_L*self.mask_fraction),2)-1))\n",
    "        return min(length, max_L)\n",
    "\n",
    "    def create_and_concatenate_parts_openAI(self, sequences, t):\n",
    "        \"\"\"\n",
    "        This function creates and concatenates parts of the sequences based on the OpenAI scrambling strategy.\n",
    "        It randomly selects two indices within the range of the given tuple,\n",
    "        splits the sequence into three parts based on these indices, and then concatenates them with the \n",
    "        masked patch at the end\n",
    "        \"\"\"\n",
    "        new_tuple = tuple(np.sort(np.random.choice(np.arange(t[0]+1, t[1]), 2, replace=False)))\n",
    "        part1 = sequences[t[0]:new_tuple[0]]\n",
    "        part2 = sequences[new_tuple[0]:new_tuple[1]]\n",
    "        part3 = sequences[new_tuple[1]:t[1]]\n",
    "        return np.concatenate([part1, [self.mask_tokens[\"<mask-1>\"]], part3, [self.mask_tokens[\"<mask-1>\"]], part2])\n",
    "\n",
    "    def create_and_concatenate_parts_inpaint(self, sequences, t):\n",
    "        \"\"\"\n",
    "        This function creates and concatenates parts of the sequences based on the inpaint scrambling strategy.\n",
    "        It randomly selects `2*self.num_patches` indices within the range of the given tuple,\n",
    "        splits the sequence into unmasked and masked parts based on these indices, and then concatenates them.\n",
    "        The concatenation is done by joining all unmaksed parts (interleaved with mask tokens) and afterwards\n",
    "        all masked parts (interleaved with mask tokens). At the en of the unmasked parts, a special token is added\n",
    "        to indicate the end of the unmasked parts, and at the end of the masked parts, a special token is added\n",
    "        to indicate the end of the masked parts.\n",
    "        \"\"\"\n",
    "        ################ DEPRECATED\n",
    "        # masked_ids = np.sort(np.random.choice(np.arange(t[0]+1, t[1]), 2*self.num_patches, replace=False))\n",
    "        # masked_tuples = [(masked_ids[2*i], masked_ids[2*i+1]) for i in range(self.num_patches)]\n",
    "        ################        \n",
    "        # sample num_patches from a discrete poisson distribution with upper limit max_patches\n",
    "        num_patches = 1000\n",
    "        while num_patches > self.max_patches:\n",
    "            num_patches = np.random.poisson(1)\n",
    "        \n",
    "        # sample num_patches starting points for the masked positions (+ final position)\n",
    "        start_patches = sorted(random.sample(range(t[0]+1, t[1]), num_patches)) + [t[1]]\n",
    "        # start_patches = list(np.sort(np.random.choice(np.arange(t[0]+1, t[1]),\n",
    "        #                                               num_patches,\n",
    "        #                                               replace=False))) + [t[1]]\n",
    "        \n",
    "        # sample num_patches lengths of the patches\n",
    "        # len_patches = self.new_sample_lengths(np.array(start_patches)) \n",
    "        len_patches = [self.sample_lengths(start_patches[i],start_patches[i+1])\n",
    "                       for i in range(len(start_patches)-1)]\n",
    "        \n",
    "        # create masked tuples with start and end indices of the patches\n",
    "        masked_tuples = [(start_patches[i], start_patches[i]+len_patches[i]) for i in range(len(start_patches)-1)]\n",
    "        # split the sequences into unmasked and masked parts\n",
    "        unmasked_sequence, masked_sequence = self.split_sequences(sequences, t, masked_tuples)\n",
    "        if self.troubleshoot:\n",
    "            print(f\"For sequence in {t}: sampled {num_patches=}, {start_patches=}, {len_patches=}, {masked_tuples=}\")\n",
    "        # concatenate the unmasked and masked parts\n",
    "        return unmasked_sequence + masked_sequence\n",
    "\n",
    "# Make dataset\n",
    "class Uniclust30_Dataset_old(Dataset):\n",
    "    \"\"\"\n",
    "        Dataset class used to import the Uniclust30 folders.\n",
    "        If `filename` = \"encoded_MSAs.pkl\", it will load the full dataset.\n",
    "        If `filename` = \"encoded_MSAs_subset.pkl\", it will load a small subset of the dataset.\n",
    "        If `sample` = True, it will sample a random number of sequences from each cluster.\n",
    "        If `sample` = False, it will load all the sequences from each cluster (and shuffle them).\n",
    "        To limit the length of the MSAs, set `max_msa_len` to a positive integer.\n",
    "        If `reverse` = True, it will reverse the sequences with probability 0.5 and move the last token to the front.\n",
    "        If `scrambling_strategy` = \"no-scramble\", it will not scramble the sequences and simply concatenate them.\n",
    "        If `scrambling_strategy` = \"OpenAI\", it will scramble the sequences using the OpenAI strategy.\n",
    "        If `scrambling_strategy` = \"inpaint\", it will scramble the sequences using the inpaint strategy. In this case it will use\n",
    "        `max_patches` patches and mask `mask_fraction` of the patches.\n",
    "    \"\"\"\n",
    "    def __init__(self, filename=\"encoded_MSAs_train.pkl\",\n",
    "                 filepath=\"/nvme1/common/OpenProteinSet/\",\n",
    "                 sample=False,\n",
    "                 max_msa_len=-1,\n",
    "                 reverse=False,\n",
    "                 seed=42,\n",
    "                 troubleshoot=False,\n",
    "                 fim_strategy=\"no-scramble\",\n",
    "                 max_patches=5,\n",
    "                 mask_fraction=0.2):\n",
    "        np.random.seed(seed)\n",
    "        self.path = filepath\n",
    "        # self.path_clusters = self.path + \"OpenProteinSet_uniclust30-filtered/\"\n",
    "        self.dataset = pickle.load(open(self.path + filename, \"rb\"))\n",
    "        self.cluster_names = list(self.dataset.keys())\n",
    "        self.sample = sample\n",
    "        self.max_msa_len = max_msa_len\n",
    "        self.reverse = reverse\n",
    "        self.Concatenate = ConcatenateSequences(max_patches=max_patches,\n",
    "                                                mask_fraction=mask_fraction,\n",
    "                                                scrambling_strategy=fim_strategy)\n",
    "        self.troubleshoot = troubleshoot\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.cluster_names)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        # get all the sequences in the cluster\n",
    "        sequences = self.get_sequences(idx)\n",
    "        # get total number of sequences in the cluster and choose how many to sample\n",
    "        orig_num_sequences = self.get_number_of_sequences(sequences)\n",
    "        num_sequences = np.random.randint(1, orig_num_sequences+1) if self.sample else orig_num_sequences\n",
    "        # sample the sequences\n",
    "        sequences = self.sample_sequences(sequences, num_sequences)\n",
    "        # with probability 0.5, reverse the sequences and move the last token to the front\n",
    "        sequences = self.reverse_sequences(sequences) if (self.reverse and np.random.rand() > 0.5) else sequences\n",
    "        # limit the length of the MSA\n",
    "        sequences = sequences[:self.max_msa_len] if self.max_msa_len > 0 else sequences\n",
    "        # convert to tensor\n",
    "        sequences = torch.asarray(sequences, dtype=torch.int64)\n",
    "        if self.troubleshoot:\n",
    "            print(f\"Cluster {idx} has {orig_num_sequences} sequences, of which {num_sequences} sampled now. Total MSA length: {len(sequences)}\")\n",
    "        return dict(input_ids=sequences, labels=sequences)\n",
    "    \n",
    "    def get_sequences(self, idx):\n",
    "        \"\"\"Get the sequences in the cluster with index `idx`.\"\"\"\n",
    "        cluster_name = self.cluster_names[idx]\n",
    "        sequences = self.dataset[cluster_name]\n",
    "        return sequences\n",
    "       \n",
    "    def get_index_start_of_sequences(self, sequences):\n",
    "        \"\"\"Get the positions of the start of each sequence in the cluster.\"\"\"\n",
    "        return np.where(sequences == 0)[0]\n",
    "\n",
    "    def get_number_of_sequences(self, sequences):\n",
    "        \"\"\"Get the number of sequences in the cluster.\"\"\"\n",
    "        return len(self.get_index_start_of_sequences(sequences))\n",
    "    \n",
    "    def reverse_sequences(self, sequence):\n",
    "        \"\"\"Reverse the sequences and move the last token to the front.\"\"\"\n",
    "        sequence = sequence[::-1]\n",
    "        return np.concatenate([sequence[-1:], sequence[:-1]])\n",
    "    \n",
    "    # @profile\n",
    "    def sample_sequences(self, sequences, num_sequences):\n",
    "        \"\"\"Sample `num_sequences` from the sequences in the cluster.\"\"\"\n",
    "        L = len(sequences)\n",
    "        # get the indexes of the start of each sequence\n",
    "        inds = self.get_index_start_of_sequences(sequences)\n",
    "        # check that there are sequences in the cluster and that there are enough of them\n",
    "        assert len(inds) > 0, \"No sequences found in cluster.\"\n",
    "        assert len(inds) >= num_sequences, \"Not enough sequences in cluster.\"\n",
    "        # sample n_sequences randomly from the sequences\n",
    "        which_seqs = np.random.choice(np.arange(len(inds)), num_sequences, replace=False)\n",
    "        # get the tuples of start and end indexes of the sequences\n",
    "        tuples = [(inds[i],inds[i+1]) if i<len(inds)-1 else (inds[i], L) for i in which_seqs]\n",
    "        if self.troubleshoot:\n",
    "            print(f\"Sampled sequences: {tuples}\")\n",
    "        # concatenate the sequences\n",
    "        return self.Concatenate.concatenate(sequences, tuples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import nbdev; nbdev.nbdev_export()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
