{
 "cells": [
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-23T16:56:10.450261Z",
     "start_time": "2024-09-23T16:56:09.047412Z"
    }
   },
   "source": [
    "from dpll import dpll\n",
    "from compiler import compile_model\n",
    "import matplotlib.pyplot as plt\n",
    "from visualization import Visualization\n",
    "from sops import indices, ones, OneHotTokEmb\n",
    "from sop_utils import sop_generate\n",
    "\n",
    "test_id = 0\n",
    "sample_inputs = [\n",
    "    \"[BOS] 1 3 0 -1 2 -3 0 4 -1 0 1 -3 4 0 -2 -3 -4 0 3 -4 -1 0 1 -3 -4 0 3 -1 4 0 [SEP]\",\n",
    "    \"[BOS] 6 -7 -5 0 -4 -8 9 0 -1 -8 2 0 -3 6 1 0 5 -9 -10 0 -7 -1 9 0 -4 -6 10 0 -4 -8 7 0 -7 -2 -8 0 3 6 1 0 8 1 3 0 6 3 7 0 3 9 7 0 3 -9 -5 0 -1 3 2 0 5 3 -6 0 -10 7 4 0 8 -9 -10 0 1 -4 5 0 -2 10 7 0 -10 5 -2 0 -8 10 -7 0 -4 -5 -1 0 -10 -7 6 0 10 7 -1 0 -3 -6 7 0 4 6 -9 0 -9 -10 7 0 -7 -10 8 0 5 4 -2 0 -1 -3 4 0 -3 -8 1 0 7 9 10 0 4 3 2 0 10 2 3 0 9 -4 8 0 9 5 2 0 1 -2 -8 0 8 9 -1 0 1 -10 4 0 5 2 -1 0 [SEP] D 7 [UP] D 7 D 1 [UP] D 7 D 1 9 D 6 [UP] D 7 D 1 9 D 6 D 5 [UP] D 7 D 1 9 D 6 D 5 3 -4 [BT] D 7 D 1 9 D 6 -5 -10 -4 3 [BT] D 7 D 1 9 -6 -5 -10 -8 4 2 SAT\",\n",
    "    \"[BOS] 6 -8 10 0 2 -6 -7 0 8 -7 3 0 -3 1 -2 0 -2 -10 4 0 1 -6 -2 0 -8 4 -9 0 1 10 -5 0 8 4 5 0 -7 -6 -4 0 6 3 10 0 1 7 -10 0 1 6 -3 0 7 10 -9 0 -2 1 4 0 -7 -9 4 0 9 6 -3 0 4 -2 -9 0 3 8 -1 0 -1 5 -3 0 -2 -3 -7 0 4 -2 8 0 -6 -9 -10 0 4 5 10 0 -7 3 -2 0 -6 7 4 0 -8 -1 -10 0 -7 8 10 0 -9 -7 -1 0 -6 -2 5 0 -6 5 -3 0 -6 -4 9 0 9 -10 7 0 8 -3 -7 0 7 -10 -5 0 1 2 -4 0 4 -8 3 0 8 -3 6 0 -8 -2 1 0 6 -3 -10 0 -1 -6 -2 0 6 -4 2 0 10 3 6 0 [SEP] D 8 [UP] D 8 D -4 [UP] D 8 D -4 -9 3 6 7 2 [BT] D 8 4 D -6 [UP] D 8 4 D -6 10 -1 7 -3 -2 [BT] D 8 4 6 -7 9 10 [BT] -8 D 2 [UP] -8 D 2 4 D -7 [UP] -8 D 2 4 D -7 D -5 [UP] -8 D 2 4 D -7 D -5 -6 -3 10 1 [BT] -8 D 2 4 D -7 5 -10 1 -9 3 6 [BT] -8 D 2 4 7 3 [BT] -8 -2 D 3 [UP] -8 -2 D 3 -7 6 4 5 9 10 [BT] -8 -2 -3 -7 -1 -10 -5 4 UNSAT\",\n",
    "    \"[BOS] 1 9 -5 0 3 6 -8 0 1 -7 4 0 5 -9 -1 0 -3 -7 -1 0 -6 -2 -3 0 6 2 3 0 -7 -6 5 0 -6 -9 7 0 3 5 -8 0 9 4 -6 0 -1 -4 -8 0 2 1 7 0 9 4 7 0 5 7 8 0 5 -1 -7 0 7 2 3 0 5 2 -7 0 3 -5 4 0 8 -7 -4 0 9 3 1 0 8 -5 7 0 3 9 -6 0 4 9 2 0 4 1 -2 0 -3 9 -4 0 -9 -6 -8 0 3 8 2 0 5 -7 4 0 -5 3 -8 0 -8 1 -4 0 -7 2 4 0 9 -6 1 0 9 -1 5 0 8 -1 -9 0 9 4 5 0 7 1 -5 0 -6 -1 -2 0 -1 -9 3 0 [SEP] D 1 [UP] D 1 D 2 [UP] D 1 D 2 -6 D 3 [UP] D 1 D 2 -6 D 3 -7 D 4 [UP] D 1 D 2 -6 D 3 -7 D 4 9 5 8 [BT] D 1 D 2 -6 D 3 -7 -4 9 5 8 SAT\",\n",
    "    \"[BOS] 4 -10 -9 0 -11 -2 5 0 13 5 -1 0 -11 -14 2 0 15 -13 -11 0 -5 -11 -1 0 -3 -1 12 0 13 8 5 0 11 -12 -2 0 5 10 -1 0 5 -1 -15 0 11 6 -1 0 6 5 -4 0 -11 5 -7 0 -8 -3 -7 0 9 -15 -11 0 -11 -13 15 0 -7 4 6 0 2 -10 -14 0 8 12 -6 0 -7 -15 -10 0 -13 7 10 0 7 15 -4 0 7 -4 9 0 -4 6 2 0 15 4 1 0 15 -6 -7 0 8 5 -6 0 -3 -6 -2 0 -4 11 -2 0 2 9 -8 0 -12 -11 -5 0 -11 -6 14 0 -2 -6 -10 0 -3 -4 -11 0 6 -9 15 0 -12 -5 15 0 -4 9 -3 0 -4 2 13 0 -2 -12 8 0 -5 -12 15 0 15 7 6 0 -3 5 -2 0 2 -4 -13 0 -3 -15 -11 0 15 -4 -9 0 -3 8 -11 0 13 -1 14 0 -10 3 7 0 8 -14 2 0 5 13 7 0 8 4 6 0 13 -4 6 0 -2 7 10 0 -12 -9 7 0 8 4 -10 0 4 -14 -12 0 -10 -13 5 0 -7 -5 3 0 5 -10 -14 0 10 -13 -11 0 -13 -6 10 0 10 11 4 0 [SEP] D 14 [UP] D 14 D 13 [UP] D 14 D 13 D 1 [UP] D 14 D 13 D 1 D 10 [UP] D 14 D 13 D 1 D 10 2 -6 11 5 [BT] D 14 D 13 D 1 -10 5 -11 6 [BT] D 14 D 13 -1 D 8 [UP] D 14 D 13 -1 D 8 D 10 [UP] D 14 D 13 -1 D 8 D 10 2 -6 5 D -11 [UP] D 14 D 13 -1 D 8 D 10 2 -6 5 D -11 -12 -4 -9 -7 15 3 SAT\",\n",
    "    \"[BOS] -12 -9 6 0 11 14 -15 0 -3 2 -13 0 -11 6 13 0 15 -3 -14 0 6 -4 -1 0 -15 -13 11 0 2 -14 -5 0 14 15 -1 0 15 11 -8 0 12 9 6 0 -3 -2 1 0 -11 -3 8 0 -7 5 -14 0 -8 4 -15 0 1 -10 15 0 13 3 -1 0 12 -9 1 0 -2 7 14 0 3 7 -1 0 -7 2 13 0 -1 -8 13 0 7 4 13 0 6 -9 -7 0 -14 -9 -8 0 14 -11 13 0 -3 -10 1 0 6 15 -11 0 -3 -1 -10 0 -10 -9 11 0 2 13 5 0 8 -4 -12 0 -10 4 11 0 2 -13 -6 0 9 5 -15 0 8 4 -15 0 6 -15 -12 0 15 -11 10 0 5 -12 1 0 -4 -1 12 0 -7 5 11 0 9 -8 -1 0 -8 -9 4 0 -3 8 10 0 2 -5 -7 0 -13 -10 12 0 -8 5 1 0 -12 -3 -6 0 -11 -14 10 0 12 14 3 0 3 9 15 0 6 12 -15 0 9 3 -11 0 -3 13 6 0 7 1 -15 0 3 2 1 0 -12 3 11 0 -7 -1 -13 0 -10 -5 -14 0 -9 -14 1 0 5 -14 -7 0 -2 13 -5 0 14 3 -6 0 [SEP] D 1 [UP] D 1 D 13 [UP] D 1 D 13 -7 3 -10 8 2 14 15 9 [BT] D 1 -13\"\n",
    "]\n",
    "# sample_input_unsat = \"[BOS] 1 -2 3 0 -1 2 -3 0 2 4 -1 0 1 -3 4 0 -2 -3 -4 0 3 -4 -1 0 -2 4 -1 0 1 2 -3 0 1 2 3 0 [SEP]\"\n",
    "sample_input = sample_inputs[test_id]\n",
    "prompt_str = sample_input.split(\"[SEP]\")[0] + \"[SEP]\"\n",
    "trace_str = sample_input.split(\"[SEP]\")[1].strip()\n",
    "prompt_tokens = prompt_str.split()\n",
    "trace_tokens = trace_str.split()\n",
    "full_trace_tokens = prompt_tokens + trace_tokens\n",
    "\n",
    "context_len = len(full_trace_tokens) + 500\n",
    "# num_vars = max(int(tok) for tok in full_trace_tokens if tok.isdigit())\n",
    "# num_clauses = prompt_str.count(\"0\") + 2\n",
    "num_vars = 15\n",
    "num_clauses = 71\n",
    "context_len = 1200\n",
    "sop, vocab, sop_logs = dpll(num_vars=num_vars, num_clauses=num_clauses, context_len=context_len, return_logs=True, mean_exactness=100, nonsep_penalty=100)\n"
   ],
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'up_k' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mNameError\u001B[0m                                 Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[1], line 31\u001B[0m\n\u001B[1;32m     29\u001B[0m num_clauses \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m71\u001B[39m\n\u001B[1;32m     30\u001B[0m context_len \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1200\u001B[39m\n\u001B[0;32m---> 31\u001B[0m sop, vocab, sop_logs \u001B[38;5;241m=\u001B[39m \u001B[43mdpll\u001B[49m\u001B[43m(\u001B[49m\u001B[43mnum_vars\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mnum_vars\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mnum_clauses\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mnum_clauses\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcontext_len\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mcontext_len\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mreturn_logs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmean_exactness\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m100\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mnonsep_penalty\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m100\u001B[39;49m\u001B[43m)\u001B[49m\n",
      "File \u001B[0;32m~/Desktop/Research/LLM Reasoning/GPT_SAT_Solve/theory/dpll.py:193\u001B[0m, in \u001B[0;36mdpll\u001B[0;34m(num_vars, num_clauses, context_len, mean_exactness, nonsep_penalty, return_logs)\u001B[0m\n\u001B[1;32m    191\u001B[0m \u001B[38;5;66;03m# Heuristic for decision literal selection: Find the most common literal in remaining clauses\u001B[39;00m\n\u001B[1;32m    192\u001B[0m heuristic_q \u001B[38;5;241m=\u001B[39m [t(r_i, num_vars, true_vec\u001B[38;5;241m=\u001B[39m(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m10\u001B[39m, \u001B[38;5;241m1\u001B[39m), false_vec\u001B[38;5;241m=\u001B[39m(\u001B[38;5;241m1\u001B[39m, \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m10\u001B[39m), none_vec\u001B[38;5;241m=\u001B[39m(\u001B[38;5;241m0\u001B[39m, \u001B[38;5;241m0\u001B[39m)), ones]\n\u001B[0;32m--> 193\u001B[0m heuristic_k \u001B[38;5;241m=\u001B[39m \u001B[43mup_k\u001B[49m\n\u001B[1;32m    194\u001B[0m heuristic_v \u001B[38;5;241m=\u001B[39m r_i\n\u001B[1;32m    195\u001B[0m heuristic_o \u001B[38;5;241m=\u001B[39m SelfAttention(heuristic_q, heuristic_k, heuristic_v)\u001B[38;5;241m.\u001B[39mnamed(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mheuristic_o\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n",
      "\u001B[0;31mNameError\u001B[0m: name 'up_k' is not defined"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-23T04:49:56.556379Z",
     "start_time": "2024-09-23T04:49:45.928520Z"
    }
   },
   "source": [
    "abstract_out = sop.abstract_eval(full_trace_tokens)\n",
    "plt.imshow(abstract_out, cmap='Greys')\n",
    "plt.show()"
   ],
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m                         Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[2], line 2\u001B[0m\n\u001B[1;32m      1\u001B[0m abstract_out \u001B[38;5;241m=\u001B[39m sop\u001B[38;5;241m.\u001B[39mabstract_eval(full_trace_tokens)\n\u001B[0;32m----> 2\u001B[0m \u001B[43mplt\u001B[49m\u001B[38;5;241m.\u001B[39mimshow(abstract_out, cmap\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mGreys\u001B[39m\u001B[38;5;124m'\u001B[39m)\n\u001B[1;32m      3\u001B[0m plt\u001B[38;5;241m.\u001B[39mshow()\n",
      "File \u001B[0;32m_pydevd_bundle/pydevd_cython_darwin_311_64.pyx:1187\u001B[0m, in \u001B[0;36m_pydevd_bundle.pydevd_cython_darwin_311_64.SafeCallWrapper.__call__\u001B[0;34m()\u001B[0m\n",
      "File \u001B[0;32m_pydevd_bundle/pydevd_cython_darwin_311_64.pyx:627\u001B[0m, in \u001B[0;36m_pydevd_bundle.pydevd_cython_darwin_311_64.PyDBFrame.trace_dispatch\u001B[0;34m()\u001B[0m\n",
      "File \u001B[0;32m_pydevd_bundle/pydevd_cython_darwin_311_64.pyx:1103\u001B[0m, in \u001B[0;36m_pydevd_bundle.pydevd_cython_darwin_311_64.PyDBFrame.trace_dispatch\u001B[0;34m()\u001B[0m\n",
      "File \u001B[0;32m_pydevd_bundle/pydevd_cython_darwin_311_64.pyx:1061\u001B[0m, in \u001B[0;36m_pydevd_bundle.pydevd_cython_darwin_311_64.PyDBFrame.trace_dispatch\u001B[0;34m()\u001B[0m\n",
      "File \u001B[0;32m/Applications/PyCharm.app/Contents/plugins/python/helpers-pro/jupyter_debug/pydev_jupyter_plugin.py:169\u001B[0m, in \u001B[0;36mstop\u001B[0;34m(plugin, pydb, frame, event, args, stop_info, arg, step_cmd)\u001B[0m\n\u001B[1;32m    167\u001B[0m     frame \u001B[38;5;241m=\u001B[39m suspend_jupyter(main_debugger, thread, frame, step_cmd)\n\u001B[1;32m    168\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m frame:\n\u001B[0;32m--> 169\u001B[0m         \u001B[43mmain_debugger\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mdo_wait_suspend\u001B[49m\u001B[43m(\u001B[49m\u001B[43mthread\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mframe\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mevent\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43marg\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m    170\u001B[0m         \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;01mTrue\u001B[39;00m\n\u001B[1;32m    171\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;01mFalse\u001B[39;00m\n",
      "File \u001B[0;32m/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevd.py:1201\u001B[0m, in \u001B[0;36mPyDB.do_wait_suspend\u001B[0;34m(self, thread, frame, event, arg, send_suspend_message, is_unhandled_exception)\u001B[0m\n\u001B[1;32m   1198\u001B[0m         from_this_thread\u001B[38;5;241m.\u001B[39mappend(frame_id)\n\u001B[1;32m   1200\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_threads_suspended_single_notification\u001B[38;5;241m.\u001B[39mnotify_thread_suspended(thread_id, stop_reason):\n\u001B[0;32m-> 1201\u001B[0m     \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_do_wait_suspend\u001B[49m\u001B[43m(\u001B[49m\u001B[43mthread\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mframe\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mevent\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43marg\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msuspend_type\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mfrom_this_thread\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[0;32m/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevd.py:1216\u001B[0m, in \u001B[0;36mPyDB._do_wait_suspend\u001B[0;34m(self, thread, frame, event, arg, suspend_type, from_this_thread)\u001B[0m\n\u001B[1;32m   1213\u001B[0m             \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_call_mpl_hook()\n\u001B[1;32m   1215\u001B[0m         \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mprocess_internal_commands()\n\u001B[0;32m-> 1216\u001B[0m         \u001B[43mtime\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msleep\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m0.01\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m   1218\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcancel_async_evaluation(get_current_thread_id(thread), \u001B[38;5;28mstr\u001B[39m(\u001B[38;5;28mid\u001B[39m(frame)))\n\u001B[1;32m   1220\u001B[0m \u001B[38;5;66;03m# process any stepping instructions\u001B[39;00m\n",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-14T04:45:56.102463Z",
     "start_time": "2024-09-14T04:45:50.351968Z"
    }
   },
   "outputs": [],
   "source": [
    "concrete_out = sop.concrete_eval(full_trace_tokens)\n",
    "plt.imshow(concrete_out, cmap='Greys')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-14T04:46:08.468854Z",
     "start_time": "2024-09-14T04:45:56.103934Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "# debugging\n",
    "for name, log_sop in sop_logs.items():\n",
    "    abstract_val = log_sop.abstract_eval(full_trace_tokens).squeeze()\n",
    "    concrete_val = log_sop.concrete_eval(full_trace_tokens).squeeze()\n",
    "    errors = np.abs(abstract_val - concrete_val)\n",
    "    threshold = 0.01\n",
    "    max_error = errors.max()\n",
    "    if max_error > threshold:\n",
    "        print(f\"Large Error SOp {name}: {max_error}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-14T04:46:09.175Z",
     "start_time": "2024-09-14T04:46:08.473950Z"
    }
   },
   "outputs": [],
   "source": [
    "dpll_model, residual_alloc, flat_sop = compile_model(sop, vocab, context_len, return_alloc=True)\n",
    "flattened_sops = flat_sop.all_named_deps()\n",
    "name_alloc = {sop.name: residual_alloc[sop] for sop in residual_alloc.keys() if sop.name is not None}\n",
    "print(\"Model Hidden Size\", dpll_model.hidden_size)\n",
    "dpll_model.summary()\n",
    "out, residual_dict = dpll_model.apply_tokens(full_trace_tokens, residual_alloc=name_alloc)\n",
    "\n",
    "model_out = out.detach().numpy()[0]\n",
    "plt.imshow(model_out, cmap='Greys')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-14T04:46:09.432955Z",
     "start_time": "2024-09-14T04:46:09.175672Z"
    }
   },
   "outputs": [],
   "source": [
    "print(residual_alloc)\n",
    "print(sorted(residual_dict.keys()))\n",
    "print(sorted(flattened_sops.keys()))\n",
    "print(residual_dict['r_i_pre'][0, -1])\n",
    "print(flattened_sops['r_i'].concrete_eval(full_trace_tokens)[-1])\n",
    "print(sop_logs['r_i'].concrete_eval(full_trace_tokens).squeeze()[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-14T04:46:13.582268Z",
     "start_time": "2024-09-14T04:46:09.444988Z"
    }
   },
   "outputs": [],
   "source": [
    "for name, log_sop in sop_logs.items():\n",
    "    if not log_sop.name or log_sop.name not in residual_dict:\n",
    "        continue\n",
    "    concrete_val = log_sop.concrete_eval(full_trace_tokens).squeeze()\n",
    "    model_val = np.squeeze(residual_dict[log_sop.name])\n",
    "    errors = np.abs(model_val - concrete_val)\n",
    "    threshold = 0.01\n",
    "    max_error = errors.max()\n",
    "    if max_error > threshold:\n",
    "        print(f\"Large Error SOp {name}: {max_error}\")\n",
    "    else:\n",
    "        print(f\"{name} within threshold\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-14T04:46:18.323106Z",
     "start_time": "2024-09-14T04:46:13.585332Z"
    }
   },
   "outputs": [],
   "source": [
    "diff = (abstract_out - model_out)[1:]\n",
    "concrete_out = sop.concrete_eval(full_trace_tokens)\n",
    "\n",
    "\n",
    "abs_pred_ids = abstract_out.argmax(axis=1).squeeze()\n",
    "model_pred_ids = model_out.argmax(axis=1).squeeze()\n",
    "concrete_pred_ids = concrete_out.argmax(axis=1).squeeze()\n",
    "\n",
    "abs_pred_tokens = np.array(vocab)[abs_pred_ids]\n",
    "model_pred_tokens = np.array(vocab)[model_pred_ids]\n",
    "concrete_pred_tokens = np.array(vocab)[concrete_pred_ids]\n",
    "\n",
    "print(' '.join(abs_pred_tokens))\n",
    "print(' '.join(model_pred_tokens))\n",
    "print(' '.join(concrete_pred_tokens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-14T04:46:18.329443Z",
     "start_time": "2024-09-14T04:46:18.324583Z"
    }
   },
   "outputs": [],
   "source": [
    "print(abstract_out[-1].round(4))\n",
    "print(concrete_out[-1].round(4))\n",
    "print(np.maximum(model_out[-1].round(4), 0))\n",
    "diff = concrete_out[-1] - abstract_out[-1]\n",
    "print(diff.round(4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-14T04:46:18.386886Z",
     "start_time": "2024-09-14T04:46:18.331497Z"
    }
   },
   "outputs": [],
   "source": [
    "plt.imshow(out.detach().numpy()[0], cmap='Greys')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-14T04:47:13.558198Z",
     "start_time": "2024-09-14T04:46:18.390003Z"
    }
   },
   "outputs": [],
   "source": [
    "print(' '.join(sop_generate(prompt_tokens, vocab, sop)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-14T04:47:14.722410Z",
     "start_time": "2024-09-14T04:47:13.561169Z"
    }
   },
   "outputs": [],
   "source": [
    "print(' '.join(dpll_model.generate(prompt_tokens)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-14T04:47:14.726511Z",
     "start_time": "2024-09-14T04:47:14.723073Z"
    }
   },
   "outputs": [],
   "source": [
    "# Find the max parameter value of model\n",
    "print(dpll_model)\n",
    "print(\"Max Param:\", dpll_model.find_max_parameter())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-14T04:47:14.730481Z",
     "start_time": "2024-09-14T04:47:14.728752Z"
    }
   },
   "outputs": [],
   "source": [
    "dpll_model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-14T04:47:14.732095Z",
     "start_time": "2024-09-14T04:47:14.731003Z"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
