{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ProtMamba_ssm.core import *\n",
    "from ProtMamba_ssm.dataloaders import *\n",
    "from ProtMamba_ssm.utils import *\n",
    "from ProtMamba_ssm.modules import *\n",
    "import torch\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_sequences(dataset,\n",
    "                     model,\n",
    "                     n_samples_per_family,\n",
    "                     max_length=1000,\n",
    "                     family_idxs=[],\n",
    "                     parameters_list=[],\n",
    "                     fim_generation=False,\n",
    "                     save_path=None):\n",
    "    \"\"\"\n",
    "    Function to sample sequences from the model. Given a dataset, a list of families (their indexes in the dataset)\n",
    "    and a set of generating parameters, it generates `n_samples_per_family` sequences for each family and each parameter set.\n",
    "    The function returns a dictionary with the following structure:\n",
    "    gen_seqs = {family_idx: {parameters: {sequence: perplexity}}}\n",
    "    The parameters are in a list of tuples with the following structure:    \n",
    "    parameters_list = [(nr_seqs_ctx, temperature, top_k, top_p)]\n",
    "    \"\"\"        \n",
    "    gen_seqs = {}\n",
    "    for j in family_idxs:\n",
    "        gen_seqs[j] = {}\n",
    "        print(\"Sampling sequences for family {}\".format(j))\n",
    "        for params in tqdm(parameters_list):\n",
    "            gen_seqs[j][params] = {}\n",
    "            n_seqs_ctx , temperature, top_k, top_p = params\n",
    "            for _ in range(n_samples_per_family):\n",
    "                # Sample the dataset to get the input\n",
    "                data = dataset[j]\n",
    "                tokens = data[\"input_ids\"][None,:].to(\"cuda\")\n",
    "                pos_ids = data[\"position_ids\"][None,:].to(\"cuda\")\n",
    "                start_seqs = torch.argwhere(tokens[0]==0)[:,0].cpu().numpy()\n",
    "                if fim_generation:\n",
    "                    n_seqs_ctx = len(start_seqs)-1 if len(start_seqs) < n_seqs_ctx+1 else n_seqs_ctx\n",
    "                    L = start_seqs[n_seqs_ctx+1] if n_seqs_ctx>0 else start_seqs[n_seqs_ctx]\n",
    "                    context_tokens, context_pos_ids, tokens_fim, pos_ids_fim, is_fim_dict = prepare_dataset_for_fim_generation(tokens[:,:L], pos_ids[:,:L])\n",
    "                    is_fim = is_fim_dict\n",
    "                else:\n",
    "                    n_seqs_ctx = len(start_seqs) if len(start_seqs) < n_seqs_ctx else n_seqs_ctx\n",
    "                    L = start_seqs[n_seqs_ctx]+1\n",
    "                    context_tokens = tokens[:,:L]\n",
    "                    context_pos_ids = pos_ids[:,:L]\n",
    "                    is_fim=False\n",
    "                # Generate the new sequence               \n",
    "                output = generate_sequence(model,\n",
    "                                        context_tokens,\n",
    "                                        position_ids=context_pos_ids,\n",
    "                                        is_fim=is_fim,\n",
    "                                        max_length=(L+max_length),\n",
    "                                        temperature=temperature,\n",
    "                                        top_k=top_k,\n",
    "                                        top_p=top_p,\n",
    "                                        return_dict_in_generate=True,\n",
    "                                        output_scores=True,\n",
    "                                        eos_token_id=torch.tensor([AA_TO_ID[\"<cls>\"]]).to(\"cuda\"),\n",
    "                                        device=\"cuda\")\n",
    "                # Get the perplexity of the generated sequence\n",
    "                output_seq = output[\"generated\"] \n",
    "                loss = torch.nn.functional.cross_entropy(torch.from_numpy(output[\"scores\"]).permute(0, 2, 1),\n",
    "                                                        torch.from_numpy(output[\"generated_tokens\"][0][None,:]))\n",
    "                # save only sequences with length < max_length\n",
    "                if len(output_seq[0]) < max_length:\n",
    "                    if fim_generation:\n",
    "                        original_input = output[\"input\"][0].split(\"<cls>\")[-1]\n",
    "                        original_input_continuation = decode_sequence(tokens_fim[0].cpu().numpy())+\"<cls>\"\n",
    "                        generated_input_continuation = output_seq[0]\n",
    "                        if len(original_input_continuation) == len(generated_input_continuation):\n",
    "                            outp_str = reorder_masked_sequence(original_input + generated_input_continuation)\n",
    "                            gen_seqs[j][params][outp_str] = {\"original_input\": original_input,\n",
    "                                                            \"original_input_fim\": original_input_continuation,\n",
    "                                                            \"generated_input_fim\": generated_input_continuation,\n",
    "                                                            \"perplexity\": torch.exp(loss).item()}\n",
    "                        else:\n",
    "                            print(\"Lengths of original and generated FIM do not match. {} vs {}\".format(original_input_continuation, generated_input_continuation))\n",
    "                    else:\n",
    "                        gen_seqs[j][params][output_seq[0]] = {\"perplexity\": torch.exp(loss).item()}\n",
    "        if save_path is not None:\n",
    "            with open(save_path, \"wb\") as f:\n",
    "                pickle.dump(gen_seqs, f)\n",
    "    return gen_seqs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the dataset used for training\n",
    "dataset_name = \"encoded_MSAs_test.pkl\"\n",
    "fim_strategy = \"multiple_span\"\n",
    "mask_fraction = 0.2\n",
    "dataset = Uniclust30_Dataset(dataset_name,\n",
    "                            filepath=\"/data1/common/OpenProteinSet/\",\n",
    "                            sample=False,\n",
    "                            max_msa_len=-1,\n",
    "                            max_patches=5,\n",
    "                            mask_fraction=mask_fraction,\n",
    "                            fim_strategy=fim_strategy,\n",
    "                            max_position_embeddings=2048,\n",
    "                            add_position_ids=\"1d\")\n",
    "    \n",
    "# Load pretrained model\n",
    "checkpoint = \"../../nbs/results/train_100M_FIM_restart-spikes_merged/checkpoint_131k-3750\"\n",
    "model = load_model(checkpoint,\n",
    "                   model_class=MambaLMHeadModelwithPosids,\n",
    "                   device=\"cuda\",\n",
    "                   dtype=torch.bfloat16,\n",
    "                   checkpoint_mixer=False\n",
    "                   )\n",
    "model = model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sample from different families using different generation methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# family_idxs = [0, 1, 2, 3, 4, 5, 6, 8, 9, 10]\n",
    "family_idxs = [11, 13, 14, 16, 18] # [20, 21, 23, 24, 25] # \n",
    "\n",
    "# # parameters: (nr_seqs_ctx, temperature, top_k, top_p)\n",
    "# parameters_list =  [(100,1.,10,0.), (-1,0.9,10,0.95)]\n",
    "parameters_list = [(10,1.,10,0.), (10,1.,15,0.), (10,1.,10,0.95), (10,0.9,10,0.95), (10,0.8,10,0.9),\n",
    "                   (100,1.,10,0.), (100,1.,15,0.), (100,1.,10,0.95), (100,0.9,10,0.95), (100,0.8,10,0.9),\n",
    "                   (500,1.,10,0.), (500,1.,15,0.), (500,1.,10,0.95), (500,0.9,10,0.95), (500,0.8,10,0.9),\n",
    "                   (1000,1.,10,0.), (1000,1.,15,0.), (1000,1.,10,0.95), (1000,0.9,10,0.95), (1000,0.8,10,0.9),\n",
    "                   (-1,1.,10,0.), (-1,1.,15,0.), (-1,1.,10,0.95), (-1,0.9,10,0.95), (-1,0.8,10,0.9)]\n",
    "n_samples_per_family = 100\n",
    "generate_fim = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in family_idxs:\n",
    "    data = dataset[i]\n",
    "    tokens = data[\"input_ids\"][None,:].to(\"cuda\")\n",
    "    start_seqs = torch.argwhere(tokens[0]==0)[:,0].cpu().numpy()\n",
    "    inds = start_seqs < 131072\n",
    "    print(f\"Family name: {dataset.cluster_names[i]}\\t\", \"\\tNumber of sequences: \", len(start_seqs), \"\\tNum sequences < 131072: \", len(start_seqs[inds]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "end_str = \"_fim\" if generate_fim else \"_full\"\n",
    "save_path = f\"figures/generated_sequences/check-131k(11-18)_gen_seqs{end_str}.pkl\"\n",
    "gen_seqs = sample_sequences(dataset,\n",
    "                            model,\n",
    "                            n_samples_per_family=n_samples_per_family,\n",
    "                            max_length=1000,\n",
    "                            family_idxs=family_idxs,\n",
    "                            parameters_list=parameters_list,\n",
    "                            fim_generation=generate_fim,\n",
    "                            save_path=save_path\n",
    "                            )\n",
    "with open(save_path, \"wb\") as f:\n",
    "    pickle.dump(gen_seqs, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ProtMamba",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
