{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook prepares the data for the subsequent notebook `10-Step-Analyze.ipynb`, which generates figures illustrating the multi-step prediction accuracy of the heat equation with Neumann boundary conditions (with a=-0.5, b=0.5), as described in Appendix A.3 of the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "\n",
    "import torch\n",
    "num_devices = torch.cuda.device_count()\n",
    "print(\"Number of visible GPUs:\", num_devices)\n",
    "\n",
    "for i in range(num_devices):\n",
    "    print(f\"GPU {i}: {torch.cuda.get_device_name(i)}\")\n",
    "\n",
    "current_device = torch.cuda.current_device()\n",
    "print(\"Current device index:\", current_device)\n",
    "print(\"Current device name:\", torch.cuda.get_device_name(current_device))\n",
    "\n",
    "import random\n",
    "import numpy as np\n",
    "import scipy.stats as stats\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from tqdm import tqdm\n",
    "from data_processing import (\n",
    "    SimpleSerializerSettings, scale_2d_array, unscale_2d_array,\n",
    "    serialize_2d_integers, deserialize_2d_integers, extract_training_and_test\n",
    ")\n",
    "from heat_equation import (\n",
    "    compute_exact_solution_random_ic_vary_Nx,\n",
    "    visualize_spline_ic,\n",
    "    solve_heat_ftcs, solve_heat_btcs,\n",
    ")\n",
    "\n",
    "from llama_utils import load_model_and_tokenizer, generate_text_multiple\n",
    "\n",
    "MODEL_NAME = \"meta-llama/Llama-3.1-8B\"\n",
    "# MODEL_NAME = \"meta-llama/Llama-3.2-3B\"\n",
    "# MODEL_NAME = \"meta-llama/Llama-3.2-1B\"\n",
    "\n",
    "# Set random seeds for reproducibility\n",
    "seed = 42\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed_all(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model, tokenizer = load_model_and_tokenizer(MODEL_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example: Demonstrating the process of generating and visualizing a random initial condition\n",
    "L = 2\n",
    "Nx = 14\n",
    "init_cond_random = np.random.uniform(-0.5, 0.5, size=Nx)\n",
    "fig, cs = visualize_spline_ic(L, Nx, init_cond_random)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define parameters for the Heat equation solver\n",
    "L = 2       # Length of the spatial domain\n",
    "k = 0.01    # Thermal diffusivity\n",
    "T = 0.5     # Total simulation time\n",
    "Nt = 25     # Number of time steps\n",
    "dt = T/Nt\n",
    "Nx = 14     # Number of spatial steps (excluding boundary points)\n",
    "dx = L/(Nx+1)\n",
    "settings = SimpleSerializerSettings(space_sep=\",\", time_sep=\";\")\n",
    "input_time_steps = 16\n",
    "number_of_future_predictions = 10\n",
    "n_ics = 20\n",
    "n_runs_per_ic = 20\n",
    "\n",
    "def compute_integral_with_boundaries(u_interior, Nx, L):\n",
    "    \"\"\"\n",
    "    Compute integral of solution including estimated boundary values.\n",
    "    \"\"\"\n",
    "    x_full = np.linspace(-L/2, L/2, Nx + 2)\n",
    "    u_full = np.zeros(Nx + 2)\n",
    "    u_full[1:-1] = u_interior\n",
    "    # Estimate boundary values using Neumann BC (consistent with FD schemes)\n",
    "    if Nx >= 2:\n",
    "        u_full[0] = (4*u_interior[0] - u_interior[1]) / 3.0\n",
    "        u_full[-1] = (4*u_interior[-1] - u_interior[-2]) / 3.0\n",
    "    else:\n",
    "        u_full[0] = u_full[-1] = u_interior[0]\n",
    "    \n",
    "    return np.trapezoid(u_full, x_full)\n",
    "\n",
    "\n",
    "def compute_exact_solution_conservation(\n",
    "    u_exact_all, Nx, L, initial_integral_fine\n",
    "):\n",
    "    \"\"\"\n",
    "    Compute conservation error for the high-accuracy exact solution\n",
    "    evaluated on the coarse grid. This represents the best achievable\n",
    "    conservation on the coarse grid.\n",
    "    \"\"\"\n",
    "    integral_changes = []\n",
    "    for t in range(u_exact_all.shape[0]):\n",
    "        # Compute integral on coarse grid with boundary reconstruction\n",
    "        coarse_integral = compute_integral_with_boundaries(u_exact_all[t], Nx, L)\n",
    "        # Compare against fine-grid initial integral\n",
    "        relative_change = np.abs((coarse_integral - initial_integral_fine) / \n",
    "                                np.abs(initial_integral_fine) * 100)\n",
    "        integral_changes.append(relative_change)\n",
    "    \n",
    "    return integral_changes\n",
    "\n",
    "\n",
    "def log_ci(mean, std, n, tcrit):\n",
    "    \"\"\"\n",
    "    95% CI for log10 axis\n",
    "    mean : arithmetic mean of the n samples\n",
    "    std : sample std of the n samples\n",
    "    n : number of samples\n",
    "    tcrit: two-sided t critical value\n",
    "    \"\"\"\n",
    "    se = std / np.sqrt(n)  # SE in linear space\n",
    "    se_log = se / (mean * np.log(10))  # delta-method SE in log space\n",
    "    mean_log = np.log10(mean)\n",
    "    delta_log = tcrit * se_log\n",
    "    return 10**(mean_log - delta_log), 10**(mean_log + delta_log)\n",
    "\n",
    "\n",
    "def linear_ci(mean, std, n, tcrit):\n",
    "    \"\"\"95% CI for linear scale\"\"\"\n",
    "    se = std / np.sqrt(n)\n",
    "    delta = tcrit * se\n",
    "    return mean - delta, mean + delta\n",
    "\n",
    "\n",
    "def llm_multi_predictions_with_conservation(\n",
    "    full_serialized_data, input_time_steps, number_of_future_predictions,\n",
    "    model, tokenizer, Nx, settings, vmin, vmax, n_seeds, initial_integral\n",
    "):\n",
    "    \"\"\"LLM prediction function that also tracks energy conservation\"\"\"\n",
    "    all_rows_scaled = deserialize_2d_integers(full_serialized_data, settings)\n",
    "    all_seeds_integral_changes = []\n",
    "    all_seeds_max_diffs = []\n",
    "    all_seeds_rmses = []\n",
    "    for seed in range(n_seeds):\n",
    "        if n_seeds > 1:\n",
    "            random.seed(seed)\n",
    "            np.random.seed(seed)\n",
    "            torch.manual_seed(seed)\n",
    "            if torch.cuda.is_available():\n",
    "                torch.cuda.manual_seed_all(seed)\n",
    "        # run a single LLM prediction sequence\n",
    "        train_serial, _ = extract_training_and_test(full_serialized_data, input_time_steps, settings)\n",
    "        if not train_serial.endswith(settings.time_sep):\n",
    "            train_serial += settings.time_sep\n",
    "        current_prompt = train_serial\n",
    "        integral_changes = []\n",
    "        max_diffs = []\n",
    "        rmses = []\n",
    "        for step_idx in range(number_of_future_predictions):\n",
    "            gt_idx = input_time_steps + step_idx\n",
    "            if gt_idx >= all_rows_scaled.shape[0]:\n",
    "                # Stop if we exceed the available ground truth\n",
    "                break\n",
    "            next_line, _ = generate_text_multiple(current_prompt, model, tokenizer, Nx)\n",
    "            next_line = next_line.strip()\n",
    "            predicted_scaled_2d = deserialize_2d_integers(next_line, settings)\n",
    "            predicted_unscaled_2d = unscale_2d_array(predicted_scaled_2d, vmin, vmax)\n",
    "            if predicted_unscaled_2d.ndim == 2 and predicted_unscaled_2d.shape[0] == 1:\n",
    "                pred_unscaled = predicted_unscaled_2d[0]\n",
    "            else:\n",
    "                pred_unscaled = predicted_unscaled_2d\n",
    "            # Compute integral and relative change\n",
    "            pred_integral = compute_integral_with_boundaries(pred_unscaled, Nx, L)\n",
    "            relative_change = np.abs((pred_integral - initial_integral) / np.abs(initial_integral) * 100)\n",
    "            integral_changes.append(relative_change)\n",
    "            # Compute regular error metrics\n",
    "            gt_scaled = all_rows_scaled[gt_idx]\n",
    "            gt_unscaled = unscale_2d_array(gt_scaled[np.newaxis, :], vmin, vmax)[0]\n",
    "            max_diff = np.max(np.abs(pred_unscaled - gt_unscaled))\n",
    "            rmse = np.sqrt(np.mean((pred_unscaled - gt_unscaled)**2))\n",
    "            max_diffs.append(max_diff)\n",
    "            rmses.append(rmse)\n",
    "            current_prompt += next_line + settings.time_sep\n",
    "        all_seeds_integral_changes.append(integral_changes)\n",
    "        all_seeds_max_diffs.append(max_diffs)\n",
    "        all_seeds_rmses.append(rmses)\n",
    "    max_steps = min(len(changes) for changes in all_seeds_integral_changes)\n",
    "    avg_integral_changes = []\n",
    "    std_integral_changes = []\n",
    "    avg_max_diffs = []\n",
    "    avg_rmses = []\n",
    "    for step in range(max_steps):\n",
    "        step_changes = [seed_changes[step] for seed_changes in all_seeds_integral_changes]\n",
    "        avg_integral_changes.append(np.mean(step_changes))\n",
    "        std_integral_changes.append(np.std(step_changes, ddof=1))\n",
    "        step_max_diffs = [seed_diffs[step] for seed_diffs in all_seeds_max_diffs]\n",
    "        step_rmses = [seed_rmses[step] for seed_rmses in all_seeds_rmses]\n",
    "        avg_max_diffs.append(np.mean(step_max_diffs))\n",
    "        avg_rmses.append(np.mean(step_rmses))\n",
    "    \n",
    "    return avg_integral_changes, std_integral_changes, avg_max_diffs, avg_rmses\n",
    "\n",
    "\n",
    "def llm_multi_predictions_with_conservation_smaller_model(\n",
    "    full_serialized_data, input_time_steps, number_of_future_predictions,\n",
    "    model, tokenizer, Nx, settings, vmin, vmax, n_seeds, initial_integral,\n",
    "    max_retries=10\n",
    "):\n",
    "    \"\"\"Modified LLM prediction function with in-loop retry for smaller models that also tracks energy conservation\"\"\"\n",
    "    all_rows_scaled = deserialize_2d_integers(full_serialized_data, settings)\n",
    "    all_seeds_integral_changes = []\n",
    "    all_seeds_max_diffs = []\n",
    "    all_seeds_rmses = []\n",
    "    for seed in range(n_seeds):\n",
    "        if n_seeds > 1:\n",
    "            random.seed(seed)\n",
    "            np.random.seed(seed)\n",
    "            torch.manual_seed(seed)\n",
    "            if torch.cuda.is_available():\n",
    "                torch.cuda.manual_seed_all(seed)\n",
    "        # Initialize for this seed\n",
    "        train_serial, _ = extract_training_and_test(full_serialized_data, input_time_steps, settings)\n",
    "        if not train_serial.endswith(settings.time_sep):\n",
    "            train_serial += settings.time_sep\n",
    "        current_prompt = train_serial\n",
    "        integral_changes = []\n",
    "        max_diffs = []\n",
    "        rmses = []\n",
    "        for step_idx in range(number_of_future_predictions):\n",
    "            gt_idx = input_time_steps + step_idx\n",
    "            if gt_idx >= all_rows_scaled.shape[0]:\n",
    "                break\n",
    "            valid_step = False\n",
    "            for attempt in range(max_retries):\n",
    "                next_line, _ = generate_text_multiple(current_prompt, model, tokenizer, Nx)\n",
    "                next_line = next_line.strip()\n",
    "                pred_scaled_2d = deserialize_2d_integers(next_line, settings)\n",
    "                if pred_scaled_2d.shape == (1, Nx):\n",
    "                    valid_step = True\n",
    "                    break\n",
    "                else:\n",
    "                    print(f\"  step {step_idx} attempt {attempt} wrong shape {pred_scaled_2d.shape}, retrying...\")\n",
    "            if not valid_step:\n",
    "                print(f\"  Warning: step {step_idx} failed all {max_retries} retries; using last output\")\n",
    "            predicted_unscaled_2d = unscale_2d_array(pred_scaled_2d, vmin, vmax)\n",
    "            if predicted_unscaled_2d.ndim == 2 and predicted_unscaled_2d.shape[0] == 1:\n",
    "                pred_unscaled = predicted_unscaled_2d[0]\n",
    "            else:\n",
    "                pred_unscaled = predicted_unscaled_2d\n",
    "            # Compute integral and relative change\n",
    "            pred_integral = compute_integral_with_boundaries(pred_unscaled, Nx, L)\n",
    "            relative_change = np.abs((pred_integral - initial_integral) / np.abs(initial_integral) * 100)\n",
    "            integral_changes.append(relative_change)\n",
    "            # Compute error metrics\n",
    "            gt_scaled = all_rows_scaled[gt_idx]\n",
    "            gt_unscaled = unscale_2d_array(gt_scaled[np.newaxis, :], vmin, vmax)[0]\n",
    "            max_diff = np.max(np.abs(pred_unscaled - gt_unscaled))\n",
    "            rmse = np.sqrt(np.mean((pred_unscaled - gt_unscaled)**2))\n",
    "            max_diffs.append(max_diff)\n",
    "            rmses.append(rmse)\n",
    "            current_prompt += next_line + settings.time_sep\n",
    "        all_seeds_integral_changes.append(integral_changes)\n",
    "        all_seeds_max_diffs.append(max_diffs)\n",
    "        all_seeds_rmses.append(rmses)\n",
    "    max_steps = min(len(changes) for changes in all_seeds_integral_changes)\n",
    "    avg_integral_changes = []\n",
    "    std_integral_changes = []\n",
    "    avg_max_diffs = []\n",
    "    avg_rmses = []\n",
    "    for step in range(max_steps):\n",
    "        step_changes = [seed_changes[step] for seed_changes in all_seeds_integral_changes]\n",
    "        avg_integral_changes.append(np.mean(step_changes))\n",
    "        std_integral_changes.append(np.std(step_changes, ddof=1))\n",
    "        step_max_diffs = [seed_diffs[step] for seed_diffs in all_seeds_max_diffs]\n",
    "        step_rmses = [seed_rmses[step] for seed_rmses in all_seeds_rmses]\n",
    "        avg_max_diffs.append(np.mean(step_max_diffs))\n",
    "        avg_rmses.append(np.mean(step_rmses))\n",
    "    \n",
    "    return avg_integral_changes, std_integral_changes, avg_max_diffs, avg_rmses\n",
    "\n",
    "\n",
    "def finite_difference_predictions_with_conservation(\n",
    "    full_serialized_data, input_time_steps, number_of_future_predictions,\n",
    "    settings, vmin, vmax, L, k, Nt, Nx, T, initial_integral\n",
    "):\n",
    "    \"\"\"Modified FD prediction function that also tracks energy conservation\"\"\"\n",
    "    # Extract full solution from serialized data\n",
    "    all_rows_scaled = deserialize_2d_integers(full_serialized_data, settings)\n",
    "    dt = T / Nt\n",
    "    ftcs_integral_changes = []\n",
    "    btcs_integral_changes = []\n",
    "    ftcs_max_diffs = []\n",
    "    btcs_max_diffs = []\n",
    "    ftcs_rmses = []\n",
    "    btcs_rmses = []\n",
    "    # Get initial condition from last training step\n",
    "    initial_step = input_time_steps - 1\n",
    "    initial_scaled = all_rows_scaled[initial_step]\n",
    "    initial_unscaled = unscale_2d_array(initial_scaled[np.newaxis, :], vmin, vmax)[0]\n",
    "    current_ftcs = initial_unscaled.copy()\n",
    "    current_btcs = initial_unscaled.copy()\n",
    "    for step_idx in range(number_of_future_predictions):\n",
    "        gt_idx = input_time_steps + step_idx\n",
    "        if gt_idx >= all_rows_scaled.shape[0]:\n",
    "            # Stop if we exceed the available ground truth\n",
    "            break\n",
    "        # Get ground truth for this step\n",
    "        gt_scaled = all_rows_scaled[gt_idx]\n",
    "        gt_unscaled = unscale_2d_array(gt_scaled[np.newaxis, :], vmin, vmax)[0]\n",
    "        # We set T=dt and Nt=1 to evolve exactly one time step\n",
    "        _, ftcs_step, _ = solve_heat_ftcs(L, k, dt, Nx, 1, init_cond=current_ftcs)\n",
    "        _, btcs_step, _ = solve_heat_btcs(L, k, dt, Nx, 1, init_cond=current_btcs)\n",
    "        # Extract predictions (using last time step)\n",
    "        pred_ftcs = ftcs_step[-1]\n",
    "        pred_btcs = btcs_step[-1]\n",
    "        # Compute integrals and relative changes\n",
    "        ftcs_integral = compute_integral_with_boundaries(pred_ftcs, Nx, L)\n",
    "        btcs_integral = compute_integral_with_boundaries(pred_btcs, Nx, L)\n",
    "        ftcs_relative_change = np.abs((ftcs_integral - initial_integral) / np.abs(initial_integral) * 100)\n",
    "        btcs_relative_change = np.abs((btcs_integral - initial_integral) / np.abs(initial_integral) * 100)\n",
    "        ftcs_integral_changes.append(ftcs_relative_change)\n",
    "        btcs_integral_changes.append(btcs_relative_change)\n",
    "        # Compute error metrics\n",
    "        ftcs_max_diffs.append(np.max(np.abs(pred_ftcs - gt_unscaled)))\n",
    "        ftcs_rmses.append(np.sqrt(np.mean((pred_ftcs - gt_unscaled)**2)))\n",
    "        btcs_max_diffs.append(np.max(np.abs(pred_btcs - gt_unscaled)))\n",
    "        btcs_rmses.append(np.sqrt(np.mean((pred_btcs - gt_unscaled)**2)))\n",
    "        # Update current state for next step\n",
    "        current_ftcs = pred_ftcs.copy()\n",
    "        current_btcs = pred_btcs.copy()\n",
    "    \n",
    "    return {\n",
    "        'ftcs': {\n",
    "            'integral_changes': ftcs_integral_changes,\n",
    "            'max_diff': ftcs_max_diffs,\n",
    "            'rmse': ftcs_rmses\n",
    "        },\n",
    "        'btcs': {\n",
    "            'integral_changes': btcs_integral_changes,\n",
    "            'max_diff': btcs_max_diffs,\n",
    "            'rmse': btcs_rmses\n",
    "        }\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate all random initial conditions and spline objects\n",
    "stored_initial_conditions = []\n",
    "stored_spline_objects = []\n",
    "stored_initial_integrals_fine = []\n",
    "stored_initial_integrals_coarse = []\n",
    "for ic_seed in range(n_ics):\n",
    "    # Set seed for this initial condition\n",
    "    random.seed(ic_seed)\n",
    "    np.random.seed(ic_seed)\n",
    "    torch.manual_seed(ic_seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(ic_seed)\n",
    "    # Generate and store the random initial condition\n",
    "    init_cond_random = np.random.uniform(0.0, 1.0, size=Nx)\n",
    "    stored_initial_conditions.append(init_cond_random.copy())\n",
    "    # Create and store the spline object\n",
    "    fig, cs = visualize_spline_ic(L, Nx, init_cond_random)\n",
    "    plt.close(fig)\n",
    "    stored_spline_objects.append(cs)\n",
    "    # Compute initial integral from fine spline grid\n",
    "    x_fine_init = np.linspace(-L/2, L/2, 1000)\n",
    "    u_fine_init = cs(x_fine_init)\n",
    "    initial_integral_fine = np.trapezoid(u_fine_init, x_fine_init)\n",
    "    stored_initial_integrals_fine.append(initial_integral_fine)\n",
    "    # Compute initial integral on coarse grid for reference\n",
    "    initial_integral_coarse = compute_integral_with_boundaries(init_cond_random, Nx, L)\n",
    "    stored_initial_integrals_coarse.append(initial_integral_coarse)\n",
    "stored_initial_conditions_array = np.array(stored_initial_conditions)\n",
    "\n",
    "\n",
    "all_llm_integral_changes = []\n",
    "all_llm_max_diffs = []\n",
    "all_llm_rmses = []\n",
    "all_fd_results = []\n",
    "all_exact_integral_changes = []\n",
    "for ic_seed in tqdm(range(n_ics)):\n",
    "    # Use the stored initial condition, spline, and initial integral\n",
    "    init_cond_random = stored_initial_conditions[ic_seed]\n",
    "    cs = stored_spline_objects[ic_seed]\n",
    "    initial_integral_fine = stored_initial_integrals_fine[ic_seed]\n",
    "    # Compute exact solution for this initial condition\n",
    "    u_exact = compute_exact_solution_random_ic_vary_Nx(L, k, T, Nx, Nt, spline_obj=cs)\n",
    "    u_exact_scaled, vmin_exact, vmax_exact = scale_2d_array(u_exact)\n",
    "    u_exact_serialized = serialize_2d_integers(u_exact_scaled, settings)\n",
    "    exact_integral_changes = compute_exact_solution_conservation(\n",
    "        u_exact[input_time_steps:input_time_steps+number_of_future_predictions],\n",
    "        Nx, L, initial_integral_fine\n",
    "    )\n",
    "    all_exact_integral_changes.append(exact_integral_changes)\n",
    "    # Run LLM predictions with conservation tracking\n",
    "    llm_integral_changes, llm_integral_changes_std, llm_max_diffs, llm_rmses = llm_multi_predictions_with_conservation(\n",
    "        full_serialized_data=u_exact_serialized,\n",
    "        input_time_steps=input_time_steps,\n",
    "        number_of_future_predictions=number_of_future_predictions,\n",
    "        model=model,\n",
    "        tokenizer=tokenizer,\n",
    "        Nx=Nx,\n",
    "        settings=settings,\n",
    "        vmin=vmin_exact,\n",
    "        vmax=vmax_exact,\n",
    "        n_seeds=n_runs_per_ic,\n",
    "        initial_integral=initial_integral_fine\n",
    "    )\n",
    "    # Run finite difference predictions with conservation tracking\n",
    "    fd_results = finite_difference_predictions_with_conservation(\n",
    "        full_serialized_data=u_exact_serialized,\n",
    "        input_time_steps=input_time_steps,\n",
    "        number_of_future_predictions=number_of_future_predictions,\n",
    "        settings=settings,\n",
    "        vmin=vmin_exact,\n",
    "        vmax=vmax_exact,\n",
    "        L=L,\n",
    "        k=k,\n",
    "        Nt=Nt,\n",
    "        Nx=Nx,\n",
    "        T=T,\n",
    "        initial_integral=initial_integral_fine\n",
    "    )\n",
    "    all_llm_integral_changes.append(llm_integral_changes)\n",
    "    all_llm_max_diffs.append(llm_max_diffs)\n",
    "    all_llm_rmses.append(llm_rmses)\n",
    "    all_fd_results.append(fd_results)\n",
    "\n",
    "# Compute averages for conservation errors\n",
    "avg_llm_integral_changes = np.mean(all_llm_integral_changes, axis=0)\n",
    "std_llm_integral_changes = np.std(all_llm_integral_changes, axis=0, ddof=1)\n",
    "ftcs_integral_changes = [res['ftcs']['integral_changes'] for res in all_fd_results]\n",
    "btcs_integral_changes = [res['btcs']['integral_changes'] for res in all_fd_results]\n",
    "avg_ftcs_integral_changes = np.mean(ftcs_integral_changes, axis=0)\n",
    "avg_btcs_integral_changes = np.mean(btcs_integral_changes, axis=0)\n",
    "std_ftcs_integral_changes = np.std(ftcs_integral_changes, axis=0, ddof=1)\n",
    "std_btcs_integral_changes = np.std(btcs_integral_changes, axis=0, ddof=1)\n",
    "avg_exact_integral_changes = np.mean(all_exact_integral_changes, axis=0)\n",
    "std_exact_integral_changes = np.std(all_exact_integral_changes, axis=0, ddof=1)\n",
    "# Compute averages for error metrics\n",
    "avg_llm_max_diffs = np.mean(all_llm_max_diffs, axis=0)\n",
    "avg_llm_rmses = np.mean(all_llm_rmses, axis=0)\n",
    "std_llm_max_diffs = np.std(all_llm_max_diffs, axis=0, ddof=1)\n",
    "std_llm_rmses = np.std(all_llm_rmses, axis=0, ddof=1)\n",
    "ftcs_max_diffs = [res['ftcs']['max_diff'] for res in all_fd_results]\n",
    "ftcs_rmses = [res['ftcs']['rmse'] for res in all_fd_results]\n",
    "btcs_max_diffs = [res['btcs']['max_diff'] for res in all_fd_results]\n",
    "btcs_rmses = [res['btcs']['rmse'] for res in all_fd_results]\n",
    "avg_ftcs_max_diff = np.mean(ftcs_max_diffs, axis=0)\n",
    "avg_ftcs_rmse = np.mean(ftcs_rmses, axis=0)\n",
    "avg_btcs_max_diff = np.mean(btcs_max_diffs, axis=0)\n",
    "avg_btcs_rmse = np.mean(btcs_rmses, axis=0)\n",
    "# Calculate confidence intervals\n",
    "t_critical = stats.t.ppf(0.975, n_ics - 1)  # 95% CI\n",
    "# CI for LLM error metrics\n",
    "ci_lower_max_diffs_8B = []\n",
    "ci_upper_max_diffs_8B = []\n",
    "ci_lower_rmses_8B = []\n",
    "ci_upper_rmses_8B = []\n",
    "for mean, std in zip(avg_llm_max_diffs, std_llm_max_diffs):\n",
    "    lower, upper = log_ci(mean, std, n_ics, t_critical)\n",
    "    ci_lower_max_diffs_8B.append(lower)\n",
    "    ci_upper_max_diffs_8B.append(upper)\n",
    "for mean, std in zip(avg_llm_rmses, std_llm_rmses):\n",
    "    lower, upper = log_ci(mean, std, n_ics, t_critical)\n",
    "    ci_lower_rmses_8B.append(lower)\n",
    "    ci_upper_rmses_8B.append(upper)\n",
    "# CI for conservation errors\n",
    "ci_lower_integral_8B = []\n",
    "ci_upper_integral_8B = []\n",
    "for mean, std in zip(avg_llm_integral_changes, std_llm_integral_changes):\n",
    "    lower, upper = linear_ci(mean, std, n_ics, t_critical)\n",
    "    ci_lower_integral_8B.append(max(lower, 0))\n",
    "    ci_upper_integral_8B.append(upper)\n",
    "\n",
    "# Compute quant floor \n",
    "all_baseline_max_errors_per_step = []\n",
    "all_baseline_rmse_errors_per_step = []\n",
    "for ic_seed in range(n_ics):\n",
    "    # Use the stored initial condition and spline\n",
    "    init_cond_random = stored_initial_conditions[ic_seed]\n",
    "    cs = stored_spline_objects[ic_seed]\n",
    "    # Compute exact solution for this initial condition\n",
    "    u_exact = compute_exact_solution_random_ic_vary_Nx(L, k, T, Nx, Nt, spline_obj=cs)\n",
    "    # Quantization pipeline\n",
    "    u_exact_scaled, vmin_exact, vmax_exact = scale_2d_array(u_exact)\n",
    "    u_exact_serialized = serialize_2d_integers(u_exact_scaled, settings)\n",
    "    u_exact_parsed = deserialize_2d_integers(u_exact_serialized, settings)\n",
    "    u_exact_unscaled = unscale_2d_array(u_exact_parsed, vmin_exact, vmax_exact)\n",
    "    # Compute baseline errors at each time step for this seed\n",
    "    seed_max_errors_per_step = []\n",
    "    seed_rmse_errors_per_step = []\n",
    "    for t in range(u_exact.shape[0]):\n",
    "        max_err_t = np.max(np.abs(u_exact[t] - u_exact_unscaled[t]))\n",
    "        rmse_err_t = np.sqrt(np.mean((u_exact[t] - u_exact_unscaled[t])**2))\n",
    "        seed_max_errors_per_step.append(max_err_t)\n",
    "        seed_rmse_errors_per_step.append(rmse_err_t)\n",
    "    all_baseline_max_errors_per_step.append(seed_max_errors_per_step)\n",
    "    all_baseline_rmse_errors_per_step.append(seed_rmse_errors_per_step)\n",
    "all_baseline_max_errors_per_step = np.array(all_baseline_max_errors_per_step)\n",
    "all_baseline_rmse_errors_per_step = np.array(all_baseline_rmse_errors_per_step)\n",
    "avg_baseline_max_errors_per_step = np.mean(all_baseline_max_errors_per_step, axis=0)\n",
    "avg_baseline_rmse_errors_per_step = np.mean(all_baseline_rmse_errors_per_step, axis=0)\n",
    "avg_baseline_max_errors_prediction = avg_baseline_max_errors_per_step[input_time_steps:]\n",
    "avg_baseline_rmse_errors_prediction = avg_baseline_rmse_errors_per_step[input_time_steps:]\n",
    "\n",
    "averaged_fd_results = {\n",
    "    'ftcs': {\n",
    "        'max_diff': avg_ftcs_max_diff.tolist(),\n",
    "        'rmse': avg_ftcs_rmse.tolist()\n",
    "    },\n",
    "    'btcs': {\n",
    "        'max_diff': avg_btcs_max_diff.tolist(),\n",
    "        'rmse': avg_btcs_rmse.tolist()\n",
    "    }\n",
    "}\n",
    "\n",
    "np.savez_compressed(\n",
    "    \"8B_10_step.npz\",\n",
    "    # LLM error metrics\n",
    "    llm_max_diffs_8B=avg_llm_max_diffs,\n",
    "    llm_rmses_8B=avg_llm_rmses,\n",
    "    std_max_diffs_8B=std_llm_max_diffs,\n",
    "    std_rmses_8B=std_llm_rmses,\n",
    "    # LLM confidence intervals for error metrics\n",
    "    ci_lower_max_diffs_8B=ci_lower_max_diffs_8B,\n",
    "    ci_upper_max_diffs_8B=ci_upper_max_diffs_8B,\n",
    "    ci_lower_rmses_8B=ci_lower_rmses_8B,\n",
    "    ci_upper_rmses_8B=ci_upper_rmses_8B,\n",
    "    # LLM conservation metrics\n",
    "    llm_integral_changes_8B=avg_llm_integral_changes,\n",
    "    std_llm_integral_changes_8B=std_llm_integral_changes,\n",
    "    ci_lower_integral_8B=ci_lower_integral_8B,\n",
    "    ci_upper_integral_8B=ci_upper_integral_8B,\n",
    "    # FD conservation metrics\n",
    "    ftcs_integral_changes=avg_ftcs_integral_changes,\n",
    "    std_ftcs_integral_changes=std_ftcs_integral_changes,\n",
    "    btcs_integral_changes=avg_btcs_integral_changes,\n",
    "    std_btcs_integral_changes=std_btcs_integral_changes,\n",
    "    # Exact solution conservation (baseline)\n",
    "    exact_integral_changes=avg_exact_integral_changes,\n",
    "    std_exact_integral_changes=std_exact_integral_changes,\n",
    "    # FD error metrics\n",
    "    ftcs_max_diffs=avg_ftcs_max_diff,\n",
    "    ftcs_rmses=avg_ftcs_rmse,\n",
    "    btcs_max_diffs=avg_btcs_max_diff,\n",
    "    btcs_rmses=avg_btcs_rmse,\n",
    "    fd_results=averaged_fd_results,\n",
    "    # Quantization floor for prediction steps only\n",
    "    avg_baseline_max_errors_prediction=avg_baseline_max_errors_prediction,\n",
    "    avg_baseline_rmse_errors_prediction=avg_baseline_rmse_errors_prediction,\n",
    "    # Raw results for all initial conditions\n",
    "    all_llm_max_diffs=all_llm_max_diffs,\n",
    "    all_llm_rmses=all_llm_rmses,\n",
    "    all_llm_integral_changes_8B=all_llm_integral_changes,\n",
    "    all_fd_results=all_fd_results,\n",
    "    all_exact_integral_changes=all_exact_integral_changes,\n",
    "    stored_initial_conditions=stored_initial_conditions_array,\n",
    "    stored_initial_integrals_fine=stored_initial_integrals_fine,\n",
    "    stored_initial_integrals_coarse=stored_initial_integrals_coarse\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "A100",
   "machine_shape": "hm",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "smollm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "4bca38f991eb477fb6f6448ed40b7953": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_f282a01a1fa94fd3841fa84b0bf85801",
      "placeholder": "​",
      "style": "IPY_MODEL_bfdb859e858e42869e6da9b1482a5702",
      "value": "Loading checkpoint shards: 100%"
     }
    },
    "79d7edd2ec684e25b3674d375812e5fc": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "82fd26e315b6460ab439920956ecfc4b": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "8d598e552e3e4f3f9ffd47c953554ad0": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_79d7edd2ec684e25b3674d375812e5fc",
      "placeholder": "​",
      "style": "IPY_MODEL_b9ca4f266f0247a3aca54430f78c7bf4",
      "value": " 2/2 [00:04&lt;00:00,  2.25s/it]"
     }
    },
    "b9ca4f266f0247a3aca54430f78c7bf4": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "be3db56ffe3047a6ab8493d65d18f5c6": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "bfdb859e858e42869e6da9b1482a5702": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "e0dd1da9791a4911932193befbfd4dd0": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "e8bbace417ee4d74ae8e9fdcaf023b44": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_be3db56ffe3047a6ab8493d65d18f5c6",
      "max": 2,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_e0dd1da9791a4911932193befbfd4dd0",
      "value": 2
     }
    },
    "f282a01a1fa94fd3841fa84b0bf85801": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "fd21f3afeb514a51a73822346535fdec": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_4bca38f991eb477fb6f6448ed40b7953",
       "IPY_MODEL_e8bbace417ee4d74ae8e9fdcaf023b44",
       "IPY_MODEL_8d598e552e3e4f3f9ffd47c953554ad0"
      ],
      "layout": "IPY_MODEL_82fd26e315b6460ab439920956ecfc4b"
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
