{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook prepares the data for the subsequent notebook `10-Step-Entropy-Analyze.ipynb`, which generates figures illustrating the accumulation of predictive uncertainty in multi-step rollouts, as described in Appendix A.10 of the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "\n",
    "import torch\n",
    "num_devices = torch.cuda.device_count()\n",
    "print(\"Number of visible GPUs:\", num_devices)\n",
    "\n",
    "for i in range(num_devices):\n",
    "    print(f\"GPU {i}: {torch.cuda.get_device_name(i)}\")\n",
    "\n",
    "current_device = torch.cuda.current_device()\n",
    "print(\"Current device index:\", current_device)\n",
    "print(\"Current device name:\", torch.cuda.get_device_name(current_device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "from allen_cahn_equation import (\n",
    "    compute_exact_solution_random_ic_vary_Nx,\n",
    "    visualize_spline_ic,\n",
    "    plot_both_grids\n",
    ")\n",
    "\n",
    "from data_processing import (\n",
    "    SimpleSerializerSettings,\n",
    "    scale_2d_array,\n",
    "    serialize_2d_integers,\n",
    "    extract_training_and_test\n",
    ")\n",
    "\n",
    "from llama_utils import load_model_and_tokenizer, generate_text_multiple\n",
    "\n",
    "MODEL_NAME = \"meta-llama/Llama-3.1-8B\"\n",
    "# MODEL_NAME = \"meta-llama/Llama-3.2-3B\"\n",
    "# MODEL_NAME = \"meta-llama/Llama-3.2-1B\"\n",
    "\n",
    "# Set random seeds for reproducibility\n",
    "seed = 42\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed_all(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define parameters for the Allen-Cahn equation\n",
    "L = 2       # Length of the spatial domain\n",
    "k = 0.001   # Thermal diffusivity\n",
    "T = 0.5     # Total simulation time\n",
    "Nx = 14     # Number of spatial steps (excluding boundary points)\n",
    "Nt = 25     # Number of time steps\n",
    "dx = L/(Nx+1)\n",
    "dt = T/Nt\n",
    "\n",
    "# Serialization setup\n",
    "settings = SimpleSerializerSettings(space_sep=\",\", time_sep=\";\")\n",
    "\n",
    "# Example: Demonstrating the process of generating and visualizing a random initial condition\n",
    "init_cond_random = np.random.uniform(-0.5, 0.5, size=Nx)\n",
    "fig = visualize_spline_ic(L, Nx, init_cond_random)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Example: Demonstrating how to resample spatial points from an underlying random initial condition\n",
    "Nx_original = Nx\n",
    "Nx_new = 14\n",
    "fig, cs, init_cond_random_new = plot_both_grids(L, Nx_original, Nx_new, init_cond_random)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model, tokenizer = load_model_and_tokenizer(MODEL_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_multiple_timesteps_with_scores(prompt, model, tokenizer, Nx, number_of_future_predictions):\n",
    "    \"\"\"\n",
    "    Generate multiple future time steps, storing the token probability distributions for each.\n",
    "    Returns:\n",
    "        predictions: List of generated text for each time step\n",
    "        generation_outputs: List of generation outputs containing scores for each time step\n",
    "    \"\"\"\n",
    "    current_prompt = prompt\n",
    "    if not current_prompt.endswith(\";\"):\n",
    "        current_prompt += \";\"\n",
    "    predictions = []\n",
    "    generation_outputs = []\n",
    "    for step in range(number_of_future_predictions):\n",
    "        # Generate next time step\n",
    "        generated_text, gen_output = generate_text_multiple(\n",
    "            prompt=current_prompt,\n",
    "            model=model,\n",
    "            tokenizer=tokenizer,\n",
    "            Nx=Nx,\n",
    "        )\n",
    "        predictions.append(generated_text.strip())\n",
    "        generation_outputs.append(gen_output)\n",
    "        # Update prompt with the generated prediction\n",
    "        current_prompt += generated_text.strip() + \";\"\n",
    "    \n",
    "    return predictions, generation_outputs\n",
    "\n",
    "\n",
    "def calculate_entropies(generation_outputs, Nx):\n",
    "    \"\"\"\n",
    "    Calculate entropy values from generation outputs.\n",
    "    Returns:\n",
    "        entropies: Array of shape (Nx, n_future_steps) containing entropy values\n",
    "        avg_entropy: Average entropy across spatial points\n",
    "    \"\"\"\n",
    "    n_future_steps = len(generation_outputs)\n",
    "    entropies = np.zeros((Nx, n_future_steps))\n",
    "    # Calculate entropy for all spatial points\n",
    "    for time_idx in range(n_future_steps):\n",
    "        gen_output = generation_outputs[time_idx]\n",
    "        for grid_idx in range(Nx):\n",
    "            token_position = grid_idx * 2  # accounting for spatial separators\n",
    "            if token_position < len(gen_output.scores):\n",
    "                logits = gen_output.scores[token_position][0]\n",
    "                p = torch.softmax(logits, -1).clamp_min(1e-30)\n",
    "                entropy = -(p * torch.log(p)).sum().item()\n",
    "                entropies[grid_idx, time_idx] = entropy\n",
    "    # Calculate average entropy across all spatial points\n",
    "    avg_entropy = entropies.mean(axis=0)\n",
    "    return entropies, avg_entropy\n",
    "\n",
    "\n",
    "def analyze_multistep_distributions(model, tokenizer, u_exact_serialized,\n",
    "                                   input_time_steps, number_of_future_predictions,\n",
    "                                   Nx, settings):\n",
    "    \"\"\"\n",
    "    Main function to analyze token distributions across multiple future predictions.\n",
    "    \"\"\"\n",
    "    train_serial, test_serial = extract_training_and_test(u_exact_serialized, input_time_steps, settings)\n",
    "    predictions, generation_outputs = generate_multiple_timesteps_with_scores(\n",
    "        prompt=train_serial,\n",
    "        model=model,\n",
    "        tokenizer=tokenizer,\n",
    "        Nx=Nx,\n",
    "        number_of_future_predictions=number_of_future_predictions\n",
    "    )\n",
    "    entropies, avg_entropy = calculate_entropies(generation_outputs, Nx)\n",
    "    \n",
    "    return predictions, generation_outputs, entropies, avg_entropy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set parameters\n",
    "input_time_steps = 16\n",
    "number_of_future_predictions = 10\n",
    "n_initial_conditions = 20\n",
    "n_llm_seeds = 20  \n",
    "all_ic_mean_entropies = []  # mean entropy for each IC\n",
    "all_ic_mean_avg_entropies = []  # mean average entropy for each IC\n",
    "all_initial_conditions = []\n",
    "\n",
    "# Run analysis across multiple initial conditions\n",
    "for ic_idx in tqdm(range(n_initial_conditions)):\n",
    "    random.seed(ic_idx)\n",
    "    np.random.seed(ic_idx)\n",
    "    torch.manual_seed(ic_idx)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(ic_idx)\n",
    "    # Generate a new random initial condition for this iteration\n",
    "    init_cond_random = np.random.uniform(-0.5, 0.5, size=Nx)\n",
    "    all_initial_conditions.append(init_cond_random)\n",
    "    # Create spline object for this initial condition\n",
    "    fig, cs, init_cond_random_new = plot_both_grids(L, Nx_original, Nx_new, init_cond_random)\n",
    "    plt.close(fig)\n",
    "    # Compute exact solution for this initial condition\n",
    "    u_exact = compute_exact_solution_random_ic_vary_Nx(L, k, T, Nx, Nt, spline_obj=cs)\n",
    "    u_exact_scaled, vmin_exact, vmax_exact = scale_2d_array(u_exact)\n",
    "    u_exact_serialized = serialize_2d_integers(u_exact_scaled, settings)\n",
    "\n",
    "    ic_entropies = []\n",
    "    ic_avg_entropies = []\n",
    "    \n",
    "    # Run LLM generation multiple times for this initial condition\n",
    "    for llm_seed_idx in range(n_llm_seeds):\n",
    "        random.seed(llm_seed_idx)\n",
    "        np.random.seed(llm_seed_idx)\n",
    "        torch.manual_seed(llm_seed_idx)\n",
    "        if torch.cuda.is_available():\n",
    "            torch.cuda.manual_seed_all(llm_seed_idx)\n",
    "        predictions, generation_outputs, entropies, avg_entropy = analyze_multistep_distributions(\n",
    "            model=model,\n",
    "            tokenizer=tokenizer,\n",
    "            u_exact_serialized=u_exact_serialized,\n",
    "            input_time_steps=input_time_steps,\n",
    "            number_of_future_predictions=number_of_future_predictions,\n",
    "            Nx=Nx,\n",
    "            settings=settings,\n",
    "        )\n",
    "        ic_entropies.append(entropies)\n",
    "        ic_avg_entropies.append(avg_entropy)\n",
    "    \n",
    "    # Average over LLM runs to get point estimate for this initial condition\n",
    "    ic_entropies = np.array(ic_entropies)\n",
    "    ic_avg_entropies = np.array(ic_avg_entropies)\n",
    "    mean_entropy_for_ic = ic_entropies.mean(axis=0)\n",
    "    mean_avg_entropy_for_ic = ic_avg_entropies.mean(axis=0)\n",
    "    all_ic_mean_entropies.append(mean_entropy_for_ic)\n",
    "    all_ic_mean_avg_entropies.append(mean_avg_entropy_for_ic)\n",
    "\n",
    "# Calculate statistics across initial conditions\n",
    "all_ic_mean_entropies = np.array(all_ic_mean_entropies)\n",
    "all_ic_mean_avg_entropies = np.array(all_ic_mean_avg_entropies)\n",
    "all_initial_conditions = np.array(all_initial_conditions)\n",
    "mean_entropies = all_ic_mean_entropies.mean(axis=0)\n",
    "std_entropies = all_ic_mean_entropies.std(axis=0, ddof=1)\n",
    "mean_avg_entropy = all_ic_mean_avg_entropies.mean(axis=0)\n",
    "std_avg_entropy = all_ic_mean_avg_entropies.std(axis=0, ddof=1)\n",
    "se_entropies = std_entropies / np.sqrt(n_initial_conditions)\n",
    "se_avg_entropy = std_avg_entropy / np.sqrt(n_initial_conditions)\n",
    "\n",
    "np.savez_compressed(\n",
    "    \"8B_10_step_token_dist.npz\",\n",
    "    mean_entropies_8B=mean_entropies,\n",
    "    std_entropies_8B=std_entropies,\n",
    "    se_entropies_8B=se_entropies,\n",
    "    mean_avg_entropy_8B=mean_avg_entropy,\n",
    "    std_avg_entropy_8B=std_avg_entropy,\n",
    "    se_avg_entropy_8B=se_avg_entropy,\n",
    "    all_initial_conditions=all_initial_conditions,\n",
    "    all_ic_mean_entropies=all_ic_mean_entropies,\n",
    "    all_ic_mean_avg_entropies=all_ic_mean_avg_entropies,\n",
    "    n_initial_conditions=n_initial_conditions,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "A100",
   "machine_shape": "hm",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "llama3-jiajun",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "4bca38f991eb477fb6f6448ed40b7953": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_f282a01a1fa94fd3841fa84b0bf85801",
      "placeholder": "​",
      "style": "IPY_MODEL_bfdb859e858e42869e6da9b1482a5702",
      "value": "Loading checkpoint shards: 100%"
     }
    },
    "79d7edd2ec684e25b3674d375812e5fc": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "82fd26e315b6460ab439920956ecfc4b": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "8d598e552e3e4f3f9ffd47c953554ad0": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_79d7edd2ec684e25b3674d375812e5fc",
      "placeholder": "​",
      "style": "IPY_MODEL_b9ca4f266f0247a3aca54430f78c7bf4",
      "value": " 2/2 [00:04&lt;00:00,  2.25s/it]"
     }
    },
    "b9ca4f266f0247a3aca54430f78c7bf4": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "be3db56ffe3047a6ab8493d65d18f5c6": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "bfdb859e858e42869e6da9b1482a5702": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "e0dd1da9791a4911932193befbfd4dd0": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "e8bbace417ee4d74ae8e9fdcaf023b44": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_be3db56ffe3047a6ab8493d65d18f5c6",
      "max": 2,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_e0dd1da9791a4911932193befbfd4dd0",
      "value": 2
     }
    },
    "f282a01a1fa94fd3841fa84b0bf85801": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "fd21f3afeb514a51a73822346535fdec": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_4bca38f991eb477fb6f6448ed40b7953",
       "IPY_MODEL_e8bbace417ee4d74ae8e9fdcaf023b44",
       "IPY_MODEL_8d598e552e3e4f3f9ffd47c953554ad0"
      ],
      "layout": "IPY_MODEL_82fd26e315b6460ab439920956ecfc4b"
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
