{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# utils\n",
    "\n",
    "> Fill in a module description here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "\n",
    "# Constants\n",
    "AA_TO_ID = {'<cls>': 0,\n",
    "            '<pad>': 1,\n",
    "            '<eos>': 2,\n",
    "            '<unk>': 3,\n",
    "            'L': 4,\n",
    "            'A': 5,\n",
    "            'G': 6,\n",
    "            'V': 7,\n",
    "            'S': 8,\n",
    "            'E': 9,\n",
    "            'R': 10,\n",
    "            'T': 11,\n",
    "            'I': 12,\n",
    "            'D': 13,\n",
    "            'P': 14,\n",
    "            'K': 15,\n",
    "            'Q': 16,\n",
    "            'N': 17,\n",
    "            'F': 18,\n",
    "            'Y': 19,\n",
    "            'M': 20,\n",
    "            'H': 21,\n",
    "            'W': 22,\n",
    "            'C': 23,\n",
    "            'X': 24,\n",
    "            'B': 25,\n",
    "            'U': 26,\n",
    "            'Z': 27,\n",
    "            'O': 28,\n",
    "            '.': 29,\n",
    "            '-': 30,\n",
    "            '<null_1>': 31,\n",
    "            '<mask>': 32}\n",
    "\n",
    "MASK_TO_ID = {\"<mask-1>\": 33,\n",
    "              \"<mask-2>\": 34,\n",
    "              \"<mask-3>\": 35,\n",
    "              \"<mask-4>\": 36,\n",
    "              \"<mask-5>\": 37,}\n",
    "\n",
    "AA_TO_ID.update(MASK_TO_ID)\n",
    "\n",
    "ID_TO_AA = {v: k for k, v in AA_TO_ID.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import numpy as np\n",
    "import torch\n",
    "from Bio import SeqIO\n",
    "\n",
    "# Encoder\n",
    "def encode_sequence(sequence):\n",
    "    \"\"\"Tokenize a sequence of amino acids and add a cls token at the beginning.\"\"\"\n",
    "    tokenized_sequence = [AA_TO_ID[aa] if aa in AA_TO_ID else AA_TO_ID['<unk>'] for aa in sequence]\n",
    "    return [AA_TO_ID['<cls>']] + tokenized_sequence\n",
    "\n",
    "def decode_sequence(sequence):\n",
    "    \"\"\"Decode a sequence of tokens.\"\"\"\n",
    "    return \"\".join([ID_TO_AA[token] if token in ID_TO_AA else \"<unk>\" for token in sequence])\n",
    "\n",
    "def clean_sequence(sequence):\n",
    "    \"\"\"Remove gaps and convert all residues to upper case.\"\"\"\n",
    "    return sequence.replace(\"-\", \"\").upper()\n",
    "\n",
    "def tokenizer(sequence_list, concatenate=True):\n",
    "    \"\"\"Tokenize a collection of sequences. If the sequences are aligned, the gaps will be removed\n",
    "    and the insertions (lower case) will be promoted to upper case.\"\"\"\n",
    "    # clean and encode all sequences\n",
    "    sequence_list = [encode_sequence(clean_sequence(sequence)) for sequence in sequence_list]\n",
    "    if concatenate:\n",
    "        # concatenate all sequences\n",
    "        sequences = np.concatenate(sequence_list)\n",
    "        # convert to tensor and add batch dimension\n",
    "        return torch.asarray(sequences, dtype=torch.int64)[None,:]\n",
    "    else:\n",
    "        return [torch.asarray(sequence, dtype=torch.int64) for sequence in sequence_list]\n",
    "\n",
    "\n",
    "def reorder_masked_sequence(mask_seq, return_ids=False):\n",
    "    \"\"\"\n",
    "    Reorder a masked sequence to fill the masked positions with the tokens\n",
    "    that should be there but are positioned after the <eos> token.\n",
    "    \"\"\"\n",
    "    mask_seq = mask_seq.split(\"<cls>\")[0]\n",
    "    try:\n",
    "        # Split the sequence and masks\n",
    "        seq, masks = mask_seq.split(\"<eos>\")\n",
    "    except:\n",
    "        return mask_seq\n",
    "    full_seq = \"\"\n",
    "    ids_mask = []\n",
    "    # Iterate over each mask tag\n",
    "    for mm in [\"<mask-1>\", \"<mask-2>\", \"<mask-3>\", \"<mask-4>\", \"<mask-5>\",\"<mask-?>\"]:\n",
    "        try:\n",
    "            # Split the sequence in before and after the mask tag\n",
    "            seq1, seq2 = seq.split(mm)\n",
    "            if mm==\"<mask-1>\":\n",
    "                # If the mask is the first one, add the sequence before the mask and update the masks\n",
    "                masks = masks.split(\"<mask-1>\")[1]\n",
    "                full_seq += seq1\n",
    "            else:\n",
    "                # If the mask is not the first one, insert the mask between the two sequence parts\n",
    "                masks1, masks2 = masks.split(mm)\n",
    "                ids_mask += [(len(full_seq), len(full_seq)+len(masks1))]\n",
    "                full_seq += masks1 + seq1\n",
    "                # Update the masks\n",
    "                masks = masks2 \n",
    "            # Update the sequence with the part after the mask\n",
    "            seq = seq2\n",
    "        except:\n",
    "            # If the mask is not found, add the remaining sequence\n",
    "            ids_mask += [(len(full_seq), len(full_seq)+len(masks))]\n",
    "            full_seq += masks + seq\n",
    "            break\n",
    "    if return_ids:\n",
    "        return full_seq, ids_mask\n",
    "    return full_seq\n",
    "\n",
    "def load_from_file(file_path):\n",
    "    \"\"\"Load a collection of sequences from an a3m file.\"\"\"\n",
    "    with open(file_path, \"r\") as f:\n",
    "        sequences = [str(record.seq) for record in SeqIO.parse(f, \"fasta\")]\n",
    "    return sequences\n",
    "\n",
    "def generate_sequence(model, tokens, position_ids=None, seq_position_ids=None, is_fim=False, max_length=2000, temperature=1., top_p=0.0, top_k=1,\n",
    "                      return_dict_in_generate=False, output_scores=False, eos_token_id=AA_TO_ID[\"<cls>\"], device=\"cuda\"):\n",
    "    \"\"\"Generating, either greedy or with top-k or top-p sampling.\n",
    "    If top-k = 0, don't limit the number of candidates (pure sampling).\n",
    "    Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,\n",
    "    then top-p. We assume that all sequences in the same batch have the same length.\n",
    "    \"\"\"\n",
    "    input_ids = tokens.to(device)\n",
    "    position_ids = position_ids.to(device) if position_ids is not None else None\n",
    "    seq_position_ids = seq_position_ids.to(device) if seq_position_ids is not None else None\n",
    "    # generate sequence\n",
    "    out = model.generate(input_ids=input_ids,\n",
    "                         position_ids=position_ids,\n",
    "                         seq_position_ids=seq_position_ids,\n",
    "                         is_fim=is_fim,\n",
    "                         max_length=max_length,\n",
    "                         temperature=temperature,\n",
    "                         top_p=top_p,\n",
    "                         top_k=top_k,\n",
    "                         return_dict_in_generate=return_dict_in_generate,\n",
    "                         output_scores=output_scores,\n",
    "                         eos_token_id=eos_token_id)\n",
    "    sequences = out.sequences\n",
    "    dic = {\"input\": [decode_sequence(seq) for seq in sequences[:, :input_ids.shape[-1]].cpu().numpy()],\n",
    "            \"generated\": [decode_sequence(seq) for seq in sequences[:, input_ids.shape[-1]:].cpu().numpy()],\n",
    "            \"input_tokens\": [seq for seq in sequences[:, :input_ids.shape[-1]].cpu().numpy()],\n",
    "            \"generated_tokens\": [seq for seq in sequences[:, input_ids.shape[-1]:].cpu().numpy()]}\n",
    "    if output_scores:\n",
    "        dic[\"scores\"] = np.array([el.to(torch.float32).cpu().numpy() for el in out.scores]).transpose(1, 0, 2)\n",
    "    return dic\n",
    "\n",
    "def prepare_dataset_for_fim_generation(tokens, pos_ids):\n",
    "    \"\"\"\n",
    "    Function to transform the tokenized training dataset into a format that can be used for FIM generation.\n",
    "    Splits the input tokens and pos_ids into the FIM part (of the last sequence) and the context part (all\n",
    "    the previous sequences and the masked part of the last sequence).\n",
    "    Also returns a dictionary with the positions of the mask tokens in the FIM part.\n",
    "    \"\"\"\n",
    "    def find_mask_positions(tokens_fim):\n",
    "        \"\"\"\n",
    "        Function to find the positions of the mask tokens in the FIM part of the last sequence.\n",
    "        \"\"\"\n",
    "        bool_mask = None\n",
    "        inds_masks = []\n",
    "        for ind in MASK_TO_ID.values():\n",
    "            tmp_bool = tokens_fim[0].cpu().numpy() == ind\n",
    "            bool_mask = tmp_bool if bool_mask is None else bool_mask | tmp_bool\n",
    "            inds_masks += [ind]\n",
    "        return bool_mask, inds_masks\n",
    "    # find where the FIM part of the last sequence starts\n",
    "    start_last_fim = np.where(tokens[0].cpu().numpy() == AA_TO_ID[\"<eos>\"])[0][-1]\n",
    "    start_next_seqs = np.where(tokens[0,start_last_fim+1:].cpu().numpy() == AA_TO_ID[\"<cls>\"])[0]\n",
    "    end_last_fim = start_last_fim+ 1 +start_next_seqs[0] if len(start_next_seqs) > 0 else tokens.shape[1]\n",
    "    # split tokens and pos_ids into FIM part and context part\n",
    "    tokens_to_fim = tokens[:,:start_last_fim+1]\n",
    "    pos_ids_to_fim = pos_ids[:,:start_last_fim+1]\n",
    "    tokens_fim = tokens[:,start_last_fim+1:end_last_fim]\n",
    "    pos_ids_fim = pos_ids[:,start_last_fim+1:end_last_fim]\n",
    "    # find positions of mask tokens\n",
    "    bool_mask, inds_masks = find_mask_positions(tokens_fim)\n",
    "    masked_positions = pos_ids_fim[0,bool_mask]\n",
    "    mask_dict = {ind: int(pos) for ind, pos in zip(inds_masks, masked_positions)}\n",
    "    return tokens_to_fim, pos_ids_to_fim, tokens_fim, pos_ids_fim, mask_dict\n",
    "\n",
    "def prepare_tokens(context_tokens,\n",
    "                   target_tokens,\n",
    "                   target_pos_ids,\n",
    "                   DatasetClass,\n",
    "                   num_sequences=1,\n",
    "                   fim_strategy=\"no-scramble\", # \"multiple_span\"\n",
    "                   mask_fraction=0.2,\n",
    "                   max_patches=5,\n",
    "                   add_position_ids=\"1d\"): \n",
    "    \"\"\"Prepare the tokens for the model by applying the FIM strategy and masking the tokens.\n",
    "    It uses custom tokenized sequences and position ids.\"\"\"\n",
    "\n",
    "    data_class = DatasetClass(None,\n",
    "                            fim_strategy=fim_strategy,\n",
    "                            mask_fraction=mask_fraction,\n",
    "                            max_patches=max_patches,\n",
    "                            add_position_ids=add_position_ids)\n",
    "    seq, pos_ids = data_class.sample_sequences(context_tokens.numpy()[0], num_sequences=num_sequences)\n",
    "    # convert to tensor and add batch dimension\n",
    "    seq = torch.asarray(seq, dtype=torch.int64)[None,:]\n",
    "    pos_ids = torch.asarray(pos_ids, dtype=torch.int64)[None,:]\n",
    "    seq = torch.cat([seq, target_tokens], dim=1)\n",
    "    pos_ids = torch.cat([pos_ids, target_pos_ids], dim=1)\n",
    "    return seq, pos_ids\n",
    "\n",
    "def prepare_target(target, use_fim=None):\n",
    "    \"\"\"Prepare the target sequence for the model using a custom tokenized sequence.\n",
    "    use_fim is a dictionary with the positions that should be masked.\n",
    "    use_fim = {\"<cls>\": 1} _-> start to generate autoregressively from 1st position\n",
    "    use_fim = {\"<cls>\": 10} -> start to generate autoregressively from 10th position\n",
    "    use_fim = {\"<mask-1>\": ((10,13), 6)} -> mask positions from 10 to 13 (i.e. 10,11,12) and fill it with 6 tokens,\n",
    "    use_fim = {\"<mask-1>\": ((10,13), 6), \"<mask-2>\": ((15,20), 2)} -> mask positions from 10 to 13 and 15 to 20 and fill it with 6 and 2 tokens\n",
    "    \"\"\"\n",
    "    if \"<cls>\" in use_fim:\n",
    "        target = target[:,:use_fim[\"<cls>\"]]\n",
    "        pos_ids = torch.arange(target.shape[1], dtype=torch.int64)[None,:]\n",
    "        assert target.shape[1] == pos_ids.shape[1]\n",
    "        return target, pos_ids\n",
    "    else:\n",
    "        is_fim_dict = {}\n",
    "        pos_ids = torch.arange(target.shape[1], dtype=torch.int64)[None,:] # default position ids\n",
    "        diff_length = 0\n",
    "        for mask in use_fim:\n",
    "            assert \"mask\" in mask\n",
    "            mask_positions, length = use_fim[mask]\n",
    "            # update mask_positions to take into account the inserted parts\n",
    "            mask_positions = (mask_positions[0]-diff_length, mask_positions[1]-diff_length)\n",
    "            diff_length = mask_positions[1] - mask_positions[0]\n",
    "            new_target = torch.cat([target[:,:mask_positions[0]],\n",
    "                                    torch.full((target.shape[0], 1), AA_TO_ID[mask], dtype=torch.int64),\n",
    "                                    target[:,mask_positions[1]:]], dim=1)\n",
    "            new_pos_ids = torch.cat([pos_ids[:,:mask_positions[0]+1],\n",
    "                                    pos_ids[:,mask_positions[1]:]+length-diff_length], dim=1)\n",
    "            is_fim_dict[AA_TO_ID[mask]] = pos_ids[:,mask_positions[0]].squeeze().item()\n",
    "            target = new_target\n",
    "            pos_ids = new_pos_ids\n",
    "            diff_length -= 1\n",
    "\n",
    "        new_target = torch.cat([target,\n",
    "                                torch.full((target.shape[0], 1), AA_TO_ID[\"<eos>\"], dtype=torch.int64)], dim=1)\n",
    "        new_pos_ids = torch.cat([pos_ids,\n",
    "                                torch.full((target.shape[0], 1), 0, dtype=torch.int64)], dim=1)\n",
    "        assert new_target.shape[1] == new_pos_ids.shape[1]\n",
    "        return new_target, new_pos_ids, is_fim_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "\n",
    "from tensorboard.backend.event_processing import event_accumulator\n",
    "from tensorboard.backend.event_processing.event_accumulator import ScalarEvent\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import glob\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def load_tensorboard_data(path):\n",
    "    \"\"\"\n",
    "    Load the TensorBoard data\n",
    "    \"\"\"\n",
    "    ea = event_accumulator.EventAccumulator(path)\n",
    "    ea.Reload()\n",
    "    # Assuming you're interested in scalar summaries; adjust if otherwise\n",
    "    # list all tags\n",
    "    tags = ea.Tags()['scalars']\n",
    "    \n",
    "    scalars = {tag: ea.scalars.Items(tag) for tag in tags}\n",
    "    return scalars\n",
    "\n",
    "def filter_datapoints(scalars, condition):\n",
    "    \"\"\"\n",
    "    Filter out the datapoint you want to delete (customize this logic)\n",
    "    \"\"\"\n",
    "    # Example condition: lambda x: x.step != step_to_delete\n",
    "    return {tag: [s for s in scalars if not condition(s)] for tag, scalars in scalars.items()}\n",
    "\n",
    "def save_to_tensorboard(filtered_data, output_path):\n",
    "    \"\"\"\n",
    "    Save modified data back to a new TensorBoard file\n",
    "    \"\"\"\n",
    "    writer = SummaryWriter(output_path) \n",
    "    for tag, scalars in filtered_data.items():\n",
    "        for data in scalars:\n",
    "            writer.add_scalar(tag, data.value, global_step=data.step, walltime=data.wall_time)\n",
    "\n",
    "def merge_loggings(directory, output_path, plot_metric=None):\n",
    "    \"\"\"\n",
    "    Merge all the TensorBoard files in a directory into a single file.\n",
    "    Keeps only the metrics with latest wall time for each step (i.e. the last logged value for each step)\n",
    "    \"\"\"\n",
    "    # Find all the TensorBoard files\n",
    "    def find_files(directory, pattern):\n",
    "        return glob.glob(f\"{directory}/{pattern}\")\n",
    "    all_paths = find_files(directory, \"events.out.tfevents.*\")\n",
    "    # Merge all the data\n",
    "    best_wall_times = {}\n",
    "    updated_metrics = {}\n",
    "    for elem in all_paths:\n",
    "        try:\n",
    "            # Load the data from one logging\n",
    "            scalars = load_tensorboard_data(elem)\n",
    "            # Make a dictionary with step number as key\n",
    "            all_metrics = {k: {s.step: s for s in scalars[k]} for k in scalars.keys()}\n",
    "            if plot_metric is not None:\n",
    "                plt.plot([s.wall_time for s in scalars[plot_metric]], [s.value for s in scalars[plot_metric]])\n",
    "            # iterate over all the metrics\n",
    "            for k in all_metrics.keys():\n",
    "                if k not in updated_metrics.keys():\n",
    "                    updated_metrics[k] = {}\n",
    "                if k not in best_wall_times:\n",
    "                    best_wall_times[k] = {}\n",
    "                # Get wall time of each step\n",
    "                steps_time = {step: s.wall_time for step,s in all_metrics[k].items()}\n",
    "                # iterate over steps and pick only the metrics associated with the best wall time\n",
    "                for key, value in steps_time.items():\n",
    "                    if key not in updated_metrics[k]:\n",
    "                        best_wall_times[k][key] = value\n",
    "                        updated_metrics[k][key] = all_metrics[k][key]\n",
    "                    elif value > best_wall_times[k][key]:\n",
    "                        best_wall_times[k][key] = value\n",
    "                        updated_metrics[k][key] = all_metrics[k][key]\n",
    "                    else:\n",
    "                        continue\n",
    "        except:\n",
    "            print(\"Could not load.\\t\", elem.split(\"/\")[-1])\n",
    "    # Sort the metrics by step\n",
    "    new_logging = {k: list(updated_metrics[k].values()) for k in updated_metrics.keys()}\n",
    "    new_logging = {k: sorted(v, key=lambda x: x.step) for k, v in new_logging.items()}\n",
    "    # Save the merged data\n",
    "    save_to_tensorboard(new_logging, output_path)\n",
    "    plt.title(f\"Metric: {plot_metric}\")\n",
    "    plt.show()\n",
    "    return new_logging\n",
    "    \n",
    "def concatenate_loggings(logging1_path, logging2_path, step_range1, step_range2, output_path):\n",
    "    \"\"\"\n",
    "    Concatenate the two loggings, assuming they have the same metrics. Use steps from step_range1[0] to step_range1[1]\n",
    "    for logging1 and from step_range2[0] to step_range2[1] for logging2.\n",
    "    Change the step numbers of logging2 to be continuous with logging1. and verify that the steps taken in each logging\n",
    "    are the ones specified by step_range1 and step_range2.\n",
    "    \"\"\"\n",
    "    logging1 = load_tensorboard_data(logging1_path)\n",
    "    logging2 = load_tensorboard_data(logging2_path)\n",
    "    if step_range1 is None:\n",
    "        k = list(logging1.keys())[0]\n",
    "        step_range1 = (logging1[k][0].step, logging1[k][-1].step)\n",
    "    if step_range2 is None:\n",
    "        k = list(logging2.keys())[0]\n",
    "        step_range2 = (logging2[k][0].step, logging2[k][-1].step)\n",
    "    new_logging = {}\n",
    "    for key in logging1.keys():\n",
    "        new_logging[key] = [el for el in logging1[key] if el.step >= step_range1[0] and el.step < step_range1[1]]\n",
    "        for el in logging2[key]:\n",
    "            if el.step >= step_range2[0] and el.step <= step_range2[1]:\n",
    "                # not possible to assign to the step attribute of the object, make a new ScalarEvent object identical\n",
    "                # to el but with the step attribute changed\n",
    "                new_step_value = el.step - step_range2[0] + step_range1[1]\n",
    "                new_el = ScalarEvent(step=new_step_value, wall_time=el.wall_time, value=el.value)        \n",
    "                \n",
    "                new_logging[key].append(new_el)\n",
    "        new_logging[key] = sorted(new_logging[key], key=lambda x: x.step)\n",
    "    save_to_tensorboard(new_logging, output_path)\n",
    "    return new_logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "\n",
    "def print_number_of_parameters(model):\n",
    "    print(\"Number of trainable parameters: \", sum(p.numel() for p in model.parameters() if p.requires_grad))\n",
    "\n",
    "def find_fim_indices(is_cls_tokens, is_eos_tokens):\n",
    "    \"\"\"Function to find the indices of the FIM tokens in the sequences.\n",
    "    \"\"\"\n",
    "    # add a cls token at the beginning\n",
    "    is_cls_tokens = torch.cat([torch.ones_like(is_cls_tokens[:, :1]), is_cls_tokens], dim=1)\n",
    "    is_eos_tokens = torch.cat([torch.zeros_like(is_eos_tokens[:, :1]), is_eos_tokens], dim=1)\n",
    "    # both eos and cls tokens\n",
    "    bol = is_cls_tokens | is_eos_tokens\n",
    "    tmp = torch.zeros_like(is_cls_tokens, dtype=torch.int)\n",
    "    tmp[torch.nonzero(is_cls_tokens, as_tuple=True)] = 1\n",
    "    tmp[torch.nonzero(is_eos_tokens, as_tuple=True)] = -1\n",
    "    bol1 = torch.clone(bol)\n",
    "    for batch_ind in range(tmp.size(0)):\n",
    "        tmp1 = tmp[batch_ind,bol[batch_ind]]\n",
    "        # find all positions where a 1 if preceeded by a -1\n",
    "        tmp1 = tmp1[:-1]*tmp1[1:]\n",
    "        # add the first element to make the sequence start with a 1\n",
    "        tmp1 = torch.cat([torch.ones_like(tmp1[:1]).to(tmp1.device), tmp1])\n",
    "        new_bol = tmp1<0\n",
    "        # bool array True only in the positions where a 1 is preceeded by a -1\n",
    "        bol1[batch_ind,bol[batch_ind]] = False if new_bol.size(0) == 0 else new_bol\n",
    "    cumulative_sum = torch.cumsum(bol1, dim=1)\n",
    "    # Use modulo operation to get the desired tensor\n",
    "    bol2 = cumulative_sum % 2 == 1\n",
    "    bol2[is_eos_tokens]= False\n",
    "    return bol2[:,1:]\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    predictions = torch.tensor(predictions).permute(0, 2, 1)\n",
    "    labels = torch.tensor(labels)\n",
    "    # shift labels to align them with predictions and remove last prediction to match the length\n",
    "    predictions = predictions[:, :, :-1].contiguous()\n",
    "    labels = labels[:, 1:].contiguous()\n",
    "    # compute unreduced elementwise loss\n",
    "    unreduced_loss = torch.nn.functional.cross_entropy(predictions, labels, reduction=\"none\")\n",
    "    # compute reconstruction accuracy\n",
    "    reconstruction = (predictions.argmax(1) == labels)\n",
    "\n",
    "    # start and end tokens\n",
    "    is_cls_tokens = (labels == AA_TO_ID[\"<cls>\"])\n",
    "    is_eos_tokens = (labels == AA_TO_ID[\"<eos>\"])\n",
    "    # fill in the middle tokens\n",
    "    if False:\n",
    "        fim_tokens = torch.zeros(is_cls_tokens.size(0), is_cls_tokens.size(1), dtype=torch.bool)\n",
    "        in_mask_vector = torch.zeros(is_cls_tokens.size(0), dtype=torch.bool)\n",
    "        for j in range(is_cls_tokens.size(1)):\n",
    "            in_mask_vector = in_mask_vector & ~is_cls_tokens[:, j]\n",
    "            fim_tokens[:, j] = in_mask_vector\n",
    "            in_mask_vector = in_mask_vector | is_eos_tokens[:, j]\n",
    "    fim_tokens = find_fim_indices(is_cls_tokens, is_eos_tokens)\n",
    "        \n",
    "    number_sequences = torch.cumsum(torch.cat([torch.zeros(is_cls_tokens.size(0),1, dtype=torch.int32), is_cls_tokens[:,:-1]],1), -1)\n",
    "    # fist, second and last sequence tokens\n",
    "    first_sequence_tokens = ((~fim_tokens & (labels < 33)) | fim_tokens) & (number_sequences == 0)\n",
    "    second_sequence_tokens = ((~fim_tokens & (labels < 33)) | fim_tokens) & (number_sequences == 1)\n",
    "    last_sequence_tokens = ((~fim_tokens & (labels < 33)) | fim_tokens) & (number_sequences == (number_sequences.max(1).values[:, None] - 1))\n",
    "    # end of mask tokens\n",
    "    end_of_masks = (fim_tokens & (labels > 33)) | is_cls_tokens | is_eos_tokens\n",
    "    return {\"loss/all\": torch.mean(unreduced_loss).item(),\n",
    "            \"loss/end_span\": torch.mean(unreduced_loss[end_of_masks]).item(),\n",
    "            \"perplexity/seq\": torch.mean(torch.exp(torch.mean(unreduced_loss, dim=1))).item(),\n",
    "            \"perplexity/end_span\": torch.exp(torch.mean(unreduced_loss[end_of_masks])).item(),\n",
    "            \"perplexity/batch\": torch.exp(torch.mean(unreduced_loss)).item(),\n",
    "            \"perplexity/first_seq\": torch.exp(torch.mean(unreduced_loss[first_sequence_tokens])).item(),\n",
    "            \"perplexity/second_seq\": torch.exp(torch.mean(unreduced_loss[second_sequence_tokens])).item(),\n",
    "            \"perplexity/last_seq\": torch.exp(torch.mean(unreduced_loss[last_sequence_tokens])).item(),\n",
    "            \"perplexity/fim\": torch.exp(torch.mean(unreduced_loss[fim_tokens])).item(),\n",
    "            \"reconstruction/all\": torch.mean(reconstruction.float()).item(),\n",
    "            \"reconstruction/end_span\": torch.mean(reconstruction[end_of_masks].float()).item(),\n",
    "            \"reconstruction/first_seq\": torch.mean(reconstruction[first_sequence_tokens].float()).item(),\n",
    "            \"reconstruction/second_seq\": torch.mean(reconstruction[second_sequence_tokens].float()).item(),\n",
    "            \"reconstruction/last_seq\": torch.mean(reconstruction[last_sequence_tokens].float()).item(),\n",
    "            \"reconstruction/fim\": torch.mean(reconstruction[fim_tokens].float()).item(),}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import nbdev; nbdev.nbdev_export()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
