{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-27T01:18:38.283592Z",
     "start_time": "2024-08-27T01:18:38.216023Z"
    },
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "OtI3tcAvUGCv",
    "outputId": "c09a456b-c8e2-4a4c-b509-0e33a2bb7c09"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "test_id = 5\n",
    "sample_inputs = [\n",
    "    \"[BOS] 1 -2 3 0 -1 2 -3 0 2 4 -1 0 1 -3 4 0 -2 -3 -4 0 3 -4 -1 0 -2 4 -1 0 [SEP] D 1 [UP] D 1 D 2 [UP] D 1 D 2 4 -3 [BT] D 1 -2 -3 4 [BT] -1 D 2 [UP] -1 D 2 3 4 [BT] -1 -2 D 3 [UP] -1 -2 D 3 4 SAT\",\n",
    "    \"[BOS] 6 -7 -5 0 -4 -8 9 0 -1 -8 2 0 -3 6 1 0 5 -9 -10 0 -7 -1 9 0 -4 -6 10 0 -4 -8 7 0 -7 -2 -8 0 3 6 1 0 8 1 3 0 6 3 7 0 3 9 7 0 3 -9 -5 0 -1 3 2 0 5 3 -6 0 -10 7 4 0 8 -9 -10 0 1 -4 5 0 -2 10 7 0 -10 5 -2 0 -8 10 -7 0 -4 -5 -1 0 -10 -7 6 0 10 7 -1 0 -3 -6 7 0 4 6 -9 0 -9 -10 7 0 -7 -10 8 0 5 4 -2 0 -1 -3 4 0 -3 -8 1 0 7 9 10 0 4 3 2 0 10 2 3 0 9 -4 8 0 9 5 2 0 1 -2 -8 0 8 9 -1 0 1 -10 4 0 5 2 -1 0 [SEP] D 7 [UP] D 7 D 1 [UP] D 7 D 1 9 D 6 [UP] D 7 D 1 9 D 6 D 5 [UP] D 7 D 1 9 D 6 D 5 3 -4 [BT] D 7 D 1 9 D 6 -5 -10 -4 3 [BT] D 7 D 1 9 -6 -5 -10 -8 4 2 SAT\",\n",
    "    \"[BOS] 6 -8 10 0 2 -6 -7 0 8 -7 3 0 -3 1 -2 0 -2 -10 4 0 1 -6 -2 0 -8 4 -9 0 1 10 -5 0 8 4 5 0 -7 -6 -4 0 6 3 10 0 1 7 -10 0 1 6 -3 0 7 10 -9 0 -2 1 4 0 -7 -9 4 0 9 6 -3 0 4 -2 -9 0 3 8 -1 0 -1 5 -3 0 -2 -3 -7 0 4 -2 8 0 -6 -9 -10 0 4 5 10 0 -7 3 -2 0 -6 7 4 0 -8 -1 -10 0 -7 8 10 0 -9 -7 -1 0 -6 -2 5 0 -6 5 -3 0 -6 -4 9 0 9 -10 7 0 8 -3 -7 0 7 -10 -5 0 1 2 -4 0 4 -8 3 0 8 -3 6 0 -8 -2 1 0 6 -3 -10 0 -1 -6 -2 0 6 -4 2 0 10 3 6 0 [SEP] D 8 [UP] D 8 D -4 [UP] D 8 D -4 -9 3 6 7 2 [BT] D 8 4 D -6 [UP] D 8 4 D -6 10 -1 7 -3 -2 [BT] D 8 4 6 -7 9 10 [BT] -8 D 2 [UP] -8 D 2 4 D -7 [UP] -8 D 2 4 D -7 D -5 [UP] -8 D 2 4 D -7 D -5 -6 -3 10 1 [BT] -8 D 2 4 D -7 5 -10 1 -9 3 6 [BT] -8 D 2 4 7 3 [BT] -8 -2 D 3 [UP] -8 -2 D 3 -7 6 4 5 9 10 [BT] -8 -2 -3 -7 -1 -10 -5 4 UNSAT\",\n",
    "    \"[BOS] 1 9 -5 0 3 6 -8 0 1 -7 4 0 5 -9 -1 0 -3 -7 -1 0 -6 -2 -3 0 6 2 3 0 -7 -6 5 0 -6 -9 7 0 3 5 -8 0 9 4 -6 0 -1 -4 -8 0 2 1 7 0 9 4 7 0 5 7 8 0 5 -1 -7 0 7 2 3 0 5 2 -7 0 3 -5 4 0 8 -7 -4 0 9 3 1 0 8 -5 7 0 3 9 -6 0 4 9 2 0 4 1 -2 0 -3 9 -4 0 -9 -6 -8 0 3 8 2 0 5 -7 4 0 -5 3 -8 0 -8 1 -4 0 -7 2 4 0 9 -6 1 0 9 -1 5 0 8 -1 -9 0 9 4 5 0 7 1 -5 0 -6 -1 -2 0 -1 -9 3 0 [SEP] D 3 [UP] D 3 D 9 [UP] D 3 D 9 -1 -6 7 4 8 [BT] D 3 -9 -4 2 1 5 7 [BT] -3 -8 2 D -5 [UP] -3 -8 2 D -5 7 -4 [BT] -3 -8 2 5 4 -7\",\n",
    "    \"[BOS] 4 -10 -9 0 -11 -2 5 0 13 5 -1 0 -11 -14 2 0 15 -13 -11 0 -5 -11 -1 0 -3 -1 12 0 13 8 5 0 11 -12 -2 0 5 10 -1 0 5 -1 -15 0 11 6 -1 0 6 5 -4 0 -11 5 -7 0 -8 -3 -7 0 9 -15 -11 0 -11 -13 15 0 -7 4 6 0 2 -10 -14 0 8 12 -6 0 -7 -15 -10 0 -13 7 10 0 7 15 -4 0 7 -4 9 0 -4 6 2 0 15 4 1 0 15 -6 -7 0 8 5 -6 0 -3 -6 -2 0 -4 11 -2 0 2 9 -8 0 -12 -11 -5 0 -11 -6 14 0 -2 -6 -10 0 -3 -4 -11 0 6 -9 15 0 -12 -5 15 0 -4 9 -3 0 -4 2 13 0 -2 -12 8 0 -5 -12 15 0 15 7 6 0 -3 5 -2 0 2 -4 -13 0 -3 -15 -11 0 15 -4 -9 0 -3 8 -11 0 13 -1 14 0 -10 3 7 0 8 -14 2 0 5 13 7 0 8 4 6 0 13 -4 6 0 -2 7 10 0 -12 -9 7 0 8 4 -10 0 4 -14 -12 0 -10 -13 5 0 -7 -5 3 0 5 -10 -14 0 10 -13 -11 0 -13 -6 10 0 10 11 4 0 [SEP] D 14 [UP] D 14 D 13 [UP] D 14 D 13 D 1 [UP] D 14 D 13 D 1 D 10 [UP] D 14 D 13 D 1 D 10 2 -6 11 5 [BT] D 14 D 13 D 1 -10 5 -11 6 [BT] D 14 D 13 -1 D 8 [UP] D 14 D 13 -1 D 8 D 10 [UP] D 14 D 13 -1 D 8 D 10 2 -6 5 D -11 [UP] D 14 D 13 -1 D 8 D 10 2 -6 5 D -11 -12 -4 -9 -7 15 3 SAT\",\n",
    "    \"[BOS] -15 -8 -2 0 9 -5 6 0 2 10 6 0 14 13 2 0 -14 -7 8 0 15 8 14 0 -15 5 -9 0 8 -13 5 0 15 -4 -6 0 11 13 2 0 5 11 -15 0 -8 2 -14 0 -4 -5 7 0 -14 8 5 0 6 2 -7 0 -5 -2 14 0 -1 7 -6 0 -10 -15 -5 0 1 13 -6 0 15 12 -9 0 -2 15 10 0 -5 -6 15 0 -12 6 13 0 7 -14 15 0 8 4 13 0 2 -10 -13 0 11 -14 7 0 15 -5 3 0 12 13 1 0 -6 9 -10 0 15 4 -11 0 -15 -12 -4 0 -14 1 15 0 -1 14 -12 0 -9 6 4 0 4 2 1 0 14 -4 3 0 14 -4 -2 0 12 -15 -3 0 -5 10 7 0 1 -8 10 0 6 -13 10 0 -8 -6 15 0 9 15 -11 0 -13 11 -8 0 -13 10 11 0 2 15 -4 0 -3 1 4 0 3 -1 -7 0 2 8 -1 0 13 15 3 0 -6 -8 -3 0 10 -3 -9 0 3 8 5 0 11 -13 -6 0 -9 3 -10 0 -10 11 15 0 -2 -11 4 0 15 11 -9 0 -2 -15 6 0 -8 14 -6 0 -12 2 4 0 10 11 -14 0 -2 -12 -15 0 [SEP] D 5 [UP] D 5 D 7 [UP] D 5 D 7 D 15 [UP] D 5 D 7 D 15 -10 D 13 [UP] D 5 D 7 D 15 -10 D 13 6 11 D -3 [UP] D 5 D 7 D 15 -10 D 13 6 11 D -3 -1 -8 -14 -2 4 [BT] D 5 D 7 D 15 -10 D 13 6 11 3 12 -4 1 14 8 [BT] D 5 D 7 D 15 -10 -13 D -3 [UP] D 5 D 7 D 15 -10 -13 D -3 -1 -6 9 2 [BT] D 5 D 7 D 15 -10 -13 3 12 6 1 -4 8 [BT] D 5 D 7 -15 -6 9 2 14 8 12 10 13 3 1 4 11 SAT\"\n",
    "]\n",
    "# sample_input_unsat = \"[BOS] 1 -2 3 0 -1 2 -3 0 2 4 -1 0 1 -3 4 0 -2 -3 -4 0 3 -4 -1 0 -2 4 -1 0 1 2 -3 0 1 2 3 0 [SEP]\"\n",
    "sample_input = sample_inputs[test_id]\n",
    "prompt_str = sample_input.split(\"[SEP]\")[0] + \"[SEP]\"\n",
    "trace_str = sample_input.split(\"[SEP]\")[1].strip()\n",
    "prompt_tokens = prompt_str.split()\n",
    "trace_tokens = trace_str.split()\n",
    "full_trace_tokens = prompt_tokens + trace_tokens\n",
    "\n",
    "M=10\n",
    "\n",
    "context_len = len(full_trace_tokens) + 10\n",
    "num_vars = max(int(tok) for tok in full_trace_tokens if tok.isdigit())\n",
    "num_clauses = prompt_str.count(\"0\") + 2\n",
    "token_set = np.array([str(i) for i in range(1, num_vars+1)] + [str(-i) for i in range(1, num_vars+1)] + ['0', '[SEP]',  '[UP]', '[BT]', '[BOS]', 'D', 'SAT', 'UNSAT'])\n",
    "print(token_set)\n",
    "token_to_idx = {token: idx for idx, token in enumerate(token_set)}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-27T01:18:39.653103Z",
     "start_time": "2024-08-27T01:18:38.297952Z"
    },
    "id": "7sk0Pkm_QPKO"
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from visualization import Visualization\n",
    "\n",
    "def encode_tokens(tokens):\n",
    "    return np.eye(len(token_set))[np.array([token_to_idx[token] for token in tokens])]\n",
    "\n",
    "def e(token):\n",
    "    return np.eye(len(token_set))[token_to_idx[token]]\n",
    "\n",
    "def nearst_token(tokens, targets):\n",
    "    ret = []\n",
    "    last_sep = 0\n",
    "    for i in range(len(tokens)):\n",
    "        if tokens[i] in targets:\n",
    "            last_sep = i\n",
    "        ret.append(last_sep)\n",
    "    return np.array(ret)\n",
    "\n",
    "def simulated_self_attention(Q, K, V):\n",
    "    attention_logits = np.dot(Q, K.T)\n",
    "    mask = np.tril(np.ones(attention_logits.shape), k=0)\n",
    "    attention_logits = np.where(mask == 0, -1e9, attention_logits)\n",
    "    max_attention_logits = np.max(attention_logits, axis=-1)\n",
    "    max_attention_positions = np.isclose(attention_logits, max_attention_logits[:, None])\n",
    "    selected_V = np.where(max_attention_positions[:, :, None], V[None, :, :], 0)\n",
    "    avg_V = np.sum(selected_V, axis=1) / np.sum(max_attention_positions, axis=1, keepdims=True)\n",
    "    return avg_V\n",
    "\n",
    "def T_transform(encodings, true_vec=(1, 0), false_vec=(0, 1), none_vec=(0, 0)):\n",
    "    mat = np.zeros((2 * num_vars, 2 * num_vars))\n",
    "    true_vec_off = (true_vec[0] - none_vec[0], true_vec[1] - none_vec[1])\n",
    "    false_vec_off = (false_vec[0] - none_vec[0], false_vec[1] - none_vec[1])\n",
    "    for i in range(num_vars):\n",
    "        true_id = i\n",
    "        false_id = num_vars + i\n",
    "        mat[true_id, true_id] = true_vec_off[0]\n",
    "        mat[true_id, false_id] = true_vec_off[1]\n",
    "        mat[false_id, true_id] = false_vec_off[0]\n",
    "        mat[false_id, false_id] = false_vec_off[1]\n",
    "\n",
    "    res = encodings @ mat\n",
    "    res[:, :num_vars] += none_vec[0]\n",
    "    res[:, num_vars:] += none_vec[1]\n",
    "    return res\n",
    "\n",
    "\n",
    "def predict_next_token_abstract(tokens, visualize=False):\n",
    "    # Layer 1\n",
    "    tokens = np.array(tokens)\n",
    "    encoding = encode_tokens(tokens)\n",
    "    viz = Visualization()\n",
    "    viz.add_array('token', tokens)\n",
    "\n",
    "    p_i_sep_p = nearst_token(tokens, ['0', '[SEP]', '[UP]', '[BT]'])\n",
    "    viz.add_array('p_i_sep_p', p_i_sep_p)\n",
    "\n",
    "    b_0 = tokens[p_i_sep_p] == '0'\n",
    "    b_SEP = tokens[p_i_sep_p] == '[SEP]'\n",
    "    b_UP = tokens[p_i_sep_p] == '[UP]'\n",
    "    b_BackTrack = tokens[p_i_sep_p] == '[BT]'\n",
    "    viz.add_array('b_0', b_0)\n",
    "    viz.add_array('b_SEP', b_SEP)\n",
    "    viz.add_array('b_UP', b_UP)\n",
    "    viz.add_array('b_BackTrack', b_BackTrack)\n",
    "    p_i_D = np.array(nearst_token(tokens, ['D']))\n",
    "    viz.add_array('p_i_D', p_i_D)\n",
    "    # Layer 2\n",
    "    p_i_sep = np.insert(p_i_sep_p[:-1], 0, 0).astype(int)\n",
    "    p_i_sep[0] = p_i_sep[0] - 1\n",
    "    b_decision = np.array([False] + [token == 'D' for token in tokens][:-1])\n",
    "    viz.add_array('b_decision', b_decision)\n",
    "    # MODIFICATION: np.arange(len(tokens)) - p_i_sep => p_i_sep_p\n",
    "    d_i_sep = np.arange(len(tokens)) - p_i_sep_p\n",
    "    viz.add_array('p_i_sep', p_i_sep)\n",
    "    viz.add_array('d_i_sep', d_i_sep)\n",
    "    viz.add_array('b_decision', b_decision)\n",
    "    # Layer 3\n",
    "    r_i = np.array([np.array(encoding[p_i_sep[i]+1:i+1,:]).sum(axis=0) for i in range(len(tokens))])[:,:2*num_vars]\n",
    "    #viz.add_array('r_i', r_i)\n",
    "    # Modification: p_i_sep[p_i_sep] => p_i_sep[p_i_sep_p]\n",
    "    p_i_sep_min = p_i_sep[p_i_sep_p]\n",
    "    viz.add_array('p_i_sep_min', p_i_sep_min)\n",
    "    # MODIFICATION: Add M instead of i\n",
    "    p_i_min = p_i_sep_min + d_i_sep + 4 * b_SEP\n",
    "    viz.add_array('p_i_min', p_i_min)\n",
    "\n",
    "    p_i_D_min = p_i_D[p_i_sep_p]\n",
    "    viz.add_array('p_i_D_min', p_i_D_min)\n",
    "    # MODIFICATION: p_i_D_min + 1\n",
    "    b_exceed = p_i_min > (p_i_D_min + 1)\n",
    "    viz.add_array('b_exceed', b_exceed)\n",
    "    b_D_min = (p_i_D_min == p_i_min + 1)\n",
    "    viz.add_array('b_D_min', b_D_min)\n",
    "    # viz.add_array('b_BT_finish', b_BT_finish)\n",
    "    # Layer 4\n",
    "    sat_Q = np.hstack((r_i, np.ones((r_i.shape[0], 1))))\n",
    "    sat_K_offset = np.zeros((r_i.shape[0], 1))\n",
    "    sat_K_offset[np.array(tokens) != '0'] = -M\n",
    "    sat_K_offset[0][0] = -0.5\n",
    "    sat_K = np.hstack((-r_i, sat_K_offset))\n",
    "    sat_V = np.zeros((r_i.shape[0], 1))\n",
    "    sat_V[0][0] = 1\n",
    "    viz.add_array('attn', (sat_Q @ sat_K.T)[:, :12])\n",
    "    b_sat = simulated_self_attention(sat_Q, sat_K, sat_V).flatten().astype(bool)\n",
    "    viz.add_array('b_sat', b_sat)\n",
    "\n",
    "    unsat_Q = np.hstack((T_transform(r_i, true_vec=(1, 0), false_vec=(0, 1), none_vec=(1, 1)), np.ones((r_i.shape[0], 1))))\n",
    "    unsat_K = sat_K\n",
    "    unsat_V = np.ones((r_i.shape[0], 1))\n",
    "    unsat_V[0][0] = 0\n",
    "    b_cont = simulated_self_attention(unsat_Q, unsat_K, unsat_V).flatten().astype(bool)\n",
    "    viz.add_array('b_cont', b_cont)\n",
    "    # MODIFICATION: p_i_sep*_p* - 1\n",
    "    b_copy_p = p_i_min < (p_i_sep_p - 1)\n",
    "    viz.add_array('b_copy_p', b_copy_p)\n",
    "\n",
    "    up_q = np.hstack((T_transform(r_i, true_vec=(0, 1), false_vec=(1, 0), none_vec=(0, 0)), np.ones((r_i.shape[0], 1))))\n",
    "    up_k_offset = np.zeros((r_i.shape[0], 1))\n",
    "    up_k_offset[np.array(tokens) != '0'] = -M\n",
    "    up_k_offset[0][0] = 1.5\n",
    "    up_k = np.hstack((r_i, up_k_offset))\n",
    "    up_v = num_clauses * r_i\n",
    "    o_up = simulated_self_attention(up_q, up_k, up_v)\n",
    "    # Goal: (o_up > 0) and not T(r_i)=ReLU(ReLU(o_up) - ReLU(o_up - 1) - T(r_i))\n",
    "    e_up = np.maximum(o_up - T_transform(r_i, true_vec=(1, 1), false_vec=(1, 1), none_vec=(0, 0)), 0) - np.maximum(o_up - 1, 0)\n",
    "\n",
    "    heuristic_q = np.hstack((T_transform(r_i, true_vec=(-10, 1), false_vec=(-10, 1), none_vec=(0, 0)), np.ones((r_i.shape[0], 1))))\n",
    "    up_k_offset = np.zeros((r_i.shape[0], 1))\n",
    "    up_k_offset[np.array(tokens) != '0'] = -M\n",
    "    up_k = np.hstack((r_i, up_k_offset))\n",
    "    up_v = r_i\n",
    "\n",
    "\n",
    "    b_final = np.logical_and(b_exceed, b_decision)\n",
    "    b_no_decision = p_i_D <= p_i_sep\n",
    "    # MODIFICATION: add b_BT_finish for exception\n",
    "    b_BT_finish = (p_i_D_min <= p_i_min) & b_BackTrack\n",
    "    # Layer 5\n",
    "    e_BT = T_transform(encoding[p_i_D_min + 1][:, :2*num_vars], true_vec=(0, 1), false_vec=(1, 0), none_vec=(0, 0))\n",
    "    # MODIFICATION: Get id 0 when p_i_min+1 does not exist\n",
    "    p_i_min_index = p_i_min + 1\n",
    "    # p_i_min_index[p_i_min_index >= len(tokens)] = 0\n",
    "    # p_i_min_index =\n",
    "    e_copy = encoding[np.clip(p_i_min_index, 0, np.arange(len(p_i_min_index)))]\n",
    "    b_unsat = np.logical_and(b_no_decision, b_cont)\n",
    "    # MODICIATION: p_i_min == p_i_D_min - 1 -> b_D_min\n",
    "    b_backtrack = b_D_min & b_BackTrack\n",
    "    # MODIFICATION: b_copy moved to Layer 4 and renamed to b_copy_p\n",
    "    # MODIFICATION: b_copy\n",
    "    b_copy = b_copy_p & np.logical_not(b_BT_finish)\n",
    "    # MODIFICATION: Add BT token condition\n",
    "    b_BT_token = b_cont & np.logical_not(encoding[:, token_to_idx['[BT]']])\n",
    "\n",
    "    viz.add_array('b_unsat', b_unsat)\n",
    "    viz.add_array('b_backtrack', b_backtrack)\n",
    "    viz.add_array('b_copy', b_copy)\n",
    "    viz.add_array('b_final', b_final)\n",
    "\n",
    "    output_logits = np.zeros_like(encoding)\n",
    "    output_logits[b_sat, :] += 2 ** 8 * e('SAT')\n",
    "    output_logits[b_unsat, :] += 2 ** 7 * e('UNSAT')\n",
    "    output_logits[b_BT_token, :] += 2 ** 6 * e('[BT]')\n",
    "    output_logits[b_final, :] += 2 ** 5 * e('[UP]')\n",
    "    output_logits[b_backtrack, :2*num_vars] += 2 ** 4 * e_BT[b_backtrack, :]\n",
    "    output_logits[b_copy, :] += 2 ** 3 * e_copy[b_copy, :]\n",
    "    output_logits[:, :2*num_vars] += 2 ** 2 * e_up\n",
    "    output_logits += 2 ** 1 * np.where((1 - encoding[:, token_to_idx['D']])[:, None], e('D'), 0)\n",
    "    output_logits[:, :2*num_vars] += 2 ** 0 * T_transform(r_i, true_vec=(0, 0), false_vec=(0, 0), none_vec=(1, 1))\n",
    "\n",
    "    plt.imshow(output_logits, cmap='Greys')\n",
    "\n",
    "    max_idx = output_logits.argmax(axis=1)\n",
    "    pred_tokens = token_set[max_idx]\n",
    "    viz.add_array('pred_tokens', pred_tokens)\n",
    "    if visualize:\n",
    "      viz.display(start_index=len(prompt_tokens), max_len=30)\n",
    "\n",
    "    return pred_tokens[-1]\n",
    "\n",
    "predict_next_token_abstract(prompt_tokens + trace_tokens, visualize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-27T01:18:39.655525Z",
     "start_time": "2024-08-27T01:18:39.653745Z"
    },
    "id": "BUNN6c8kA_jZ"
   },
   "outputs": [],
   "source": [
    "def generation(prompt_tokens):\n",
    "  for _ in range(500):\n",
    "    next_token = predict_next_token_abstract(prompt_tokens)\n",
    "    prompt_tokens.append(next_token)\n",
    "    if next_token in ['SAT', 'UNSAT']:\n",
    "      break\n",
    "  return prompt_tokens\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-27T01:18:39.767653Z",
     "start_time": "2024-08-27T01:18:39.656040Z"
    },
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "sf7fZ-dhBWJZ",
    "outputId": "97acb255-3bbc-423d-ea30-93f81f76f45d"
   },
   "outputs": [],
   "source": [
    "all_tokens = generation(prompt_tokens.copy())\n",
    "print(' '.join(all_tokens))\n",
    "print(len(all_tokens) - len(prompt_tokens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "_vsHLfcFVpvr",
    "outputId": "38993034-7d40-41dc-9bf4-3187b0dc7fd9"
   },
   "outputs": [],
   "source": [
    "print(' '.join(generation(unsat_tokens.copy())))"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
