{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ProtMamba_ssm.core import *\n",
    "from ProtMamba_ssm.dataloaders import *\n",
    "from ProtMamba_ssm.utils import *\n",
    "from ProtMamba_ssm.modules import *\n",
    "import torch\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "def smooth(scalars: list[float], weight: float) -> list[float]:\n",
    "    \"\"\"\n",
    "    Exponential moving average\n",
    "    \"\"\"\n",
    "    last = 0\n",
    "    smoothed = []\n",
    "    num_acc = 0\n",
    "    for next_val in scalars:\n",
    "        last = last * weight + (1 - weight) * next_val\n",
    "        num_acc += 1\n",
    "        # de-bias\n",
    "        debias_weight = 1\n",
    "        if weight != 1:\n",
    "            debias_weight = 1 - math.pow(weight, num_acc)\n",
    "        smoothed_val = last / debias_weight\n",
    "        smoothed.append(smoothed_val)\n",
    "\n",
    "    return smoothed\n",
    "\n",
    "def find_fim_indices(is_cls_tokens, is_eos_tokens):\n",
    "    # add a cls token at the beginning\n",
    "    is_cls_tokens = torch.cat([torch.ones_like(is_cls_tokens[:, :1]), is_cls_tokens], dim=1)\n",
    "    is_eos_tokens = torch.cat([torch.zeros_like(is_eos_tokens[:, :1]), is_eos_tokens], dim=1)\n",
    "    # both eos and cls tokens\n",
    "    bol = is_cls_tokens | is_eos_tokens\n",
    "    tmp = torch.zeros_like(is_cls_tokens, dtype=torch.int)\n",
    "    tmp[torch.nonzero(is_cls_tokens, as_tuple=True)] = 1\n",
    "    tmp[torch.nonzero(is_eos_tokens, as_tuple=True)] = -1\n",
    "    bol1 = torch.clone(bol)\n",
    "    for batch_ind in range(tmp.size(0)):\n",
    "        tmp1 = tmp[batch_ind,bol[batch_ind]]\n",
    "        # find all positions where a 1 if preceeded by a -1\n",
    "        tmp1 = tmp1[:-1]*tmp1[1:]\n",
    "        # add the first element to make the sequence start with a 1\n",
    "        tmp1 = torch.cat([torch.ones_like(tmp1[:1]).to(tmp1.device), tmp1])\n",
    "        new_bol = tmp1<0\n",
    "        # bool array True only in the positions where a 1 is preceeded by a -1\n",
    "        bol1[batch_ind,bol[batch_ind]] = False if new_bol.size(0) == 0 else new_bol\n",
    "    cumulative_sum = torch.cumsum(bol1, dim=1)\n",
    "    # Use modulo operation to get the desired tensor\n",
    "    bol2 = cumulative_sum % 2 == 1\n",
    "    bol2[is_eos_tokens]= False\n",
    "    return bol2[:,1:]\n",
    "\n",
    "\n",
    "def compute_metrics(predictions, labels, full_fim=False):\n",
    "    predictions = predictions.permute(0, 2, 1)\n",
    "    labels = labels\n",
    "    # shift labels to align them with predictions and remove last prediction to match the length\n",
    "    predictions = predictions[:, :, :-1].contiguous()\n",
    "    labels = labels[:, 1:].contiguous()\n",
    "    # compute unreduced elementwise loss\n",
    "    unreduced_loss = torch.nn.functional.cross_entropy(predictions, labels, reduction=\"none\")\n",
    "    # compute reconstruction accuracy\n",
    "    reconstruction = (predictions.argmax(1) == labels)\n",
    "\n",
    "    # start and end tokens\n",
    "    is_cls_tokens = (labels == AA_TO_ID[\"<cls>\"])\n",
    "    is_eos_tokens = (labels == AA_TO_ID[\"<eos>\"])\n",
    "    # fill in the middle tokens\n",
    "    if full_fim:\n",
    "        print(\"Using for loop fim\")\n",
    "        fim_tokens = torch.zeros(is_cls_tokens.size(0), is_cls_tokens.size(1), dtype=torch.bool).to(is_cls_tokens.device)\n",
    "        in_mask_vector = torch.zeros(is_cls_tokens.size(0), dtype=torch.bool).to(is_cls_tokens.device)\n",
    "        for j in range(is_cls_tokens.size(1)):\n",
    "            in_mask_vector = in_mask_vector & ~is_cls_tokens[:, j]\n",
    "            fim_tokens[:, j] = in_mask_vector\n",
    "            in_mask_vector = in_mask_vector | is_eos_tokens[:, j]\n",
    "    else:\n",
    "        fim_tokens = find_fim_indices(is_cls_tokens, is_eos_tokens)\n",
    "\n",
    "    number_sequences = torch.cumsum(torch.cat([torch.zeros(is_cls_tokens.size(0),1, dtype=torch.int32).to(is_cls_tokens.device), is_cls_tokens[:,:-1]],1), -1)\n",
    "    # sequence tokens\n",
    "    sequence_perplexity, sequence_fim_perplexity = [], []\n",
    "    sequence_loss, sequence_fim_loss = [], []\n",
    "    sequence_size_fim_part = []\n",
    "    num_tokens_preceeding = []\n",
    "    num_tokens = torch.arange(labels.size(1)).to(labels.device)\n",
    "    num_tokens = torch.cat([num_tokens[None, :] for _ in range(labels.size(0))], dim=0)\n",
    "    for i in range(torch.max(number_sequences.max(1).values[:, None] - 1).item()):\n",
    "        i_sequence_tokens = ((~fim_tokens & (labels < 33)) | fim_tokens) & (number_sequences == i)\n",
    "        i_sequence_tokens_fim = fim_tokens & (number_sequences == i) \n",
    "        i_sequence_cls_tokens = is_cls_tokens & (number_sequences == i)\n",
    "        num_tokens_preceeding.append(num_tokens[i_sequence_cls_tokens])\n",
    "        sequence_perplexity.append(torch.exp(torch.mean(unreduced_loss[i_sequence_tokens])).item())\n",
    "        sequence_fim_perplexity.append(torch.exp(torch.mean(unreduced_loss[i_sequence_tokens_fim])).item())\n",
    "        sequence_loss.append(torch.mean(unreduced_loss[i_sequence_tokens]).item())\n",
    "        sequence_fim_loss.append(torch.mean(unreduced_loss[i_sequence_tokens_fim]).item())\n",
    "        tmp1 = i_sequence_tokens_fim & (labels < 33)\n",
    "        tmp2 = i_sequence_tokens_fim & (labels >= 33)\n",
    "        sequence_size_fim_part.append((tmp1.sum()/tmp2.sum()).item())\n",
    "    # metrics\n",
    "    # cum_loss = torch.cumsum(unreduced_loss, dim=1)/torch.cumsum(torch.ones_like(unreduced_loss), dim=1)\n",
    "    return {\"loss\": unreduced_loss,\n",
    "            \"sequence_losses\": torch.tensor(sequence_loss),\n",
    "            \"sequence_perplexities\": torch.tensor(sequence_perplexity),\n",
    "            \"sequence_fim_losses\": torch.tensor(sequence_fim_loss),\n",
    "            \"sequence_fim_perplexities\": torch.tensor(sequence_fim_perplexity),\n",
    "            \"sequence_size_fim_part\": torch.tensor(sequence_size_fim_part),\n",
    "            \"reconstruction\": reconstruction,\n",
    "            \"num_tokens_preceeding\": torch.tensor(num_tokens_preceeding),\n",
    "            }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Perplexity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "is_fim = True\n",
    "dataset_name = \"encoded_MSAs_test.pkl\"\n",
    "fim_strategy = \"multiple_span\" if is_fim else \"no-scramble\"\n",
    "# Load dataset\n",
    "dataset = Uniclust30_Dataset(dataset_name,\n",
    "                             filepath=\"/data1/common/OpenProteinSet/\",\n",
    "                             max_msa_len=135000,\n",
    "                             sample=False,\n",
    "                             max_patches=1,\n",
    "                             mask_fraction=0.1,\n",
    "                             fim_strategy=fim_strategy,\n",
    "                             max_position_embeddings=2048,\n",
    "                             add_position_ids=\"1d\")\n",
    "device = \"cuda\"\n",
    "# Load pretrained model\n",
    "model = load_model(\"../../nbs/results/train_100M_FIM_restart-spikes_merged/finetuned_FIM_checkpoint_131k-3200\",#checkpoint_131k-8750\n",
    "                   model_class=MambaLMHeadModelwithPosids,\n",
    "                   device=device,\n",
    "                   dtype=torch.bfloat16,\n",
    "                   checkpoint_mixer=False)\n",
    "model.eval()\n",
    "# model = torch.compile(model)\n",
    "# data_collator = DataCollatorForUniclust30Dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# histogram of length sequences in dataset\n",
    "lengths = []\n",
    "for i in range(len(dataset)):\n",
    "    data = dataset[i][\"input_ids\"]\n",
    "    start = dataset.get_index_start_of_sequences(dataset[i][\"input_ids\"])\n",
    "    lengths.append(np.mean(start[1:] - start[:-1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(lengths, bins=20)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "what = input(\"How should the files be named?:\") # \"0.05-135k_new\"\n",
    "loss = []\n",
    "# reconstruction = []\n",
    "seq_size_fim_part = []\n",
    "seq_loss, seq_perp = [], []\n",
    "seq_fim_loss, seq_fim_perp = [], []\n",
    "num_tokens_preceeding = []\n",
    "with torch.no_grad():\n",
    "    for i in tqdm(range(len(dataset))):\n",
    "        for j in range(100):\n",
    "            data = dataset[i]#data_collator([dataset[i] for j in range(4)])\n",
    "            tokens = data[\"input_ids\"][None,:].to(device)\n",
    "            pos_ids = data[\"position_ids\"][None,:].to(device)\n",
    "            # tokens = data[\"input_ids\"].to(device)\n",
    "            # pos_ids = data[\"position_ids\"].to(device)\n",
    "            out = model(tokens, pos_ids)\n",
    "            logits = out.logits\n",
    "            metrics = compute_metrics(logits, tokens)\n",
    "            loss += [metrics[\"loss\"].cpu().to(torch.float).numpy()]\n",
    "            # reconstruction += [metrics[\"reconstruction\"].cpu().to(torch.float).numpy()]\n",
    "            # cum_loss += [metrics[\"cumulative_loss\"][0].cpu().to(torch.float).numpy()]\n",
    "            # cum_perp += [metrics[\"cumulative_perplexity\"][0].cpu().to(torch.float).numpy()]\n",
    "            # cum_rec += [metrics[\"cumulative_reconstruction\"][0].cpu().to(torch.float).numpy()]\n",
    "            seq_loss.append(metrics[\"sequence_losses\"])\n",
    "            seq_perp.append(metrics[\"sequence_perplexities\"])\n",
    "            seq_fim_loss.append(metrics[\"sequence_fim_losses\"])\n",
    "            seq_fim_perp.append(metrics[\"sequence_fim_perplexities\"])\n",
    "            seq_size_fim_part.append(metrics[\"sequence_size_fim_part\"])\n",
    "            num_tokens_preceeding.append(metrics[\"num_tokens_preceeding\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dict_seq_loss, dict_seq_perp = {}, {}\n",
    "dict_seq_length, dict_seq_fim_length = {}, {}\n",
    "dict_seq_fim_loss, dict_seq_fim_perp = {}, {}\n",
    "dict_size_fim = {}\n",
    "ctx_lengths = [1024,2048,4096,8192,16384,32768,65536,131072,262144]\n",
    "dict_ctx_perp, dict_ctx_fim_perp = {l_ctx: [] for l_ctx in ctx_lengths}, {l_ctx: [] for l_ctx in ctx_lengths}\n",
    "dict_ctx_lenghts = {l_ctx: [] for l_ctx in ctx_lengths}\n",
    "dict_ctx_len, dict_ctx_fim_len = {l_ctx: [] for l_ctx in ctx_lengths}, {l_ctx: [] for l_ctx in ctx_lengths}\n",
    "for el in zip(seq_loss, seq_perp, seq_fim_loss, seq_fim_perp, seq_size_fim_part, num_tokens_preceeding):\n",
    "    avg_l = (el[5][-1]/len(el[5])).item()\n",
    "    for i in range(len(el[0])):\n",
    "        if i in dict_seq_loss:\n",
    "            dict_seq_loss[i].append(el[0][i].item())\n",
    "            dict_seq_perp[i].append(el[1][i].item())\n",
    "            dict_seq_length[i].append(avg_l)\n",
    "        else:\n",
    "            dict_seq_perp[i] = [el[1][i].item()]\n",
    "            dict_seq_loss[i] = [el[0][i].item()]\n",
    "            dict_seq_length[i] = [avg_l]\n",
    "    for i in range(len(el[2])):\n",
    "        if i in dict_seq_fim_loss:\n",
    "            if el[2][i] >0:\n",
    "                dict_seq_fim_loss[i].append(el[2][i].item())\n",
    "                dict_seq_fim_length[i].append(avg_l)\n",
    "            if el[3][i]  >0:\n",
    "                dict_seq_fim_perp[i].append(el[3][i].item())\n",
    "            if el[4][i] >0:\n",
    "                dict_size_fim[i].append(el[4][i].item())\n",
    "        else:\n",
    "            if el[2][i] >0:\n",
    "                dict_seq_fim_loss[i] = [el[2][i].item()]\n",
    "                dict_seq_fim_length[i] = [avg_l]\n",
    "            if el[3][i] >0:\n",
    "                dict_seq_fim_perp[i] = [el[3][i].item()]\n",
    "            if el[4][i] >0:\n",
    "                dict_size_fim[i] = [el[4][i].item()]\n",
    "    # context legths\n",
    "    for l_ctx in ctx_lengths:\n",
    "        if el[5][-1].item() > l_ctx and el[5][0].item() < l_ctx:\n",
    "            indx = torch.argwhere(el[5]<=l_ctx)[-1]\n",
    "            dict_ctx_lenghts[l_ctx].append(el[5][indx].item())\n",
    "            dict_ctx_perp[l_ctx].append(el[1][indx].item())\n",
    "            dict_ctx_len[l_ctx].append(avg_l)\n",
    "            if el[3][indx] > 0:\n",
    "                dict_ctx_fim_perp[l_ctx].append(el[3][indx].item())\n",
    "                dict_ctx_fim_len[l_ctx].append(avg_l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(f\"figures/{what}\", exist_ok=True)\n",
    "\n",
    "with open(f\"figures/{what}/dict_seq_loss_all_test.pkl\", \"wb\") as f:\n",
    "    pickle.dump(dict_seq_loss, f)\n",
    "with open(f\"figures/{what}/dict_seq_perp_all_test.pkl\", \"wb\") as f:\n",
    "    pickle.dump(dict_seq_perp, f)\n",
    "with open(f\"figures/{what}/dict_seq_fim_loss_all_test.pkl\", \"wb\") as f:\n",
    "    pickle.dump(dict_seq_fim_loss, f)\n",
    "with open(f\"figures/{what}/dict_seq_fim_perp_all_test.pkl\", \"wb\") as f:\n",
    "    pickle.dump(dict_seq_fim_perp, f)\n",
    "with open(f\"figures/{what}/dict_size_fim_all_test.pkl\", \"wb\") as f:\n",
    "    pickle.dump(dict_size_fim, f)\n",
    "with open(f\"figures/{what}/dict_ctx_lengths_all_test.pkl\", \"wb\") as f:\n",
    "    pickle.dump(dict_ctx_lenghts, f)\n",
    "with open(f\"figures/{what}/dict_ctx_perp_all_test.pkl\", \"wb\") as f:\n",
    "    pickle.dump(dict_ctx_perp, f)\n",
    "with open(f\"figures/{what}/dict_ctx_fim_perp_all_test.pkl\", \"wb\") as f:\n",
    "    pickle.dump(dict_ctx_fim_perp, f)\n",
    "with open(f\"figures/{what}/dict_seq_length_all_test.pkl\", \"wb\") as f:\n",
    "    pickle.dump(dict_seq_length, f)\n",
    "with open(f\"figures/{what}/dict_seq_fim_length_all_test.pkl\", \"wb\") as f:\n",
    "    pickle.dump(dict_seq_fim_length, f)\n",
    "with open(f\"figures/{what}/dict_ctx_len_all_test.pkl\", \"wb\") as f:\n",
    "    pickle.dump(dict_ctx_len, f)\n",
    "with open(f\"figures/{what}/dict_ctx_fim_len_all_test.pkl\", \"wb\") as f:\n",
    "    pickle.dump(dict_ctx_fim_len, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Figures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "sns.set_theme(\n",
    "    context='notebook', style='ticks', palette='bright',\n",
    "    color_codes=True)  #other contexts: “paper”, “talk”, and “poster”,\n",
    "\n",
    "# Plotting settings\n",
    "SMALL_SIZE = 15\n",
    "MEDIUM_SIZE = 20\n",
    "BIGGER_SIZE = 30\n",
    "\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": False,\n",
    "    \"font.family\": \"sans-serif\", #sans-serif\n",
    "    \"font.serif\": [\"Arial\"],\n",
    "    \"font.size\": MEDIUM_SIZE,\n",
    "    \"axes.titlesize\": MEDIUM_SIZE,\n",
    "    \"axes.labelsize\": MEDIUM_SIZE,\n",
    "    \"figure.titlesize\": MEDIUM_SIZE,\n",
    "    \"figure.labelsize\": MEDIUM_SIZE,\n",
    "    \"xtick.labelsize\": SMALL_SIZE,\n",
    "    \"ytick.labelsize\": SMALL_SIZE,\n",
    "    \"legend.fontsize\": MEDIUM_SIZE,\n",
    "})\n",
    "\n",
    "color_ = [(64, 83, 211), (0, 178, 93), (181, 29, 20), (221, 179, 16), (0, 190, 255), (251, 73, 176), (202, 202, 202)]\n",
    "color =[]\n",
    "for t in color_:\n",
    "    color.append(tuple(ti/255 for ti in t))\n",
    "    \n",
    "# split a dictionary based on the length of the sequences\n",
    "def split_lengths(dic, lengths, L=[100, 200]):\n",
    "    new_dic = {ll: {j: [] for j in dic.keys()} for ll in L+[-1]}\n",
    "    for j in dic.keys():\n",
    "        lst = dic[j]\n",
    "        lst_l = lengths[j]\n",
    "        for i in range(len(lst)):\n",
    "            for ll in L:\n",
    "                if lst_l[i] <= ll:\n",
    "                    new_dic[ll][j].append(lst[i])\n",
    "                elif ll == L[-1]:\n",
    "                    new_dic[-1][j].append(lst[i])\n",
    "    return new_dic, L+[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "what = input(\"Name:\")\n",
    "with open(f\"figures/{what}/dict_seq_loss_all_test.pkl\", \"rb\") as f:\n",
    "    dict_seq_loss = pickle.load(f)\n",
    "with open(f\"figures/{what}/dict_seq_perp_all_test.pkl\", \"rb\") as f:\n",
    "    dict_seq_perp = pickle.load(f)\n",
    "with open(f\"figures/{what}/dict_seq_fim_loss_all_test.pkl\", \"rb\") as f:\n",
    "    dict_seq_fim_loss = pickle.load(f)\n",
    "with open(f\"figures/{what}/dict_seq_fim_perp_all_test.pkl\", \"rb\") as f:\n",
    "    dict_seq_fim_perp = pickle.load(f)\n",
    "with open(f\"figures/{what}/dict_size_fim_all_test.pkl\", \"rb\") as f:\n",
    "    dict_size_fim = pickle.load(f)\n",
    "with open(f\"figures/{what}/dict_ctx_lengths_all_test.pkl\", \"rb\") as f:\n",
    "    dict_ctx_lenghts = pickle.load(f)\n",
    "with open(f\"figures/{what}/dict_ctx_perp_all_test.pkl\", \"rb\") as f:\n",
    "    dict_ctx_perp = pickle.load(f)\n",
    "with open(f\"figures/{what}/dict_ctx_fim_perp_all_test.pkl\", \"rb\") as f:\n",
    "    dict_ctx_fim_perp = pickle.load(f)\n",
    "with open(f\"figures/{what}/dict_seq_length_all_test.pkl\", \"rb\") as f:\n",
    "    dict_seq_length = pickle.load(f)\n",
    "with open(f\"figures/{what}/dict_seq_fim_length_all_test.pkl\", \"rb\") as f:\n",
    "    dict_seq_fim_length = pickle.load(f)\n",
    "with open(f\"figures/{what}/dict_ctx_len_all_test.pkl\", \"rb\") as f:\n",
    "    dict_ctx_len = pickle.load(f)\n",
    "with open(f\"figures/{what}/dict_ctx_fim_len_all_test.pkl\", \"rb\") as f:\n",
    "    dict_ctx_fim_len = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(2,5, figsize=(20,10), sharex=True, constrained_layout=True)\n",
    "\n",
    "keys = list(dict_seq_perp.keys())\n",
    "vals = [5,25,55,105,155,205,255,505,755,1005]\n",
    "keys = {i: [keys[j-5+vals[i]] for j in range(11)] for i in range(len(vals))}\n",
    "do_fim = True if len(dict_seq_fim_perp) > 0 else False\n",
    "for i, keys_t in enumerate(keys):\n",
    "    all_vals0, all_vals1 = [], []\n",
    "    for key in keys[keys_t]:\n",
    "        all_vals0 += dict_seq_perp[key]\n",
    "        if do_fim:\n",
    "            all_vals1 += dict_seq_fim_perp[key]\n",
    "    axs[i//5,i%5].axvline(np.median(all_vals0), color=color[0], linestyle=\"--\")\n",
    "    axs[i//5,i%5].hist(all_vals0, bins=np.linspace(0,30,40), alpha=0.5, color=color[0], density=True, label=\"full\")\n",
    "    if do_fim:\n",
    "        axs[i//5,i%5].axvline(np.median(all_vals1), color=color[1], linestyle=\"--\")\n",
    "        axs[i//5,i%5].hist(all_vals1, bins=np.linspace(0,30,40), alpha=0.5, color=color[1], density=True, label=\"fim\")\n",
    "    axs[i//5,i%5].set_title(f\"Seq. positions: ({keys[keys_t][0]}-{keys[keys_t][-1]})\")\n",
    "axs[0,0].legend()\n",
    "fig.supxlabel(\"Perplexity\")\n",
    "fig.supylabel(\"Count\")\n",
    "fig.suptitle(\"Perplexity distributions\")\n",
    "fig.savefig(f\"figures/{what}/perplexity_distributions.pdf\")\n",
    "plt.show()\n",
    "\n",
    "fig, axs = plt.subplots(2,5, figsize=(20,10), sharex=True, constrained_layout=True)\n",
    "for i, keys_t in enumerate(keys):\n",
    "    all_vals0, all_vals1 = [], []\n",
    "    for key in keys[keys_t]:\n",
    "        all_vals0 += dict_seq_loss[key]\n",
    "        if do_fim:\n",
    "            all_vals1 += dict_seq_fim_loss[key]\n",
    "    axs[i//5,i%5].axvline(np.median(all_vals0), color=color[0], linestyle=\"--\")\n",
    "    axs[i//5,i%5].hist(all_vals0, bins=np.linspace(0,4,40), alpha=0.5, color=color[0], density=True, label=\"full\")\n",
    "    if do_fim:\n",
    "        axs[i//5,i%5].axvline(np.median(all_vals1), color=color[1], linestyle=\"--\")\n",
    "        axs[i//5,i%5].hist(all_vals1, bins=np.linspace(0,4,40), alpha=0.5, color=color[1], density=True, label=\"fim\")\n",
    "    axs[i//5,i%5].set_title(f\"Seq. positions: ({keys[keys_t][0]}-{keys[keys_t][-1]})\")\n",
    "axs[0,0].legend()\n",
    "fig.supxlabel(\"Loss\")\n",
    "fig.supylabel(\"Count\")\n",
    "fig.suptitle(\"Loss distributions\")\n",
    "fig.savefig(f\"figures/{what}/loss_distributions.pdf\")\n",
    "plt.show()\n",
    "\n",
    "if do_fim:\n",
    "    fig, axs = plt.subplots(2,5, figsize=(20,10), sharex=True, sharey=True, constrained_layout=True)\n",
    "    for i, keys_t in enumerate(list(keys.keys())):\n",
    "        all_vals0 = []\n",
    "        size_vals = []\n",
    "        for key in keys[keys_t]:\n",
    "            all_vals0 += dict_seq_fim_perp[key]\n",
    "            size_vals += dict_size_fim[key]\n",
    "        axs[i//5,i%5].plot(size_vals, all_vals0, \"o\", color=color[1], alpha=0.5)\n",
    "        axs[i//5,i%5].axhline(np.median(all_vals0), color=\"k\", linestyle=\"--\")\n",
    "        axs[i//5,i%5].axvline(np.median(size_vals), color=\"k\", linestyle=\"--\")\n",
    "        axs[i//5,i%5].set_title(f\"Seq. positions: ({keys[keys_t][0]}-{keys[keys_t][-1]})\")\n",
    "        axs[i//5,i%5].set_ylim(0,50)\n",
    "\n",
    "    fig.supxlabel(\"Size FIM masks (per sequence)\")\n",
    "    fig.supylabel(\"Perplexity\")\n",
    "    fig.savefig(f\"figures/{what}/fim_perplexity_vs_size.pdf\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Lvec1 = [150, 250]\n",
    "dict_seq_perps, Lvec = split_lengths(dict_seq_perp, dict_seq_length, L=Lvec1)\n",
    "dict_seq_fim_perps, Lvec = split_lengths(dict_seq_fim_perp, dict_seq_fim_length, L=Lvec1)\n",
    "median_seq_perp = np.zeros(len(dict_seq_loss)+1)\n",
    "median_seq_loss, median_seq_perps = np.zeros(len(dict_seq_loss)+1), {ll: np.zeros(len(dict_seq_perps[ll])+1) for ll in Lvec}\n",
    "if do_fim:\n",
    "    median_seq_fim_loss, median_seq_fim_perp = np.zeros(max(dict_seq_fim_loss.keys())+1), {ll: np.zeros(max(dict_seq_fim_perps[ll].keys())+1) for ll in Lvec}\n",
    "for key in dict_seq_loss:\n",
    "    # avg_seq_loss[key] = np.mean(dict_seq_loss[key])\n",
    "    median_seq_loss[key] = np.median(dict_seq_loss[key])\n",
    "    median_seq_perp[key] = np.median(dict_seq_perp[key])\n",
    "    for ll in Lvec:\n",
    "        median_seq_perps[ll][key] = np.median(dict_seq_perps[ll][key])\n",
    "if do_fim:\n",
    "    for key in dict_seq_fim_loss:\n",
    "        median_seq_fim_loss[key] = np.median(dict_seq_fim_loss[key])\n",
    "        for ll in Lvec:\n",
    "            median_seq_fim_perp[ll][key] = np.median(dict_seq_fim_perps[ll][key])\n",
    "    \n",
    "fig, axs = plt.subplots(1,2, figsize=(20,5), constrained_layout=True)\n",
    "# inds = np.random.choice(len(dict_seq_loss), 1000)\n",
    "# for i in inds:\n",
    "#     el = dict_seq_loss[i], dict_seq_perps[i]\n",
    "#     axs[0].plot(np.ones_like(el[0])*i,el[0], \".\", color=color[0], alpha=0.05)\n",
    "#     axs[1].plot(np.ones_like(el[1])*i,el[1], \".\", color=color[0], alpha=0.05)\n",
    "# axs[0].plot(avg_seq_loss, \"k-\", label=\"Average\")\n",
    "axs[0].plot(smooth(median_seq_loss,0.6), \"k-\", label=\"Median\")\n",
    "# axs[1].plot(avg_seq_perp, \"k-\", label=\"Average\")\n",
    "# axs_new = axs[1].twinx()\n",
    "for i, ll in enumerate(Lvec):\n",
    "    if i==0:\n",
    "        string = f\"$0<L<${ll}\"\n",
    "    elif ll == Lvec[-1]:\n",
    "        string = f\"$L>${Lvec[-2]}\"\n",
    "    else:\n",
    "        string = f\"{Lvec[i-1]}$<L<${ll}\"\n",
    "    axs[1].plot(smooth(median_seq_perps[ll],0.8), \"-\", color=color[i+1], label=string)\n",
    "    axs[1].plot(smooth(median_seq_perp,0.8), \"--\", color=\"k\")\n",
    "    # axs_new.plot([len(dict_seq_perps[ll][key])/len(dict_seq_perps[ll][0]) for key in dict_seq_perps[ll].keys()], \"-\", color=color[0])\n",
    "axs[1].axhline(5, color=\"k\", linestyle=\"--\", alpha=0.5)\n",
    "axs[1].axhline(10, color=\"k\", linestyle=\"--\", alpha=0.5)\n",
    "# axs_new.set_ylabel(\"Fraction of clusters\", color=color[0])\n",
    "# axs_new.tick_params(axis='y', labelcolor=color[0])\n",
    "axs[0].set_ylim(top=4)\n",
    "# axs[1].set_ylim(0.1, 30)\n",
    "axs[0].set_ylabel(\"Loss\")\n",
    "axs[1].set_ylabel(\"Perplexity\")\n",
    "# axs[0].legend()\n",
    "axs[1].legend()\n",
    "axs[0].set_xlabel(\"Number of sequences in context\")\n",
    "axs[1].set_xlabel(\"Number of sequences in context\")\n",
    "fig.savefig(f\"figures/{what}/seq_pos_perplexity_loss.pdf\", dpi=20)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if do_fim:\n",
    "    # dict_sizes_fim, Lvec = split_lengths(dict_size_fim, dict_seq_fim_length, L=Lvec1)\n",
    "    # median_seq_fim_loss_sz, median_seq_fim_perp_sz = {}, {}\n",
    "    # num_vals = {}\n",
    "    # median_seq_fim_perps_sz = {ll: {} for ll in Lvec}\n",
    "    # std_seq_fim_loss_sz, std_seq_fim_perp_sz = {}, {}\n",
    "    # sz_lst = [1,5,10,-10]\n",
    "    # for key in dict_seq_fim_loss:\n",
    "    #     sizes = dict_size_fim[key]\n",
    "    #     all_sizes = {ll: dict_sizes_fim[ll][key] for ll in Lvec}\n",
    "    #     for i,sz in enumerate(sz_lst):\n",
    "    #         if sz not in median_seq_fim_loss_sz:\n",
    "    #             median_seq_fim_loss_sz[sz] = np.zeros(max(dict_seq_fim_loss.keys())+1)\n",
    "    #             median_seq_fim_perp_sz[sz] = np.zeros(max(dict_seq_fim_loss.keys())+1)\n",
    "    #             num_vals[sz] = np.zeros(max(dict_seq_fim_loss.keys())+1)\n",
    "    #             for ll in Lvec:\n",
    "    #                 median_seq_fim_perps_sz[ll][sz] = np.zeros(max(dict_seq_fim_loss.keys())+1)\n",
    "    #             std_seq_fim_loss_sz[sz] = np.zeros(max(dict_seq_fim_loss.keys())+1)\n",
    "    #             std_seq_fim_perp_sz[sz] = np.zeros(max(dict_seq_fim_loss.keys())+1)\n",
    "    #         if sz<0:\n",
    "    #             indxs = np.argwhere((np.array(sizes)-1) > -sz)[:,0]\n",
    "    #         else:\n",
    "    #             indxs = np.argwhere(((np.array(sizes)-1) <= sz) & ((np.array(sizes)-1) > (sz_lst[i-1] if i>0 else 0)))[:,0]\n",
    "    #         median_seq_fim_loss_sz[sz][key] = np.median([dict_seq_fim_loss[key][i] for i in indxs])\n",
    "    #         median_seq_fim_perp_sz[sz][key] = np.median([dict_seq_fim_perp[key][i] for i in indxs])\n",
    "    #         num_vals[sz][key] = len([dict_seq_fim_perp[key][i] for i in indxs])\n",
    "    #         for ll in Lvec:\n",
    "    #             if sz<0:\n",
    "    #                 all_indxs = np.argwhere((np.array(all_sizes[ll])-1) > -sz)[:,0]\n",
    "    #             else:\n",
    "    #                 all_indxs = np.argwhere(((np.array(all_sizes[ll])-1) <= sz) & ((np.array(all_sizes[ll])-1) > (sz_lst[i-1] if i>0 else 0)))[:,0]\n",
    "    #             median_seq_fim_perps_sz[ll][sz][key] = np.median([dict_seq_fim_perps[ll][key][i] for i in all_indxs])\n",
    "    #         std_seq_fim_loss_sz[sz][key] = np.std([dict_seq_fim_loss[key][i] for i in indxs])\n",
    "    #         std_seq_fim_perp_sz[sz][key] = np.std([dict_seq_fim_perp[key][i] for i in indxs])\n",
    "            \n",
    "    fig, axs = plt.subplots(1,1, figsize=(13,5), constrained_layout=True)\n",
    "\n",
    "    # moving_avg = lambda x, w: np.convolve(x, np.ones(w)/w, 'valid')\n",
    "    smoothing = 0.8\n",
    "    axs= [axs]\n",
    "    x_ticks = np.arange(len(median_seq_fim_loss))\n",
    "    x_ticks = x_ticks[median_seq_fim_loss>1e-10]\n",
    "    # axs[0].plot(x_ticks, smooth(median_seq_fim_perp[median_seq_fim_loss>1e-10],0.6), \"k-\", label=\"Median\")\n",
    "    for i, sz in enumerate(sz_lst):\n",
    "        str_sz = sz_lst[i-1] if i>0 else 0\n",
    "        lbl_name = f\"{str_sz} $ < N_m \\leq$ {sz}\" if sz>0 else f\"$N_m > $ {-sz}\"\n",
    "        lbl_name = f\"$N_m = $ {sz}\" if sz == 1 else lbl_name\n",
    "        where_vals = num_vals[sz][median_seq_fim_loss>1e-10]>0\n",
    "        axs[0].plot(x_ticks[where_vals], smooth(median_seq_fim_perp_sz[sz][median_seq_fim_loss>1e-10][where_vals], smoothing), label=lbl_name, color=color[i])\n",
    "        # for ll in Lvec:\n",
    "        #     axs[0].plot(x_ticks, smooth(median_seq_fim_perps_sz[ll][sz][median_seq_fim_loss>1e-10], smoothing), \"--\", color=color[i], alpha=0.5)\n",
    "        # axs[0].fill_between(x_ticks, smooth(median_seq_fim_perp_sz[sz][median_seq_fim_loss>1e-10]-std_seq_fim_perp_sz[sz][median_seq_fim_loss>1e-10], smoothing),\n",
    "        #                     smooth(median_seq_fim_perp_sz[sz][median_seq_fim_loss>1e-10]+std_seq_fim_perp_sz[sz][median_seq_fim_loss>1e-10], smoothing), alpha=0.3)\n",
    "\n",
    "    # axs[0].plot(x_ticks, smooth(median_seq_fim_perp[median_seq_fim_loss>1e-10],smoothing), \"k-\", label=\"All FIM\")\n",
    "    # axs[0].plot(smooth(median_seq_perp,smoothing), \"k--\", label=\"Full\")\n",
    "    # axs[0].axhline(5, color=\"k\", linestyle=\"--\", alpha=0.5)\n",
    "    # axs[0].axhline(10, color=\"k\", linestyle=\"--\", alpha=0.5)\n",
    "    # axs[0].set_ylim(2,11)\n",
    "    axs[0].set_xlim(left=-10)\n",
    "    axs[0].set_ylabel(\"Perplexity\")\n",
    "    axs[0].legend(frameon=False, loc=\"upper right\")\n",
    "    axs[0].set_xlabel(\"Number of sequences in context\")\n",
    "    fig.savefig(f\"figures/{what}/seq_pos_perplexity_fim.pdf\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1,1, figsize=(13,5), constrained_layout=True)\n",
    "\n",
    "smoothing = 0.0\n",
    "axs= [axs]\n",
    "x_ticks = np.arange(len(median_seq_fim_loss))\n",
    "x_ticks = x_ticks[median_seq_fim_loss>1e-10]\n",
    "# axs[0].plot(x_ticks, smooth(median_seq_fim_perp[median_seq_fim_loss>1e-10],0.6), \"k-\", label=\"Median\")\n",
    "for i, sz in enumerate(sz_lst):\n",
    "    str_sz = sz_lst[i-1] if i>0 else 0\n",
    "    lbl_name = f\"{str_sz} $ < N_m \\leq$ {sz}\" if sz>0 else f\"$N_m > $ {-sz}\"\n",
    "    lbl_name = f\"$N_m = $ {sz}\" if sz == 1 else lbl_name\n",
    "    axs[0].plot(x_ticks, num_vals[sz][median_seq_fim_loss>1e-10], label=lbl_name, color=color[i])\n",
    "axs[0].set_yscale(\"log\")\n",
    "axs[0].set_xlim(left=-10)\n",
    "axs[0].set_ylabel(\"Nvals\")\n",
    "axs[0].legend(frameon=False, loc=\"upper right\")\n",
    "axs[0].set_xlabel(\"Number of sequences in context\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ctx_lengths = list(dict_ctx_lenghts.keys())\n",
    "dict_ctx_perps, Lvec = split_lengths(dict_ctx_perp, dict_ctx_len, Lvec1)\n",
    "dict_ctx_fim_perps, Lvec = split_lengths(dict_ctx_fim_perp, dict_ctx_fim_len, Lvec1)\n",
    "median_ctx_perp = {ll: np.zeros(len(ctx_lengths)) for ll in Lvec}\n",
    "median_ctx_fim_perp = {ll: np.zeros(len(ctx_lengths)) for ll in Lvec}\n",
    "median_ctx_lenghts = np.zeros(len(ctx_lengths))\n",
    "std_ctx_perp, std_ctx_fim_perp = np.zeros(len(ctx_lengths)), np.zeros(len(ctx_lengths))\n",
    "\n",
    "for i,key in enumerate(dict_ctx_perp):\n",
    "    for ll in Lvec:\n",
    "        median_ctx_perp[ll][i] = np.median(dict_ctx_perps[ll][key])\n",
    "        if do_fim:\n",
    "            median_ctx_fim_perp[ll][i] = np.median(dict_ctx_fim_perps[ll][key])\n",
    "    median_ctx_lenghts[i] = np.median(dict_ctx_lenghts[key])\n",
    "\n",
    "fig, axs = plt.subplots(1,2, figsize=(20,5), constrained_layout=True)\n",
    " \n",
    "for i, ll in enumerate(Lvec):\n",
    "    if i==0:\n",
    "        string = f\"$0<L<${ll}\"\n",
    "    elif ll == Lvec[-1]:\n",
    "        string = f\"$L>${Lvec[-2]}\"\n",
    "    else:\n",
    "        string = f\"{Lvec[i-1]}$<L<${ll}\"\n",
    "    axs[0].plot(median_ctx_lenghts, median_ctx_perp[ll], \"-\", color=color[i+1], label=string)\n",
    "    if do_fim:\n",
    "        axs[1].plot(median_ctx_lenghts, median_ctx_fim_perp[ll], \"-\", color=color[i+1], label=string)\n",
    "axs[0].set_xscale(\"log\")\n",
    "axs[1].set_xscale(\"log\")\n",
    "axs[0].set_xticks(ctx_lengths[:7]+[131072, 262144, 524288])\n",
    "axs[1].set_xticks(ctx_lengths[:7]+[131072, 262144, 524288])\n",
    "axs[0].set_xticklabels(ctx_lengths[:7]+[131072, 262144, 524288], rotation=45)\n",
    "axs[1].set_xticklabels(ctx_lengths[:7]+[131072, 262144, 524288], rotation=45)\n",
    "axs[0].set_ylabel(\"Perplexity full sequence\")\n",
    "axs[1].set_ylabel(\"Perplexity FIM tokens\")\n",
    "axs[0].legend()\n",
    "axs[0].set_xlabel(\"Number of tokens in context\")\n",
    "axs[1].set_xlabel(\"Number of tokens in context\")\n",
    "# fig.suptitle(\"Perplexity as function of context length\")\n",
    "fig.savefig(f\"figures/{what}/ctx_length_perplexity.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import training logs and plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# training 2048 - 70k steps\n",
    "print((64*2048*10**4+2*64*2048*6*10**4) / 10**10, \"10^10 tokens\")\n",
    "# training 16k - 164k steps\n",
    "print(8*2*16384*164000 /10**10, \"10^10 tokens\")\n",
    "# training 32k - 116k steps\n",
    "print(4*4*32768*118250 /10**10, \"10^10 tokens\")\n",
    "# training 131k - 8750 steps\n",
    "print(2*2*16*131072*8750 /10**10, \"10^10 tokens\")\n",
    "\n",
    "def interpolate(steps_array=np.linspace(10, 352000, 100), stops=[61000, 225000, 342100, -1], num_tokens=[64*2048*50,\n",
    "                                                                64*2048*10**4+2*64*2048*6*10**4,\n",
    "                                                                64*2048*10**4+2*64*2048*6*10**4+8*2*16384*164000,\n",
    "                                                                64*2048*10**4+2*64*2048*6*10**4+8*2*16384*164000+4*4*32768*118250,\n",
    "                                                                64*2048*10**4+2*64*2048*6*10**4+8*2*16384*164000+4*4*32768*118250+2*2*16*131072*8750,]):\n",
    "    tmp_length = 0\n",
    "    new_steps_array = np.zeros(len(steps_array))\n",
    "    if stops[-1] == -1:\n",
    "        stops[-1] = steps_array[-1]\n",
    "    for i in range(len(stops)):\n",
    "        length = sum(steps_array<=stops[i]) - tmp_length\n",
    "        if i == len(stops)-1:\n",
    "            length = len(steps_array) - tmp_length\n",
    "        new_steps_array[tmp_length:tmp_length+length] = np.linspace(num_tokens[i], num_tokens[i+1], length)\n",
    "        tmp_length += length\n",
    "        print(tmp_length, num_tokens[i], num_tokens[i+1])\n",
    "    return new_steps_array\n",
    "\n",
    "new = interpolate()\n",
    "new_logging = load_tensorboard_data(\"../../nbs/results/train_100M_FIM_restart-spikes_merged/merged-131072\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "smoothing = 0.5\n",
    "steps = np.array([el.step for el in new_logging['eval/train_loss/all']])\n",
    "steps = interpolate(steps)\n",
    "t_loss = np.array(smooth([el.value for el in new_logging['eval/train_loss/all']], smoothing))\n",
    "v_loss = np.array(smooth([el.value for el in new_logging['eval/valid_loss/all']], smoothing))\n",
    "t_perp = np.array(smooth([el.value for el in new_logging['eval/train_perplexity/batch']], smoothing))\n",
    "v_perp = np.array(smooth([el.value for el in new_logging['eval/valid_perplexity/batch']], smoothing))\n",
    "t_perp_fim = np.array(smooth([el.value for el in new_logging['eval/train_perplexity/fim']], smoothing))\n",
    "v_perp_fim = np.array(smooth([el.value for el in new_logging['eval/valid_perplexity/fim']], smoothing))\n",
    "\n",
    "# logarithmically spaced indices for plots with x axis in log scale\n",
    "idxs_log = np.unique([int(el)-1 for el in np.logspace(np.log10(1), np.log10(len(t_loss)), num = 4000)])\n",
    "# linearly spaced indices for plots with x axis in linear scale\n",
    "# idxs_log = np.linspace(0, len(t_loss)-1, 1000, dtype=int)\n",
    "\n",
    "fig, axs = plt.subplots(3,1, figsize=(10,8), constrained_layout=True)\n",
    "\n",
    "axs[0].plot(steps[idxs_log],t_loss[idxs_log], color=color[0], label=\"Training\")\n",
    "axs[0].plot(steps[idxs_log],v_loss[idxs_log], color=color[2], label=\"Validation\")\n",
    "axs[1].plot(steps[idxs_log],t_perp[idxs_log], color=color[0])\n",
    "axs[1].plot(steps[idxs_log],v_perp[idxs_log], color=color[2])\n",
    "axs[2].plot(steps[idxs_log],t_perp_fim[idxs_log], color=color[0])\n",
    "axs[2].plot(steps[idxs_log],v_perp_fim[idxs_log], color=color[2])\n",
    "\n",
    "axs[0].set_xscale(\"log\")\n",
    "axs[0].set_ylabel(\"Loss\")\n",
    "axs[1].set_xscale(\"log\")\n",
    "axs[1].set_ylabel(\"Perplexity\\nall tokens\")\n",
    "axs[2].set_xscale(\"log\")\n",
    "\n",
    "axs[2].set_ylabel(\"Perplexity\\nFIM tokens\")\n",
    "axs[2].set_xlabel(\"Training tokens\")\n",
    "axs[0].legend()\n",
    "fig.suptitle(\"Loss and Perplexity during training\")\n",
    "plt.show()\n",
    "fig.savefig(\"figures/training_loss_perplexity.pdf\", format=\"pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.read_pickle(\"figures/generated_sequences/dataframe_check-131k_gen_seqs_full.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ProtMamba",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
