{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c3d4e482",
   "metadata": {},
   "source": [
    "# 1. Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff898d49",
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "import transformers\n",
    "import torch\n",
    "import logging\n",
    "import random\n",
    "import pandas as pd\n",
    "from textstat import automated_readability_index\n",
    "import json\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "631e2b31",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Basic vars\n",
    "prompt = \"Once upon a time, in a big forest, there lived a rhinoc\"\n",
    "model = transformers.AutoModelForCausalLM.from_pretrained('roneneldan/TinyStories-8M')\n",
    "model.eval()\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(\"EleutherAI/gpt-neo-125M\")\n",
    "input_ = tokenizer(prompt, return_tensors=\"pt\")\n",
    "prompt_token_length = input_['input_ids'].shape[1] #16\n",
    "\n",
    "TORCH_SEED = 42\n",
    "PY_SEED = 1234\n",
    "NUM_COMPLETIONS = 20\n",
    "MAX_GEN_TOKENS = 100\n",
    "DATA_SAVE_DIR = os.path.join(\"data\", \"example\")\n",
    "os.makedirs(DATA_SAVE_DIR, exist_ok=True)\n",
    "max_length = prompt_token_length + MAX_GEN_TOKENS # for batch size 1 this is a fixed quantity for all text = 116\n",
    "transformers.set_seed(TORCH_SEED)\n",
    "random.seed(PY_SEED)\n",
    "logging.getLogger(\"transformers\").setLevel(logging.ERROR)\n",
    "\n",
    "choice_idxs = list(range(prompt_token_length, max_length)) # these are the positions that can be edited from in TPS. CHECK."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b66bbde",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_ari(prompt, completions, model):\n",
    "    # uncapped version for illustration purposes\n",
    "    # paper uses custom py-readability implementation for edge case handling and cap=15\n",
    "    aris = [automated_readability_index(prompt + c) for c in completions]\n",
    "    return aris\n",
    "\n",
    "def compute_log_probs(prompt, completions, model):\n",
    "    # unbatched version for illustration purposes\n",
    "    res = []\n",
    "\n",
    "    for c in completions:\n",
    "\n",
    "        full_text = prompt + c\n",
    "        tmp = tokenizer(full_text, return_tensors=\"pt\")\n",
    "        next_token_logits = model(**tmp).logits\n",
    "        next_token_log_probs = next_token_logits.log_softmax(dim=-1)\n",
    "        current_token_log_probs = next_token_log_probs.roll(shifts=1, dims=-2) #align current token log probs with current tokens\n",
    "        current_token_log_probs[:, 0, 0:prompt_token_length] = 0 #set all prompt token log probs to zero since these are given, P(prompt)=1\n",
    "        chosen_token_ids = tmp['input_ids'] \n",
    "        chosen_current_token_log_probs = current_token_log_probs.gather(dim = -1, index=chosen_token_ids.unsqueeze(-1)).squeeze(-1) \n",
    "        joint_chosen_log_probs = chosen_current_token_log_probs.sum(dim=1)  # shape (1,)\n",
    "        #to python float\n",
    "        joint_chosen_log_probs = joint_chosen_log_probs.item()\n",
    "        res.append(joint_chosen_log_probs)\n",
    "\n",
    "    return res\n",
    "\n",
    "def make_obs_fn(obs):\n",
    "    if obs == 'ari':\n",
    "        return compute_ari\n",
    "    elif obs == 'log_probs':\n",
    "        return compute_log_probs\n",
    "    else:\n",
    "        raise ValueError(f\"Unsupported observable: {obs}. Supported: 'ari', 'log_probs'\")\n",
    "\n",
    "def generate_tokens(input_, model, tokenizer, prompt, max_length):\n",
    "    output_tokens = model.generate(**input_, max_length=max_length, top_k=None, do_sample=True, temperature=1.0, \n",
    "                                    eos_token_id=None, pad_token_id=None)\n",
    "    output_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)\n",
    "    completion_text = output_text[len(prompt):]\n",
    "    return output_tokens, completion_text"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64b83668",
   "metadata": {},
   "source": [
    "# 1. Direct Sampling Example\n",
    "\n",
    "Setting `num_completions = 10^6` and adding ari-cap would give paper results (approx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c69aa71",
   "metadata": {},
   "outputs": [],
   "source": [
    "def direct_sampling(input_, model, tokenizer, prompt, max_length, num_completions=10):\n",
    "\n",
    "    completions = []\n",
    "\n",
    "    with torch.no_grad():\n",
    "\n",
    "        for _ in range(num_completions):\n",
    "            output_tokens, completion_text = generate_tokens(input_, model, tokenizer, prompt, max_length)\n",
    "            assert output_tokens.shape[0] == 1, f\"Batch size greater than 1 not supported in this code snippet\"\n",
    "            assert output_tokens.shape[1] == max_length, f\"Generated sequence length {output_tokens.shape[1]} does not match expected max_length {max_length}\" #This might happen if stop-token is used\n",
    "            completions.append(completion_text)\n",
    "\n",
    "    aris = compute_ari(prompt, completions, None)\n",
    "    log_probs = compute_log_probs(prompt, completions, model)\n",
    "\n",
    "    return completions, aris, log_probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "429f81ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "completions, aris, log_probs = direct_sampling(input_, model, tokenizer, prompt, max_length, num_completions=NUM_COMPLETIONS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c54aaaf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "completions[0:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4d9c07a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# quick histograms of both \n",
    "import matplotlib.pyplot as plt\n",
    "fig, axs = plt.subplots(1, 2, figsize=(12, 5))\n",
    "axs[0].hist(aris, bins=20, color='blue', alpha=0.7)\n",
    "axs[0].set_title(\"Automated Readability Indices\")\n",
    "axs[1].hist(log_probs, bins=20, color='green', alpha=0.7)\n",
    "axs[1].set_title(\"Log Probabilities of Completions\")\n",
    "axs[0].set_xlabel(\"ARI\")\n",
    "axs[0].set_ylabel(\"Frequency\")\n",
    "axs[1].set_xlabel(\"Log Probability\")\n",
    "axs[1].set_ylabel(\"Frequency\")\n",
    "axs[0].set_yscale('log')\n",
    "axs[1].set_yscale('log')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff62b433",
   "metadata": {},
   "source": [
    "# 2. Biased Sampling\n",
    "\n",
    "Running $10$ independent chains would reproduce paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46b8f23f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_proposal(current_tokens, model, tokenizer, prompt, max_length, choice_idxs):\n",
    "    \"\"\"Generates a proposal by randomly choosing an index from choice_idxs and generating tokens from the truncated input.\"\"\"\n",
    "    idx = random.choice(choice_idxs) #random edit point\n",
    "    input_tokens_proposal = current_tokens.clone()[:, 0:idx]\n",
    "    #make all all ones attention mask for proposal\n",
    "    attention_mask_proposal = torch.ones_like(input_tokens_proposal) #Only works because no padding tokens in\n",
    "    input_proposal = {'input_ids': input_tokens_proposal, 'attention_mask': attention_mask_proposal} #generate tokens expects dict-like input\n",
    "    output_tokens_proposal, completion_text_proposal = generate_tokens(input_proposal, model, tokenizer, prompt, max_length)\n",
    "    assert output_tokens_proposal.shape[0] == 1, f\"Batch size greater than 1 not supported in this code snippet\"\n",
    "    assert output_tokens_proposal.shape[1] == max_length, f\"Generated sequence length {output_tokens_proposal.shape[1]} does not match expected max_length {max_length}\"\n",
    "    return output_tokens_proposal, completion_text_proposal, idx\n",
    "\n",
    "def compute_tps_acceptance_factor(bias, current_tokens, proposal_tokens, current_obs, proposal_obs, current_length, proposal_length):\n",
    "    \"\"\"Compute the acceptance factor according to Metropolis-Hastings criterion for TPS\"\"\"\n",
    "    length_ratio = current_length / proposal_length\n",
    "    obs_diff = proposal_obs - current_obs\n",
    "    acceptance_factor = length_ratio * torch.exp(-bias * torch.tensor(obs_diff)).item() #check standard terminology\n",
    "    acceptance_factor_min_1 = min(acceptance_factor, 1.0)\n",
    "    return acceptance_factor_min_1, acceptance_factor, obs_diff, length_ratio\n",
    "\n",
    "def tps_step(current_tokens, current_completion_text, model, tokenizer, prompt, max_length, choice_idxs, bias, compute_obs_fn, current_obs=None, current_length=None):\n",
    "    \"\"\"Performs a single TPS step: generates a proposal, computes observables, and acceptance factor.\n",
    "    \n",
    "    Expected signature for compute_obs_fn: compute_obs_fn(prompt, completion_text, model) -> observable_value\n",
    "    \n",
    "    \"\"\"\n",
    "    # Generate proposal\n",
    "    proposal_tokens, proposal_completion_text, edit_idx = generate_proposal(current_tokens, model, tokenizer, prompt, max_length, choice_idxs)\n",
    "    \n",
    "    # Compute observables. Batch size 1 assumed throughout\n",
    "    current_obs = compute_obs_fn(prompt, [current_completion_text], model)[0] if current_obs is None else current_obs\n",
    "    proposal_obs = compute_obs_fn(prompt, [proposal_completion_text], model)[0]\n",
    "\n",
    "    is_same_tokens = torch.equal(proposal_tokens, current_tokens)\n",
    "    is_same_text = (proposal_completion_text == current_completion_text)\n",
    "    is_same_obs = (proposal_obs == current_obs)\n",
    "\n",
    "    # Compute acceptance factor\n",
    "    current_length = current_tokens.shape[1]\n",
    "    proposal_length = proposal_tokens.shape[1]\n",
    "    alpha, acceptance_factor, obs_diff, length_ratio = compute_tps_acceptance_factor(bias, \n",
    "        current_tokens, proposal_tokens, current_obs, proposal_obs, current_length, proposal_length)\n",
    "\n",
    "    step_info = {\n",
    "        \"is_same_tokens\": is_same_tokens,\n",
    "        \"is_same_text\": is_same_text,\n",
    "        \"is_same_obs\": is_same_obs,\n",
    "        \"current_obs\": current_obs,\n",
    "        \"proposal_obs\": proposal_obs,\n",
    "        \"alpha\": alpha,\n",
    "        \"acceptance_factor\": acceptance_factor,\n",
    "        \"obs_diff\": obs_diff,\n",
    "        \"length_ratio\": length_ratio,\n",
    "        \"edit_idx\": edit_idx\n",
    "    }\n",
    "    \n",
    "\n",
    "    # Accept/reject step\n",
    "    if random.random() < alpha:\n",
    "        # Accept proposal\n",
    "        accepted = 1\n",
    "        step_info[\"accepted\"] = accepted\n",
    "        return proposal_tokens, proposal_completion_text,  proposal_obs, step_info\n",
    "    else:\n",
    "        # Reject proposal, keep current\n",
    "        accepted = 0\n",
    "        step_info[\"accepted\"] = accepted\n",
    "        return current_tokens, current_completion_text, current_obs, step_info\n",
    "\n",
    "def biased_sampling(input_, model, tokenizer, prompt, max_length, choice_idxs, num_tps_steps_per_bias=5, biases=[0.1, 0.2], obs='ari'):\n",
    "    \"\"\"Performs biased sampling using annealed TPS to generate completions. Performs a single chain of num_tps_steps TPS steps.\"\"\"\n",
    "\n",
    "    def _get_bias_for_step(step):\n",
    "        bias_idx = step // num_tps_steps_per_bias\n",
    "        return biases[bias_idx]\n",
    "\n",
    "    num_tps_steps = num_tps_steps_per_bias * len(biases)\n",
    "    sampling_info = {\"steps\": []}\n",
    "    observable_fn = make_obs_fn(obs)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        # Initial generation and observable computation\n",
    "        current_tokens, completion_text = generate_tokens(input_, model, tokenizer, prompt, max_length)\n",
    "        current_length = current_tokens.shape[1]\n",
    "        current_obs = None #will be computed in first step\n",
    "\n",
    "        for step in range(num_tps_steps):\n",
    "            bias = _get_bias_for_step(step)\n",
    "            current_tokens, completion_text, current_obs, step_info = tps_step(current_tokens, \n",
    "                completion_text, model, tokenizer, prompt, max_length, choice_idxs, bias, \n",
    "                compute_obs_fn=observable_fn, current_obs=current_obs, current_length=current_length)\n",
    "            # concat step info into sampling_info\n",
    "            sampling_info[\"steps\"].append(step)\n",
    "            if step == 0: #init the keys of sampling_info\n",
    "                for key in step_info.keys():\n",
    "                    sampling_info[key] = []\n",
    "            for key, value in step_info.items():\n",
    "                sampling_info[key].append(value)\n",
    "\n",
    "    # add fixed info\n",
    "    sampling_info[\"bias\"] = [_get_bias_for_step(step) for step in range(num_tps_steps)]\n",
    "    sampling_info[\"observable\"] = obs\n",
    "    sampling_info[\"num_tps_steps_per_bias\"] = num_tps_steps_per_bias\n",
    "    sampling_info[\"num_tps_steps\"] = num_tps_steps\n",
    "    sampling_info[\"min_choice_idx\"] = min(choice_idxs)\n",
    "    sampling_info[\"max_choice_idx\"] = max(choice_idxs)\n",
    "\n",
    "\n",
    "    return current_tokens, completion_text, current_obs, sampling_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4d426c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "current_tokens, completion_text, current_obs, sampling_info = biased_sampling(input_, model, tokenizer, prompt, max_length, choice_idxs, num_tps_steps_per_bias=100, biases=[0.1, 0.2, 0.3], obs='ari')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "673f2c35",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(sampling_info)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a81ffb35",
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute key statistics\n",
    "acceptance_rates = df.groupby('bias')['accepted'].mean()\n",
    "mean_obs = df.groupby('bias')['proposal_obs'].mean() #no burn-in applied\n",
    "print(\"Acceptance Rate:\", acceptance_rates.to_dict())\n",
    "print(\"Mean Observable (no burn-in):\", mean_obs.to_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fff79e71",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot current_obs vs step highlighting bias changes\n",
    "biases = df['bias'].unique()\n",
    "import matplotlib.pyplot as plt\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.plot(df['steps']+1, df['proposal_obs'], marker='o', linestyle='--', label='Proposal Observable', alpha=0.6)\n",
    "plt.plot(df['steps'], df['current_obs'], marker='s', linestyle='-', label='Current Observable')\n",
    "# cumulative mean current_obs for each bias region separately\n",
    "cumulative_means = []\n",
    "for bias in biases: #not sure this always works ascending with step\n",
    "    bias_mask = df['bias'] == bias\n",
    "    bias_obs = df.loc[bias_mask, 'current_obs']\n",
    "    cummean = bias_obs.expanding().mean()\n",
    "    cumulative_means.extend(cummean)\n",
    "plt.plot(df['steps'], cumulative_means, color='red', marker='x', linestyle='--', label='Cumulative Mean Current Observable', linewidth=2)\n",
    "# highlight bias regions\n",
    "for bias in biases:\n",
    "    bias_steps = df[df['bias'] == bias]['steps']\n",
    "    plt.axvspan(bias_steps.min(), bias_steps.max(), alpha=0.2)\n",
    "plt.xlabel('TPS Step')\n",
    "plt.ylabel('Proposal Observable (ARI)')\n",
    "plt.title('Proposal Observable vs TPS Step with Bias Highlighting')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e1df7c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract data to save for reweighting and analysis scripts. \n",
    "\n",
    "\"\"\"Parameters to recreate Figure 2a in the paper are commented out below.\"\"\"\n",
    "# steps_per_bias = 40000\n",
    "# biases = [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],\n",
    "#           [-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8, -0.9, -1.0]]\n",
    "# num_trajectories = 10\n",
    "\n",
    "\"\"\"Small example.\"\"\"\n",
    "steps_per_bias = 10\n",
    "all_biases = [[0.1, 0.2], [-0.1, -0.2]]\n",
    "num_trajectories = 2\n",
    "\n",
    "all_trajectories = []\n",
    "for biases in all_biases:\n",
    "    trajs_per_biases = []\n",
    "    for num_traj in range(num_trajectories):\n",
    "        _, _, _, sampling_info = biased_sampling(input_, model, tokenizer, prompt, max_length, choice_idxs, num_tps_steps_per_bias=steps_per_bias, biases=biases, obs='ari')\n",
    "        trajectory = sampling_info[\"current_obs\"]\n",
    "        trajs_per_biases.append(trajectory)\n",
    "    all_trajectories.append(trajs_per_biases)\n",
    "\n",
    "with open(os.path.join(DATA_SAVE_DIR, \"ari_trajectories.json\"), \"w\") as f:\n",
    "    json.dump([all_trajectories, all_biases, steps_per_bias], f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c5664a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Parameters to recreate the orange curve in Figure 3b are commented out below.\"\"\"\n",
    "# num_unbiased_samples = 4200000\n",
    "\n",
    "num_unbiased_samples = 50\n",
    "\n",
    "_, aris_save, _ = direct_sampling(input_, model, tokenizer, prompt, max_length, num_completions=num_unbiased_samples)\n",
    "\n",
    "with open(os.path.join(DATA_SAVE_DIR, \"ari_unbiased_samples.json\"), \"w\") as f:\n",
    "    json.dump(aris, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "completions_database",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
