{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "notebookRunGroups": {
     "groupValue": "12"
    }
   },
   "source": [
    "# Random Number Generation Experiments\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "notebookRunGroups": {
     "groupValue": "2"
    }
   },
   "source": [
    "![RNG diagram](../rng.png)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "This notebook sets up experiments comparing different methods for training language models to generate random numbers from specified distributions.\n",
    "\n",
    "We will focus on sampling numbers from various distributions.\n",
    "\n",
    "The models we will compare are:\n",
    "\n",
    "- GFN-fine-tuned LM: Fine-tuned via generative flow networks\n",
    "- Likelihood-trained LM: Supervised-fine-tuned LM\n",
    "- RL-tuned LM: Fine-tuned via reinforcement learning (PPO)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Several axes of experimentation:\n",
    "\n",
    "- vary the distribution\n",
    "  - discrete: uniform, Poisson, Binomial, Geometric, etc\n",
    "  - continuous: uniform, Gaussian, exponential, etc\n",
    "- vary the hyperparameters of the distribution (in the context)\n",
    "  - Uniform: between 0 and `n_max`\n",
    "  - Poisson: `lambda` between `λ_min` and `λ_max`\n",
    "  - etc\n",
    "- vary the prompt\n",
    "  - 'Randomly generate (uniformly) one single random integer between 0 and {num_test}, and then stop: '\n",
    "  - 'Randomly generate (uniformly) one single random integer in the interval [0, {num_test}]: '\n",
    "  - 'Here is one single random integer sampled uniformly between 0 and {num_test}: '\n",
    "  - \"The following is a random integer drawn uniformly between 0 and {num_test}: \"\n",
    "  - etc\n",
    "- vary the model\n",
    "  - GFN-LM\n",
    "  - PPO\n",
    "  - MLE (SFT)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## General imports\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import hydra\n",
    "from hydra.experimental import initialize, compose"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "xq77AgKWJM-N"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2023-09-28 08:07:53,774] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "import json\n",
    "import numpy as np\n",
    "from itertools import chain\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from peft import LoraConfig, get_peft_model, PeftModel\n",
    "\n",
    "from utils import generate, generate_and_return_eos_logprob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "import matplotlib.font_manager\n",
    "import matplotlib.pyplot as plt\n",
    "import shutil\n",
    "\n",
    "import seaborn as sns\n",
    "from IPython.display import display, Markdown\n",
    "\n",
    "# # Remove the matplotlib cache\n",
    "# shutil.rmtree(matplotlib.get_cachedir())\n",
    "\n",
    "fonts = matplotlib.font_manager.findSystemFonts(fontpaths=None, fontext=\"ttf\")\n",
    "\n",
    "# print the names of all fonts\n",
    "font_names = [matplotlib.font_manager.get_font(x).family_name for x in fonts]\n",
    "print(font_names)\n",
    "\n",
    "fonts = [f.name for f in matplotlib.font_manager.fontManager.ttflist]\n",
    "print(fonts)\n",
    "print(\"Times New Roman\" in fonts)\n",
    "\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "matplotlib.rc(\"font\", family=\"Times New Roman\")\n",
    "\n",
    "print(matplotlib.get_configdir())\n",
    "print(matplotlib.get_cachedir())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Thu Sep 28 08:07:56 2023       \n",
      "+---------------------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.2     |\n",
      "|-----------------------------------------+----------------------+----------------------+\n",
      "| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                                         |                      |               MIG M. |\n",
      "|=========================================+======================+======================|\n",
      "|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:0F:00.0 Off |                    0 |\n",
      "| N/A   24C    P0              60W / 400W |      7MiB / 81920MiB |      0%      Default |\n",
      "|                                         |                      |             Disabled |\n",
      "+-----------------------------------------+----------------------+----------------------+\n",
      "                                                                                         \n",
      "+---------------------------------------------------------------------------------------+\n",
      "| Processes:                                                                            |\n",
      "|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |\n",
      "|        ID   ID                                                             Usage      |\n",
      "|=======================================================================================|\n",
      "|  No running processes found                                                           |\n",
      "+---------------------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "\n",
    "with torch.no_grad():\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load the pretrained model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize Hydra\n",
    "hydra.core.global_hydra.GlobalHydra.instance().clear()\n",
    "initialize(config_path=\"multiobjective-lm/rng/configs\")\n",
    "cfg = compose(config_name=\"config\")\n",
    "\n",
    "bsz = cfg.hparams.bsz\n",
    "grad_acc = cfg.hparams.grad_acc\n",
    "lr = cfg.hparams.lr\n",
    "warmup_steps = cfg.hparams.warmup_steps\n",
    "epochs = cfg.hparams.epochs\n",
    "max_len = cfg.hparams.max_len\n",
    "min_len = cfg.hparams.min_len\n",
    "eval_interval = cfg.hparams.eval_interval\n",
    "log_interval = cfg.hparams.log_interval\n",
    "model_to_use = cfg.hparams.model_to_use\n",
    "seed = cfg.hparams.seed\n",
    "save_dir = cfg.hparams.save_dir\n",
    "sft_epochs = cfg.hparams.SFT.epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the seed\n",
    "torch.manual_seed(seed)\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GPTJForCausalLM(\n",
       "  (transformer): GPTJModel(\n",
       "    (wte): Embedding(50400, 4096)\n",
       "    (drop): Dropout(p=0.0, inplace=False)\n",
       "    (h): ModuleList(\n",
       "      (0-27): 28 x GPTJBlock(\n",
       "        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): GPTJAttention(\n",
       "          (attn_dropout): Dropout(p=0.0, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.0, inplace=False)\n",
       "          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (out_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "        )\n",
       "        (mlp): GPTJMLP(\n",
       "          (fc_in): Linear(in_features=4096, out_features=16384, bias=True)\n",
       "          (fc_out): Linear(in_features=16384, out_features=4096, bias=True)\n",
       "          (act): NewGELUActivation()\n",
       "          (dropout): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (ln_f): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (lm_head): Linear(in_features=4096, out_features=50400, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "if model_to_use == \"gpt-j\":\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"nlpcloud/instruct-gpt-j-fp16\")\n",
    "    model = AutoModelForCausalLM.from_pretrained(\n",
    "        \"nlpcloud/instruct-gpt-j-fp16\", torch_dtype=torch.bfloat16\n",
    "    )\n",
    "elif model_to_use == \"gpt2\":\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
    "    model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n",
    "\n",
    "model.to(\"cuda\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LoRA, Optimizer, and Scheduler\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "lora_config = LoraConfig(\n",
    "    r=cfg.lora.r,\n",
    "    lora_alpha=cfg.lora.lora_alpha,\n",
    "    target_modules=[\"k_proj\", \"v_proj\"] if model_to_use == \"gpt-j\" else [\"c_attn\"],\n",
    "    lora_dropout=cfg.lora.lora_dropout,\n",
    "    bias=cfg.lora.bias,\n",
    ")\n",
    "inference_model = get_peft_model(model, lora_config)\n",
    "\n",
    "opt = torch.optim.AdamW(\n",
    "    [{\"params\": inference_model.parameters(), \"lr\": lr}],\n",
    "    betas=(cfg.adamw.b1, cfg.adamw.b2),\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataloaders\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torch.nn.utils.rnn import pad_sequence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Tokenizing dataset...: 4096it [00:01, 2764.58it/s]\n"
     ]
    }
   ],
   "source": [
    "from rng.rng_dataset import get_dataloader_from_dataframe\n",
    "\n",
    "df_train = pd.read_csv(cfg.file_name.train)\n",
    "\n",
    "train_loader = get_dataloader_from_dataframe(\n",
    "    df_train, tokenizer, bsz=bsz, shuffle=True, method=\"SFT\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_steps = epochs * len(train_loader)\n",
    "\n",
    "\n",
    "# learning rate schedule\n",
    "def get_lr_mult_at_step(step):\n",
    "    if step <= warmup_steps:\n",
    "        return min(step / warmup_steps, 1.0)\n",
    "    return max((total_steps - step) / (total_steps - warmup_steps), 0)\n",
    "\n",
    "\n",
    "sched = torch.optim.lr_scheduler.LambdaLR(opt, get_lr_mult_at_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Prompt</th>\n",
       "      <th>Value</th>\n",
       "      <th>Distribution and Parameters</th>\n",
       "      <th>Distribution</th>\n",
       "      <th>Parameters</th>\n",
       "      <th>Data Type</th>\n",
       "      <th>Distribution and Parameters Index</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>The following is a random integer drawn from a...</td>\n",
       "      <td>80</td>\n",
       "      <td>{\"distribution\": \"uniform discrete\", \"data_typ...</td>\n",
       "      <td>uniform discrete</td>\n",
       "      <td>{\"a\": 0, \"b\": 100}</td>\n",
       "      <td>int</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>The following is a random integer drawn from a...</td>\n",
       "      <td>69</td>\n",
       "      <td>{\"distribution\": \"uniform discrete\", \"data_typ...</td>\n",
       "      <td>uniform discrete</td>\n",
       "      <td>{\"a\": 0, \"b\": 100}</td>\n",
       "      <td>int</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>The following is a random integer drawn from a...</td>\n",
       "      <td>67</td>\n",
       "      <td>{\"distribution\": \"uniform discrete\", \"data_typ...</td>\n",
       "      <td>uniform discrete</td>\n",
       "      <td>{\"a\": 0, \"b\": 100}</td>\n",
       "      <td>int</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Randomly generate one single random integer fr...</td>\n",
       "      <td>99</td>\n",
       "      <td>{\"distribution\": \"uniform discrete\", \"data_typ...</td>\n",
       "      <td>uniform discrete</td>\n",
       "      <td>{\"a\": 0, \"b\": 100}</td>\n",
       "      <td>int</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Here is one single random integer sampled from...</td>\n",
       "      <td>53</td>\n",
       "      <td>{\"distribution\": \"uniform discrete\", \"data_typ...</td>\n",
       "      <td>uniform discrete</td>\n",
       "      <td>{\"a\": 0, \"b\": 100}</td>\n",
       "      <td>int</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                              Prompt  Value  \\\n",
       "0  The following is a random integer drawn from a...     80   \n",
       "1  The following is a random integer drawn from a...     69   \n",
       "2  The following is a random integer drawn from a...     67   \n",
       "3  Randomly generate one single random integer fr...     99   \n",
       "4  Here is one single random integer sampled from...     53   \n",
       "\n",
       "                         Distribution and Parameters      Distribution  \\\n",
       "0  {\"distribution\": \"uniform discrete\", \"data_typ...  uniform discrete   \n",
       "1  {\"distribution\": \"uniform discrete\", \"data_typ...  uniform discrete   \n",
       "2  {\"distribution\": \"uniform discrete\", \"data_typ...  uniform discrete   \n",
       "3  {\"distribution\": \"uniform discrete\", \"data_typ...  uniform discrete   \n",
       "4  {\"distribution\": \"uniform discrete\", \"data_typ...  uniform discrete   \n",
       "\n",
       "           Parameters Data Type  Distribution and Parameters Index  \n",
       "0  {\"a\": 0, \"b\": 100}       int                                  0  \n",
       "1  {\"a\": 0, \"b\": 100}       int                                  0  \n",
       "2  {\"a\": 0, \"b\": 100}       int                                  0  \n",
       "3  {\"a\": 0, \"b\": 100}       int                                  0  \n",
       "4  {\"a\": 0, \"b\": 100}       int                                  0  "
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training loop\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "dph6LMlH0ooD",
    "outputId": "cb17a778-4221-421b-f6b2-f163cad72797",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "inference_model.train()\n",
    "for epoch in range(sft_epochs):\n",
    "    for i, batch in tqdm(enumerate(train_loader), desc=f\"Epoch {epoch}\"):\n",
    "        opt.zero_grad()\n",
    "        input_ids, target_ids = batch\n",
    "        outputs = inference_model(input_ids, labels=target_ids)\n",
    "        loss = outputs.loss\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        sched.step()\n",
    "        if i % log_interval == 0:\n",
    "            print(f\"Epoch: {epoch}, Batch: {i}, Loss: {loss.item()}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save the model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_name = f\"rng-SFT_{model_to_use}_bsz_{bsz}_grad_acc_{grad_acc}_lr_{lr}_warmup_steps_{warmup_steps}_epochs_{sft_epochs}_max_len_{max_len}_min_len_{min_len}_eval_interval_{eval_interval}_log_interval_{log_interval}_training_samples_{len(train_loader)}_seed_{seed}\"\n",
    "inference_model.save_pretrained(f\"{save_dir}/{ckpt_name}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ckpt_name = \"rng-SFT_gpt-j_bsz_8_grad_acc_1_lr_5e-05_warmup_steps_100_epochs_10_max_len_512_min_len_1_eval_interval_100_log_interval_10_seed_42\"\n",
    "# model_path = f\"{save_dir}/{ckpt_name}\"\n",
    "# inference_model = PeftModel.from_pretrained(model, model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n",
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['51', '<|endoftext|>']\n",
      "['0', '<|endoftext|>']\n"
     ]
    }
   ],
   "source": [
    "inference_model.eval()\n",
    "with torch.inference_mode():\n",
    "    prompt_test = \"Randomly generate (uniformly) one single random integer between 0 and 520, and then stop: \"\n",
    "    print(\n",
    "        [\n",
    "            tokenizer.decode(t)\n",
    "            for t in inference_model.generate(\n",
    "                **tokenizer(prompt_test, return_tensors=\"pt\").to(\"cuda\"),\n",
    "                max_new_tokens=30,\n",
    "                temperature=0\n",
    "            )[0][len(tokenizer.encode(prompt_test)) :]\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    prompt_test = (\n",
    "        \"Here is one single random integer sampled uniformly between 0 and 520: \"\n",
    "    )\n",
    "    print(\n",
    "        [\n",
    "            tokenizer.decode(t)\n",
    "            for t in inference_model.generate(\n",
    "                **tokenizer(prompt_test, return_tensors=\"pt\").to(\"cuda\"),\n",
    "                max_new_tokens=30,\n",
    "                temperature=0\n",
    "            )[0][len(tokenizer.encode(prompt_test)) :]\n",
    "        ]\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rng.rng_utils import get_distribution\n",
    "\n",
    "n_max = 100\n",
    "intro_prompt = f\"The following is a random integer drawn uniformly between 0 and \"\n",
    "prompt = f\"{intro_prompt}{n_max-1}: \""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rng.rng_plot import plot_distribution\n",
    "\n",
    "n_samples = 1000 * 512"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [15:18<00:00,  1.09it/s]\n"
     ]
    }
   ],
   "source": [
    "inference_model.base_model.enable_adapter_layers()\n",
    "\n",
    "dist_inference, number_of_NaNs_inference = get_distribution(\n",
    "    inference_model, tokenizer, prompt, num_samples=n_samples\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "## SFT-finetuned Model: Distribution of generated numbers"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjMAAAG2CAYAAACKxwc0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA2wUlEQVR4nO3df3RU9Z3/8dfk5ySZ0iSFkoABKkiBUgLCSYMUIq2pPygoB3uqdneBFNKyLa6/iqlFylFgaURTK4KFoNtG1y6u2l3XdRGE0simsIEWyy8BNYnUASJOk4X8MGTu9w++ucvkB7kzzGTyCc/HOR7PfOZzP/dz3/fOzYt77yQuy7IsAQAAGCom2hMAAAC4HIQZAABgNMIMAAAwGmEGAAAYjTADAACMRpgBAABGI8wAAACjEWYAAIDR4qI9gUj74x//KMuyFB8fH+2pAAAAh1paWuRyuTRhwoRu+/b5KzOWZSlSv+TYsix9+umnERsfF1DnnkGdewZ17jnUumdEqs7B/Pzu81dm2q7IfPnLXw772A0NDTp8+LBGjBih5OTksI+PC6hzz6DOPYM69xxq3TMiVec///nPjvv2+SszAACgbyPMAAAAoxFmAACA0QgzAADAaIQZAABgNMIMAAAwGmEGAAAYjTADAACMRpgBAABGC/o3ALe2tqqkpER+v18+n0+zZ89WTk5Op3137typrVu3yuPxKDk5WYsXL5bL5ZIkVVZW6pFHHtGHH36oa6+9VqtWrdLAgQPtZd955x29+OKLSktL0/nz5/XAAw8oISEhxM0EAAB9VdBXZtasWaOEhAQtWbJEy5cv19KlS/Xhhx926Hfo0CEVFxdr2bJlKioq0tmzZ1VaWipJOnPmjHbv3q1f/epXeuGFF3Ts2DE9/vjj9rKnTp3SPffcowcffFBLlizR4MGDtWrVqsvYTAAA0FcFFWZ8Pp/Kyso0Y8YMSVJiYqImTpxoh5SLrVu3TtOnT7evpuTn52vjxo1qbm5WcnKyfvCDHygtLU1jxozRjBkzFBsbay/77LPPKjs7W6mpqfaymzdv1qlTp0LdTgAA0EcFdZupoqJCLS0tGjJkiN02fPhwbd68OaCfZVnatWuXpk2bFtCvrq5OBw4c0MSJEwP619XV6e6777Zfl5eXKz8/3349aNAgJSQkaPfu3Zo1a1YwU7bn09DQEPRy3WlsbAz4PyKDOvcM6twzqHPPodY9I1J1tizLfjSlO0GFGa/Xq5SUFPsvUUuSx+OR1+sN6Ofz+dTQ0GBfWWnrJ0knT5602yorK1VWVqaamhrV19crMzPTXs/Fy7Ytf/GywWhpadHhw4dDWtaJqqqqiI2N/0OdewZ17hnUuedQ654RiTo7fVY2qDDjcrnkdrsD2vx+v+Li4jr0ky7chrq4n6SAvldffbVmzpypn//85/rud7+rbdu22eNfvGxX63EqPj5eI0aMCGnZS2lsbFRVVZWGDRumpKSksI+PC6hzz6DOPYM69xxq3TMiVefjx4877htUOsjIyFB9fX1AW319vTIyMgLa0tLS5Ha7A/rW1dXZY7RJT0/XDTfcoGuuuUY33XSTjh49qnHjxikzMzNgWcuyOl2PUy6XS8nJySEt60RSUlJEx8cF1Lln9IU6n3i6OuD1VT8YGqWZdK0v1NkU1LpnhLvOTm8xSUE+AJybmyuXyxVwKam6ulpTp07t0DcvLy8gVdXU1Cg1NVVjx47t0Hfo0KFKS0vToEGD7GWPHTtmv+/1emVZlnJzc4OZLgAAuAIEFWbS09M1Z84cbdu2TdKFS0t79+7V/PnzVVtbq9WrV6upqUmSVFBQoB07dsiyLEnSli1btGjRIsXGxuqvf/1rwDMse/bs0Q033KD+/ftLku666y7t27fPHuvNN9/UnXfeqfT09MvfYgAA0KcE/RBKUVGRiouLtXbtWjvAZGZmav/+/Xr99dc1d+5cZWZmavz48VqwYIFWrFghj8ejAQMGaN68eZIuPPj70EMP6eqrr9b48ePVv39/LVu2zF5HVlaWVqxYoZUrV2rgwIFqampSUVFR2DYaAAD0HUGHGbfbHRA82mRnZ6u8vDygbdasWZ1+lfqGG27QDTfccMn1TJkyRVOmTAl2egAA4ArD32YCAABGI8wAAACjEWYAAIDRCDMAAMBohBkAAGA0wgwAADAaYQYAABiNMAMAAIxGmAEAAEYjzAAAAKMRZgAAgNEIMwAAwGiEGQAAYDTCDAAAMBphBgAAGI0wAwAAjEaYAQAARiPMAAAAoxFmAACA0QgzAADAaIQZAABgNMIMAAAwGmEGAAAYjTADAACMRpgBAABGI8wAAACjEWYAAIDRCDMAAMBohBkAAGA0wgwAADAaYQYAABiNMAMAAIxGmAEAAEYjzAAAAKMRZgAAgNEIMwAAwGiEGQAAYDTCDAAAMBphBgAAGI0wAwAAjEaYAQAARiPMAAAAoxFmAACA0QgzAADAaIQZAABgNMIMAAAwWly0JwAAAMxx4unqgNetra3SlChN5v/jygwAADBaUFdmWltbVVJSIr/fL5/Pp9mzZysnJ6fTvjt37tTWrVvl8XiUnJysxYsXy+VySZL27NmjlStXqqamRmPGjNHDDz+sUaNGBSw/b948VVRUSJIGDBig7du3KyEhIZRtBAAAfVhQYWbNmjVKSkrS3XffrebmZs2cOVObNm1SVlZWQL9Dhw6puLhYr776qhISErRq1SqVlpZq4cKFOnPmjH7961/r4Ycflt/v16pVq1RYWKg333xTbrdbklRZWakJEyaosLBQkvT5z3+eIIOIaX/J9KofDI3STAAAoXB8m8nn86msrEwzZsyQJCUmJmrixIkqLS3t0HfdunWaPn26HUDy8/O1ceNGNTc364MPPlBxcbEmTZqknJwcFRcX69SpUzp27Ji9/Pr165WTk6OcnBxdd911GjFixOVuJwAA6KMcX5mpqKhQS0uLhgwZYrcNHz5cmzdvDuhnWZZ27dqladOmBfSrq6vTgQMHNGnSpID+Q4YMUUxMjDIzMyVJR44c0YEDBzRv3jylp6fr4Ycf1i233BLSxl08p4aGhssaozONjY0B/0dkRLrOra2tAa8jcayYoC8dz715n/alOvd21Doy2n++/H6/pPDX2bIs+/GU7jgOM16vVykpKYqPj7fbPB6PvF5vQD+fz6eGhgalpqYG9JOkkydPdhi3srJSt912m/r37y9JGjVqlHbv3q2PPvpIzzzzjO699175/X5985vfdDrVDlpaWnT48OGQl+9OVVVVxMbG/4lUnZMbEgNeR/JYMUFfOJ5N2Kd9oc6moNbh1f7z1SYSdXb6iInjMONyuexnWtr4/X7FxcV16CdduA11cT9JHfr6/X698sorWrp0aYf1DRo0SI888ogSEhK0fv36ywoz8fHxEblV1djYqKqqKg0bNkxJSUlhHx8XRLrOtbsCQ/bQ0VeHfR0m6EvHc2/ep32pzr0dtY6M9p8vv9+vBjWHvc7Hjx933NdxmMnIyFB9fX1AW319vTIyMgLa0tLS5Ha7A/rW1dXZY1xs06ZNKiwsVHp6epfrXbhwoV566SWn0+yUy+VScnLyZY1xKUlJSREdHxdEqs6xsbEBr6/0fdkXjmcT9mlfqLMpqHV4tf98tQl3nZ3eYpKCeAA4NzdXLpcr4DJSdXW1pk6d2qFvXl5eQKKqqalRamqqxo4da7e9/PLLmjBhQoevZLcXGxurcePGOZ0mAAC4wjgOM+np6ZozZ462bdsm6cLlu71792r+/Pmqra3V6tWr1dTUJEkqKCjQjh07ZFmWJGnLli1atGiRneZeeOEFSReu1Jw4cUIHDx7Us88+K0l6++23tWfPHkkXnnV56qmn9NOf/jRMmwsAAPqaoH7PTFFRkYqLi7V27Vo7wGRmZmr//v16/fXXNXfuXGVmZmr8+PFasGCBVqxYIY/HowEDBmjevHmSpA0bNujxxx/vMPYTTzwhSTp69KjWrl2rkSNHauTIkZo/f76+8IUvXP6WAgCAPimoMON2u7Vs2bIO7dnZ2SovLw9omzVrlmbNmtWhb2Fhof3L8DpTUFCggoKCYKYFAACuYPxtJgAAYDTCDAAAMBphBgAAGI0wAwAAjEaYAQAARiPMAAAAoxFmAACA0QgzAADAaIQZAABgNMIMAAAwGmEGAAAYjTADAACMRpgBAABGI8wAAACjEWYAAIDRCDMAAMBohBkAAGA0wgwAADAaYQYAABgtLtoTAAA4c+Lp6oDXV/1gaJRmAvQuXJkBAABGI8wAAACjEWYAAIDReGYGUdf+OQCJZwEA9E089xQZhBkYi5MCAEDiNhMAADAcYQYAABiNMAMAAIxGmAEAAEYjzAAAAKMRZgAAgNEIMwAAwGiEGQAAYDTCDAAAMBphBgAAGI0wAwAAjEaYAQAARiPMAAAAoxFmAACA0QgzAADAaIQZAABgNMIMAAAwGmEGAAAYjTADAACMRpgBAABGI8wAAACjEWYAAIDRCDMAAMBoccEu0NraqpKSEvn9fvl8Ps2ePVs5OTmd9t25c6e2bt0qj8ej5ORkLV68WC6XS5K0Z88erVy5UjU1NRozZowefvhhjRo1yl72nXfe0Ysvvqi0tDSdP39eDzzwgBISEkLcTATrxNPVAa+v+sHQKM0EAIBLC/rKzJo1a5SQkKAlS5Zo+fLlWrp0qT788MMO/Q4dOqTi4mItW7ZMRUVFOnv2rEpLSyVJZ86c0a9//Ws9/PDD+uUvf6lz586psLBQTU1NkqRTp07pnnvu0YMPPqglS5Zo8ODBWrVq1WVuKgAA6IuCCjM+n09lZWWaMWOGJCkxMVETJ060Q8rF1q1bp+nTp9tXU/Lz87Vx40Y1Nzfrgw8+UHFxsSZNmqScnBwVFxfr1KlTOnbsmCTp2WefVXZ2tlJTU+1lN2/erFOnTl3OtgIAgD4oqNtMFRUVamlp0ZAhQ+y24cOHa/PmzQH9LMvSrl27NG3atIB+dXV1OnDggCZNmhTQf8iQIYqJiVFmZqYkqby8XPn5+fb7gwYNUkJCgnbv3q1Zs2YFM2V7Pg0NDUEv153GxsaA//clra2tAa8jUb+u1tV+fV3VOVxz7Mlt7Um1pScDXg9YkHHJ/n3peO7N+/Ry6tybt6s36o3HdF/Yh+23we/3Swp/nS3Lsh9N6U5QYcbr9SolJUXx8fF2m8fjkdfrDejn8/nU0NBgX1lp6ydJJ08GnmAlqbKyUrfddpv69+9vr+fiZduW72xZJ1paWnT48OGQlnWiqqoqYmNHS3JDYsDrSNav/bq6Wl/7Oodrjj25rT0p1O3qC8ezk21P3hrYpyG/OaJzai+UOvfVYzXSetMx3Rf2YWfnbCkydXb6rGxQYcblcsntdge0+f1+xcXFdegnXbgNdXE/SR36+v1+vfLKK1q6dGlA+8XLdrUep+Lj4zVixIiQlr2UxsZGVVVVadiwYUpKSgr7+NFUuyswOA4dfXXHPqUdw2V3//p3sq726+uqzk7mGMr6Qx2nJzm56hLsdvWl49nR8Rul/X45dY72sRrs1b5o643HdLT3YTi03wa/368GNYe9zsePH3fcN6h0kJGRofr6+oC2+vp6ZWQEHtBpaWlyu90Bfevq6uwxLrZp0yYVFhYqPT3dbsvMzAxY1rKsTtfjlMvlUnJyckjLOpGUlBTR8aMhNjY24HVn29e+T1f9gl1XV+O0r3P75T55rrbDMk6+heVkW3ubUPaP0+3qC8dzJOsTLqHUOdpzjvb6QxWJY7r9Nz6lvnu+aa+zc7YU/jo7vcUkBRlmcnNz5XK57KQrSdXV1Zo6dWqHvnl5eQGpqqamRqmpqRo7dqzd9vLLL2vChAkBX8luW7btYWDpwm0ny7KUm5sbzHQB4IrT236tQqg/9HvbdqB3C+rbTOnp6ZozZ462bdsm6cIlvL1792r+/Pmqra3V6tWr7a9XFxQUaMeOHbIsS5K0ZcsWLVq0yE50L7zwgqQLV2pOnDihgwcP6tlnn5Uk3XXXXdq3b5891ptvvqk777wz4OoNAACAFMIvzSsqKlJxcbHWrl1rB5jMzEzt379fr7/+uubOnavMzEyNHz9eCxYs0IoVK+TxeDRgwADNmzdPkrRhwwY9/vjjHcZ+4oknJElZWVlasWKFVq5cqYEDB6qpqUlFRUWXt6UAAKBPCjrMuN1uLVu2rEN7dna2ysvLA9pmzZrV6VepCwsLVVhYeMn1TJkyRVOmTAl2egAA4ArD32YCAABGI8wAAACjEWYAAIDRCDMAAMBohBkAAGA0wgwAADAaYQYAABiNMAMAAIxGmAEAAEYjzAAAAKMRZgAAgNEIMwAAwGiEGQAAYDTCDAAAMBphBgAAGI0wAwAAjEaYAQAARiPMAAAAoxFmAACA0eKiPQEAAHB5Tjxd3aHtqh8MjcJMooMwA/Qy7U9K0T4h9bb5hKKzE/2VrC/sU+BihBlE1JX+r4UrBT8cAUQTYQaO8C/bQIQ0AOg9CDMG4wcq2hA2gd6Hz2XPIcygxzn5gF/cp7W1VZoSyRkBAExGmLlMyVsTVbvrpGJjYyVxZQQAgJ5GmEFYcVkVANDTCDMggAAAjMZvAAYAAEYjzAAAAKMRZgAAgNEIMwAAwGiEGQAAYDTCDAAAMBphBgAAGI0wAwAAjEaYAQAARiPMAAAAoxFmAACA0QgzAADAaIQZAABgNMIMAAAwWly0JwAAl+PE09XRngIgiWMxmggzAKKisxP/VT8YGoWZADAdt5kAAIDRuDLTA9r/CzSS//rsyXUhELUH+oa+ctXwSjonEWYAGIXnEgC0R5iJAicn476coAEAXbuSrqiES1BhprW1VSUlJfL7/fL5fJo9e7ZycnI67btz505t3bpVHo9HycnJWrx4sVwul/1+ZWWlNmzYoJtvvlmzZ8/usPy8efNUUVEhSRowYIC2b9+uhISEYKZrtL5ymRMA+oIr+YqgCdseVJhZs2aNkpKSdPfdd6u5uVkzZ87Upk2blJWVFdDv0KFDKi4u1quvvqqEhAStWrVKpaWlWrhwoSTpzJkzamlpUUVFhW666aYO66msrNSECRNUWFgoSfr85z9/RQUZRBdBElcajvnoMSEomMDxt5l8Pp/Kyso0Y8YMSVJiYqImTpyo0tLSDn3XrVun6dOn2wEkPz9fGzduVHNzsyTpc5/7nCZPnqz09PRO17V+/Xrl5OQoJydH1113nUaMGBH0hgEAgCuD4yszFRUVamlp0ZAhQ+y24cOHa/PmzQH9LMvSrl27NG3atIB+dXV1OnDggCZOnGi3x8bGdljPkSNHdODAAc2bN0/p6el6+OGHdcsttwS1Ue1ZlqWGhobLGqMzjY2NkiS/32+3dbae1tbWsKyv/dhOxnWy3eGaX6S01bet3m16cvsjVXsn8wnXMdXdfNrq212dncwn1LpH83iNxDmiM13V2Ylw7YtQj99o7udQ9k+ote5t58RIncec6m7srs7Rl8uyrIDHUy7FcZjxer1KSUlRfHy83ebxeOT1egP6+Xw+NTQ0KDU1NaCfJJ08ebLb9YwaNUq7d+/WRx99pGeeeUb33nuv/H6/vvnNbzqdagctLS06fPhwyMtfSrIS1dTUZL/ubD3JDYlhWVf7sZ2M62S7wzW/SKuqqgp43ZPbH0rtq3/xfoe2hvzmoOcTrmPK6Weguzo7mU+odY/m8Rqpc0RX2tfZiXDti1A/O9Hcz5ezf4KtdW87J0bqPOaU03qEckx3x+kjJo7DjMvlktvtDmjz+/2Ki4vr0E+6cBvq4n6SOvS9lEGDBumRRx5RQkKC1q9ff1lhJj4+PiK3qhobG3VaXrndbsXEXLhjN3T01R361e7qPsQ50X5sJ+N2Np/2wjW/SPH7/WpQs4YNG6akpCS7vSe3P5Tah2s+4Tqmult3Y2Ojqqqquq2zk/mEWvdoHq9O1h0OXdXZiXDti1A/O9Hcz6Hsn1Br3dvOiZE6jznV3dhdnaMv1/Hjxx33dZwuMjIyVF9fH9BWX1+vjIyMgLa0tDS53e6AvnV1dfYYwVq4cKFeeumloJe7mMvlUnJy8mWNcSkxMTH2LbPO1tPZ7bRQtB/bybhOtjtc84u0pKSkgO3pye0Ppfbhmk+4jimnn4Hu6uxkPqHWPZrHayTPEZ1pX2cnwrUvQv3sRHM/X87+CbbWve2cGKnzmFNO6xHKMX0pTm8xSUE8AJybmyuXyxVwGam6ulpTp07t0DcvLy8gUdXU1Cg1NVVjx451PLE2sbGxGjduXNDLAQCAK4PjMJOenq45c+Zo27Ztki5cvtu7d6/mz5+v2tparV692n52pKCgQDt27JBlWZKkLVu2aNGiRR3SnWVZdp82b7/9tvbs2SPpwrMuTz31lH7605+GvoUAAKBPC+r3zBQVFam4uFhr1661A0xmZqb279+v119/XXPnzlVmZqbGjx+vBQsWaMWKFfJ4PBowYIDmzZtnj3Pu3Dm99dZbqq2t1fbt2zV69GiNGTNGknT06FGtXbtWI0eO1MiRIzV//nx94QtfCOtGAwCAviOoMON2u7Vs2bIO7dnZ2SovLw9omzVrlmbNmtXpOCkpKV2+X1BQoIKCgmCmBQAArmCObzMBAAD0RoQZAABgNMIMAAAwGmEGAAAYLagHgGEe/houAKCv48oMAAAwGmEGAAAYjTADAACMRpgBAABGI8wAAACjEWYAAIDRCDMAAMBohBkAAGA0wgwAADAaYQYAABiNMAMAAIxGmAEAAEYjzAAAAKMRZgAAgNEIMwAAwGiEGQAAYDTCDAAAMBphBgAAGI0wAwAAjEaYAQAARiPMAAAAoxFmAACA0QgzAADAaIQZAABgNMIMAAAwGmEGAAAYjTADAACMRpgBAABGI8wAAACjEWYAAIDRCDMAAMBohBkAAGA0wgwAADAaYQYAABiNMAMAAIxGmAEAAEYjzAAAAKMRZgAAgNEIMwAAwGiEGQAAYDTCDAAAMBphBgAAGI0wAwAAjBYX7AKtra0qKSmR3++Xz+fT7NmzlZOT02nfnTt3auvWrfJ4PEpOTtbixYvlcrns9ysrK7VhwwbdfPPNmj17dsCy77zzjl588UWlpaXp/PnzeuCBB5SQkBDsdAEAQB8X9JWZNWvWKCEhQUuWLNHy5cu1dOlSffjhhx36HTp0SMXFxVq2bJmKiop09uxZlZaW2u+fOXNGLS0tqqiokGVZAcueOnVK99xzjx588EEtWbJEgwcP1qpVq0LYPAAA0NcFFWZ8Pp/Kyso0Y8YMSVJiYqImTpwYEFLarFu3TtOnT7evpuTn52vjxo1qbm6WJH3uc5/T5MmTlZ6e3mHZZ599VtnZ2UpNTbWX3bx5s06dOhXUxgEAgL4vqNtMFRUVamlp0ZAhQ+y24cOHa/PmzQH9LMvSrl27NG3atIB+dXV1OnDggCZOnGi3x8bGdlhPeXm58vPz7deDBg1SQkKCdu/erVmzZgUzZXs+DQ0NQS/XncbGRkmS3++32zpbT2tra1jW137sUMcN1zg9pa2+bfVu42TeTvZ7KOOEq/ZO5hOuY6q7dbfVt7s6O5lPqHUP1/4KRSTOEZ3pqs5OhGtfhPrZieZ+DmX/hFrr3nZOjNR5zKnuxu7qHH25LMsKeDTlUoIKM16vVykpKYqPj7fbPB6PvF5vQD+fz6eGhgb7ykpbP0k6efKko/VcvGzb8k6W7UxLS4sOHz4c0rLdSVaimpqa7NedrSe5ITEs62o/dqjjhmucnlZVVRXw2sm8nez3UMYJV+2dzCdcx5TTz0B3dXYyn1DrHq79FYpInSO60r7OToRrX4T62Ynmfr6c/RNsrXvbOTFS5zGnnNYjlGO6O06flQ0qzLhcLrnd7oA2v9+vuLi4Dv2kC7ehLu4nqUPfrly8bFfrcSo+Pl4jRowIadlLaWxs1Gl55Xa7FRNz4Y7d0NFXd+hXuyu0ENZe+7FDHTdc4/QUv9+vBjVr2LBhSkpKstudzLuz/dFeKOOEq/ZO5hOuY6q7dTc2NqqqqqrbOjuZT6h1D9f+CoWTdYdDV3V2Ilz7ItTPTjT3cyj7J9Ra97ZzYqTOY051N3ZX5+jLdfz4ccd9g0oHGRkZqq+vD2irr69XRkZGQFtaWprcbndA37q6OnuM7mRmZgYsa1lWp+txyuVyKTk5OaRlnYiJibFvl3W2ns5upYWi/dihjhuucXpaUlJSwNydzNvJfg9lnHDV3sl8wnVMOf0MdFdnJ/MJte7h2l+hiOQ5ojPt6+xEuPZFqJ+daO7ny9k/wda6t50TI3Uec8ppPUI5pi/F6S0mKcgHgHNzc+VyuQIuJVVXV2vq1Kkd+ubl5QWkqpqaGqWmpmrs2LHdricvL0/Hjh2zX3u9XlmWpdzc3GCmCwAArgBBhZn09HTNmTNH27Ztk3ThEt7evXs1f/581dbWavXq1fbzIwUFBdqxY4f9testW7Zo0aJFHRKeZVkdvpp91113ad++ffZYb775pu68885Ov/kEAACubEE/hFJUVKTi4mKtXbvWDjCZmZnav3+/Xn/9dc2dO1eZmZkaP368FixYoBUrVsjj8WjAgAGaN2+ePc65c+f01ltvqba2Vtu3b9fo0aM1ZswYSVJWVpZWrFihlStXauDAgWpqalJRUVHYNhoAAPQdQYcZt9utZcuWdWjPzs5WeXl5QNusWbO6/Cp1SkrKJd+fMmWKpkyZEuz0AADAFYa/zQQAAIxGmAEAAEYjzAAAAKMRZgAAgNEIMwAAwGiEGQAAYDTCDAAAMBphBgAAGI0wAwAAjEaYAQAARiPMAAAAoxFmAACA0QgzAADAaIQZAABgNMIMAAAwGmEGAAAYjTADAACMRpgBAABGI8wAAACjEWYAAIDRCDMAAMBohBkAAGA0wgwAADAaYQYAABiNMAMAAIxGmAEAAEYjzAAAAKMRZgAAgNEIMwAAwGiEGQAAYDTCDAAAMBphBgAAGI0wAwAAjEaYAQAARiPMAAAAoxFmAACA0QgzAADAaIQZAABgNMIMAAAwGmEGAAAYjTADAACMRpgBAABGI8wAAACjEWYAAIDRCDMAAMBohBkAAGA0wgwAADAaYQYAABgtLtgFWltbVVJSIr/fL5/Pp9mzZysnJ6fTvjt37tTWrVvl8XiUnJysxYsXy+VySZIaGxv1s5/9TP369dPp06dVUFCgkSNHBqznlltuUVVVlSTpS1/6kl555ZUQNhEAAPRlQYeZNWvWKCkpSXfffbeam5s1c+ZMbdq0SVlZWQH9Dh06pOLiYr366qtKSEjQqlWrVFpaqoULF0qSHnzwQeXl5WnOnDn6+OOPdccdd+iVV15Rv379JEmvvfaa/u7v/k5f+MIXJKnD+AAAAFKQt5l8Pp/Kyso0Y8YMSVJiYqImTpyo0tLSDn3XrVun6dOnKyEhQZKUn5+vjRs3qrm5WUeOHNFbb72lm266SZLUv39/DRo0SL/5zW8kXbgq85vf/Ebjxo1Tbm6urrvuOsIMAADoVFBXZioqKtTS0qIhQ4bYbcOHD9fmzZsD+lmWpV27dmnatGkB/erq6nTgwAHt27dP6enpSklJCXi/oqJChYWFKi8v13vvvafbb79dgwcP1sqVKzV58uRQt1GWZamhoSHk5bvS2NgoSfL7/XZbZ+tpbW0Ny/rajx3quOEap6e01bet3m2czNvJfg9lnHDV3sl8wnVMdbfutvp2V2cn8wm17uHaX6GIxDmiM13V2Ylw7YtQPzvR3M+h7J9Qa93bzomROo851d3YXZ2jL5dlWfajKd0JKsx4vV6lpKQoPj7ebvN4PPJ6vQH9fD6fGhoalJqaGtBPkk6ePCmv1xvwXvtxrr/+ev3P//yP3nvvPZWUlKigoEBlZWWaNGlSMNO1tbS06PDhwyEt251kJaqpqcl+3dl6khsSw7Ku9mOHOm64xulpbc9PtXEybyf7PZRxwlV7J/MJ1zHl9DPQXZ2dzCfUuodrf4UiUueIrrSvsxPh2hehfnaiuZ8vZ/8EW+vedk6M1HnMKaf1COWY7k7b3Z3uBBVmXC6X3G53QJvf71dcXFyHftKF21AX95OkuLi4Lse5OCRJF67WPPXUU/rhD3+oDRs2hBxm4uPjNWLEiJCWvZTGxkadlldut1sxMRfu2A0dfXWHfrW7ToZlfe3HDnXccI3TU/x+vxrUrGHDhikpKcludzLvzvZHe6GME67aO5lPuI6p7tbd2NioqqqqbuvsZD6h1j1c+ysUTtYdDl3V2Ylw7YtQPzvR3M+h7J9Qa93bzomROo851d3YXZ2jL9fx48cd9w0qzGRkZKi+vj6grb6+XhkZGQFtaWlpcrvdAX3r6ursMTIyMvT22293GGfgwIEd1ulyufTd735XS5cuDWaqHcZITk4OefnuxMTEKDY2VpI6XU/be5er/dihjhuucXpaUlJSwNydzNvJfg9lnHDV3sl8wnVMOf0MdFdnJ/MJte7h2l+hiOQ5ojPt6+xEuPZFqJ+daO7ny9k/wda6t50TI3Uec8ppPUI5pi/F6S0mKcgHgHNzc+VyuQIuJVVXV2vq1Kkd+ubl5QWkqpqaGqWmpmrs2LHKy8vTX/7yl4D7d12NI10oZHZ2djBTBQAAV4igwkx6errmzJmjbdu2SbpwCW/v3r2aP3++amtrtXr1avv5kYKCAu3YsUOWZUmStmzZokWLFik2NlYjR47U5MmTtXPnTknS6dOn9dFHH+n222+XJL3xxhs6dOiQJOncuXN6/vnndf/994dniwEAQJ8S9O+ZKSoqUnFxsdauXWsHmMzMTO3fv1+vv/665s6dq8zMTI0fP14LFizQihUr5PF4NGDAAM2bN88ep7i4WI899piqqqrk9Xr1y1/+0v520x//+Ec99NBDGjdunK655hr96Ec/Uv/+/cO20QAAoO8IOsy43W4tW7asQ3t2drbKy8sD2mbNmqVZs2Z1Ok5aWppWrVrV6XsPPfSQHnrooWCnBgAArkD8bSYAAGA0wgwAADAaYQYAABiNMAMAAIxGmAEAAEYjzAAAAKMRZgAAgNEIMwAAwGiEGQAAYDTCDAAAMBphBgAAGI0wAwAAjEaYAQAARiPMAAAAoxFmAACA0QgzAADAaIQZAABgNMIMAAAwGmEGAAAYjTADAACMRpgBAABGI8wAAACjEWYAAIDRCDMAAMBohBkAAGA0wgwAADAaYQYAABiNMAMAAIxGmAEAAEYjzAAAAKMRZgAAgNEIMwAAwGiEGQAAYDTCDAAAMBphBgAAGI0wAwAAjEaYAQAARiPMAAAAoxFmAACA0QgzAADAaIQZAABgNMIMAAAwGmEGAAAYjTADAACMRpgBAABGI8wAAACjEWYAAIDRCDMAAMBoccEu0NraqpKSEvn9fvl8Ps2ePVs5OTmd9t25c6e2bt0qj8ej5ORkLV68WC6XS5LU2Nion/3sZ+rXr59Onz6tgoICjRw50tGyAAAAbYIOM2vWrFFSUpLuvvtuNTc3a+bMmdq0aZOysrIC+h06dEjFxcV69dVXlZCQoFWrVqm0tFQLFy6UJD344IPKy8vTnDlz9PHHH+uOO+7QK6+8on79+nW7LAAAQJugbjP5fD6VlZVpxowZkqTExERNnDhRpaWlHfquW7dO06dPV0JCgiQpPz9fGzduVHNzs44cOaK33npLN910kySpf//+GjRokH7zm990uywAAMDFgroyU1FRoZaWFg0ZMsRuGz58uDZv3hzQz7Is7dq1S9OmTQvoV1dXpwMHDmjfvn1KT09XSkpKwPsVFRVauHDhJZedOHFiUBvY0tIiy7L0zjvvBLWcE5ZlqXXKeTW5PpX+/y2wT96p69CvdXxrWNbXfuxQxw3XOD3GsiRJx44dC7jV6GTene2P9kIZJ1y1dzKfcB1T3a3bclhnJ/MJte7h2l+hcLLucOiqzk6Ea1+E+tmJ5n4OZf+EWuvedk6M1HnMqW7Hvoxj+lJaWlocjxdUmPF6vUpJSVF8fLzd5vF45PV6A/r5fD41NDQoNTU1oJ8knTx5Ul6vN+C9i8fpbtlgtRUiEs/buFwuxaQmdNsvrl/Qd/McCde4kZpfpEVz+yNZMydjR2L9LpfLvhraE/MJdRtMPV7bdFVnJ8K1L3qy9tHcz6HW2sRjLNrnpEhwuVyRCTMul0tutzugze/3Ky4urkM/6cJtqIv7SVJcXFyX48THx3e7bLAmTJgQ9DIAAMAcQT0zk5GRofr6+oC2+vp6ZWRkBLSlpaXJ7XYH9K2rq7PH6GqcgQMHdrssAADAxYIKM7m5uXK5XKqqqrLbqqurNXXq1A598/LydPz4cft1TU2NUlNTNXbsWOXl5ekvf/mLGhoaOh3nUssCAABcLKgwk56erjlz5mjbtm2SLvyumL1792r+/Pmqra3V6tWr1dTUJEkqKCjQjh077AewtmzZokWLFik2NlYjR47U5MmTtXPnTknS6dOn9dFHH+n222/vdlkAAICLuay2xOBQU1OTiouLlZ6ertraWt1666269tprtX//fv3whz/U5s2blZmZKUn693//d+3fv9/+xXff+9737HF8Pp8ee+wxZWVlyev1au7cuRo+fLj9/qWWBQAAaBN0mAEAAOhN+NtMAADAaIQZAABgNMIMAAAwGmEGAAAYjTADAACMRpgBAABGI8wAAACjEWYAAIDRzPs7571Ea2urSkpK5Pf75fP5NHv2bOXk5ER7WsZ79913tXz5ch05ckTDhg1TUVGRvvKVr0i68Pe71q9fr/79+6u+vl4PPPCA+vXrF+UZm6+0tFQ7d+5UWVmZJOocCSdOnNBrr72mq666SkOHDtW4ceOoc5idOHFC69ev1/Dhw9XU1KS4uDgVFhZKkt555x29+OKLSktL0/nz5/XAAw8oISEhyjM2x5EjR7Rx40YNHz5cf//3f2+3d3cMb9y4UR9//LEaGhp03XXX6eabb47cJC2EZPXq1daTTz5pWZZlNTU1Wfn5+VZNTU2UZ2W25uZma9GiRdbbb79t/elPf7Lmzp1rjR8/3jp58qR17tw56+tf/7r1wQcfWJZlWVu3brUWLFgQ3Qn3AZWVldbXvvY162/+5m8sy7KocwSUl5db3//+963//d//tduoc/h9+9vftnbv3m2//tGPfmS98cYb1smTJ63p06dbPp/PsizL+qd/+ifrpz/9aXQmaaCzZ89af/jDH6wpU6ZYv/jFL+z27o7hsrIya8mSJZZlWZbf77fmzJlj7du3L2Lz5DZTCHw+n8rKyjRjxgxJUmJioiZOnKjS0tIoz8xs1dXVWrZsmaZMmaLs7Gw9+eST+vTTT/XHP/5R//qv/6r09HQNGzZMkjRt2jTt3r1b+/fvj+6kDfbJJ5/otdde06233mq3UefwOnLkiB599FEVFxfL4/HY7dQ5/N59913V19fbr1NTU1VfX69nn31W2dnZSk1NlSTl5+dr8+bNOnXqVJRmapaUlBR95Stf0ZAhQwLaL3UMt7a2au3atfbPSJfLpeuvv15PP/10xOZJmAlBRUWFWlpaAnbu8OHDVVFREcVZme+aa65RRkaG/fqzn/2sPvvZz2rw4MEqLy9XVlaW/V5CQoKysrL03//939GYqvEsy9KTTz6p++67Ty6Xy26nzuH16KOP6ktf+pKeeeYZ3XXXXXr66afV2tpKnSNg5syZevTRR/X+++/r1KlTOnPmjG699VaVl5cHnKsHDRqkhIQE7d69O4qzNU9sbGzA60sdwwcPHpTP5+vwM3LPnj06f/58RObHMzMh8Hq9SklJUXx8vN3m8Xjk9XqjOKu+54MPPtCIESP05S9/WV6vV0OHDg14n5qHbtOmTfrWt77V4RkN6hw+NTU1qqys1C9+8QvdeOONOnLkiO644w6dP3+eOkfA0qVLde7cOX3729/WtGnTVFxcrNjYWHm9XvuqTBuPx6OTJ09GZ6J9xKWO4Y8++kiSAuru8XjU3Nwsn8+nAQMGhH0+XJkJgcvlktvtDmjz+/2KiyMbhtNzzz2nRx55RNKFmicmJga87/f7AwIlnNm9e7c+85nPaOzYsR3eo87h8+6770qSvvrVr0qSRo0apRtvvFEvv/wydY6ATz/9VF/84he1cuVKbd++3T53SOq01pyvL8+ljuG2q70X/5z0+/2SFLG6szdDkJGREXBvVpLq6+sDbpHg8rz66qu68cYb7fuxXdV84MCBUZid2datW6eDBw/qsccekyQ1NzertbVVkyZN0pgxY6hzmLRdTo+J+b9/M44ePVpvvPGGvvjFL1LnMLv33nt13333afTo0crIyNDcuXP1la98RZmZmQG1tiyL83UYXOqcnJmZKUmqq6uzA019fb2SkpI6XCULF67MhCA3N1cul0tVVVV2W3V1taZOnRq9SfUhO3bskMfj0ZQpU+y2vLw8HTt2zH7d0tKijz76SNOmTYvGFI22Zs0a/fa3v7X/u+OOOzR27Fj99re/1Te+8Q3qHCZjxoyRJL3//vt2W1xcnK655hqO5zDz+Xz6/e9/r6uvvlqSNG7cOBUUFGjv3r0dau31emVZlnJzc6M13T7hUsfw6NGjNWDAAB0/ftx+v7q6WlOmTAl4Ri+cCDMhSE9P15w5c7Rt2zZJUmNjo/bu3av58+dHeWbm27Jli9577z2NHj1aJ06c0NGjR/X000/rtttuk9fr1enTpyVJv/vd7/TVr35Vo0aNivKMzTNgwABdddVV9n/9+vVTYmKirrrqKuocRkOHDtXNN9+sV1991W7bs2ePCgsLqXOYpaamKisrS3/+85/tNpfLpWuvvVZ33XWX9u3bp6amJknSm2++qTvvvFPp6enRmq6RLMuSZVn260sdw/Hx8Zo7d679M9Lv9+t3v/udvve970Vsfi7r4tnBsaamJhUXFys9PV21tbW69dZbde2110Z7WkZ77bXX9OCDD6q1tTWg/f7771dhYaEOHjyo559/XllZWTpz5ozuvffegK+7IjRPPfWU9uzZY//SPOocPg0NDXr00Uc1ePBgSVJaWpq+853vSKLO4fb+++9r/fr1Gj9+vGJiYuT3++1a79q1S//1X/+lgQMHqqmpSffccw/PzDjU2tqqbdu2afny5Ro2bJjuv/9+TZo0SdKlj2G/36+SkhLFxsbq7Nmzmjx5sr7+9a9HbJ6EGQAAYDRuMwEAAKMRZgAAgNEIMwAAwGiEGQAAYDTCDAAAMBphBgAAGI0wAwAAjEaYARB2Z8+e1T//8z/r+uuv15gxY1RZWRnwvt/v12uvvaYbbrhB3/3ud7Vr166wz+HYsWNavny5Zs6cGfaxAfQuhBkAYefxeHTXXXeprKxMra2tuv/+++Xz+ez3Y2JiNHPmTN14442aO3duwN/hCpfPfOYz+uSTT3Tu3Lmwjw2gdyHMAIiYrKwspaWl6fTp0/rxj3/c4X23263ExMSIrDsjI8P+w4MA+jbCDICIuuaaa/T9739fO3bs0K9+9aseXXdsbGyPrg9AdBBmAETc4sWLlZubq8cee0wHDx4MeM+yLL300kuaMGGCioqKJElHjhzR/Pnz9cUvflGSdPLkSZWUlOi6665TXV2d7r//fl177bVauHChmpqa9NJLL2nq1Kn66le/2unzN3/60580Y8YMTZ48WSUlJfL7/fZ7hw8f1j/+4z/qnnvu0YwZM/Tyyy9Lkg4cOKAf//jHWrBggf7zP/9TOTk52rBhQ6RKBOAyEGYARFxMTIwef/xxpaam6t5779XZs2ft91wul771rW9pzJgxdtuoUaM0Y8aMgDHq6+t15swZ/cd//Ifuu+8+PfHEE/r973+vn/zkJ0pLS9O//du/acyYMVq1alWH5bZv366f/OQn+sY3vqFnnnlGL7zwgiTpr3/9q55//nn9+Mc/1s9//nMVFhbqJz/5iSorK5WUlKSjR4+qqqpKDQ0NuvPOO5WVlRXBKgEIFX8DHUCP6N+/v5544gnNmzdPy5cv15o1awLej4mJ6fJ1RkaGHXa+853vSJIGDx6sz33ucxo2bJhuuOEGSdLXvvY1PfroowHj9OvXT/fdd58k6brrrtP777+vf/mXf9Hf/u3f6oUXXtBf//pX+4pLY2OjcnNz9Ze//EWTJk3S8OHDdfToUd1+++1hrASAcCPMAOgxOTk5uvvuu1VSUqLJkycHtWxnz7+43e6A1wkJCTp//vwlx5k6daqefvppSdLRo0eVnZ2twsLCTvvGxMToM5/5TFDzBNDzuM0EoEd973vf07Rp0/Too4/qvffe6/H1p6SkKCkpSZL06aef6sCBAx36fPLJJz09LQCXgTADoEe5XC4VFxcrNTVVb7zxht0eHx+vpqYm+3XbQ7oXP6wbDtXV1fbvtbnmmmu0bds2HT582H7/ww8/1B/+8IewrhNAZBFmAERMU1NTQEBpk5aWppKSEsXHx9ttWVlZqqys1IEDB7R161Zt27ZNkrR79241NDSotbVVkuz/SxeCzsVhx7KsDn3OnTtnX2k5deqUfve73+kf/uEfJF14/iYxMVFz587V2rVr9dxzz2n58uX6+te/bo//6aefhqUWACKHMAMgIt59912tWbNGhw8f1gsvvKCPP/444P0JEybogQcesF8vWLBAn/3sZzV//nx9+OGH+sY3vqEvf/nLOnXqlGpqavTaa69JkjZs2KBPPvlEzz//vB1OKisrdfDgQf32t7+VJK1fv16NjY36zne+o5kzZ2rhwoVasmSJnnzyST3zzDMaMmSIJGngwIHauHGjBg8erI0bN2r79u1avny5EhMTtXPnTv3hD3/QwYMH9dxzzwUEJAC9i8tq+6cMAACAgbgyAwAAjEaYAQAARiPMAAAAoxFmAACA0QgzAADAaIQZAABgNMIMAAAwGmEGAAAYjTADAACMRpgBAABGI8wAAACjEWYAAIDRCDMAAMBo/w/v40FWuO1xCwAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_distribution(\n",
    "    dist_inference,\n",
    "    n_max=n_max,\n",
    "    model_name=\"SFT-finetuned Model\",\n",
    "    color=\"orchid\",\n",
    "    number_of_NaNs=number_of_NaNs_inference,\n",
    "    xlims=(-5, n_max + 5),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists(\"plots\"):\n",
    "    os.makedirs(\"plots\")\n",
    "\n",
    "df_inference = pd.DataFrame(dist_inference, columns=[\"Generated Numbers\"])\n",
    "df_inference.to_csv(f\"plots/SFT-dist_inference_{ckpt_name}.csv\", index=False)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python (myenv)",
   "language": "python",
   "name": "myenv"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
