{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook prepares the data for the subsequent notebook `1-Step-Analyze.ipynb`, which generates figures illustrating the one-step prediction accuracy of the Fisher-KPP equation, as described in Appendix A.3 of the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "\n",
    "import torch\n",
    "num_devices = torch.cuda.device_count()\n",
    "print(\"Number of visible GPUs:\", num_devices)\n",
    "\n",
    "for i in range(num_devices):\n",
    "    print(f\"GPU {i}: {torch.cuda.get_device_name(i)}\")\n",
    "\n",
    "current_device = torch.cuda.current_device()\n",
    "print(\"Current device index:\", current_device)\n",
    "print(\"Current device name:\", torch.cuda.get_device_name(current_device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "import scipy.stats as stats\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from tqdm import tqdm\n",
    "from llama_utils import load_model_and_tokenizer, generate_text_multiple\n",
    "from data_processing import (\n",
    "    SimpleSerializerSettings,\n",
    "    scale_2d_array,\n",
    "    unscale_2d_array,\n",
    "    serialize_2d_integers,\n",
    "    deserialize_2d_integers,\n",
    ")\n",
    "from fisher_kpp_equation import (\n",
    "    compute_exact_solution_random_ic_vary_Nx,\n",
    "    solve_fisher_kpp_ftcs,\n",
    "    visualize_spline_ic,\n",
    "    plot_both_grids,\n",
    "    solve_fisher_kpp_imex\n",
    ")\n",
    "\n",
    "MODEL_NAME = \"meta-llama/Llama-3.1-8B\"\n",
    "# MODEL_NAME = \"meta-llama/Llama-3.2-3B\"\n",
    "# MODEL_NAME = \"meta-llama/Llama-3.2-1B\"\n",
    "seed = 42\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed_all(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model, tokenizer = load_model_and_tokenizer(MODEL_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example: Demonstrating the process of generating and visualizing a random initial condition\n",
    "L = 2\n",
    "Nx = 14\n",
    "init_cond_random = np.random.uniform(0.3, 0.7, size=Nx)\n",
    "fig, cs = visualize_spline_ic(L, Nx, init_cond_random)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Example: Demonstrating how to resample spatial points from an underlying random initial condition\n",
    "Nx_original = Nx\n",
    "Nx_new = 14\n",
    "fig, cs, init_cond_random_new = plot_both_grids(L, Nx_original, Nx_new, init_cond_random)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define parameters for the Fisher-KPP equation\n",
    "L = 2       # Length of the spatial domain\n",
    "k = 0.002   # Thermal diffusivity\n",
    "T = 0.5     # Total simulation time\n",
    "Nx = 14     # Number of spatial steps (excluding boundary points)\n",
    "Nt = 25     # Number of time steps \n",
    "dx = L/(Nx+1)\n",
    "dt = T/Nt\n",
    "\n",
    "# Serialize the entire exact solution for all time steps\n",
    "u_exact = compute_exact_solution_random_ic_vary_Nx(L, k, T, Nx, Nt, spline_obj=cs)\n",
    "settings = SimpleSerializerSettings(space_sep=\",\", time_sep=\";\")\n",
    "u_exact_scaled, vmin_exact, vmax_exact = scale_2d_array(u_exact)\n",
    "u_exact_serialized = serialize_2d_integers(u_exact_scaled, settings)\n",
    "# Reconstruct the scaled data from the text\n",
    "u_exact_parsed = deserialize_2d_integers(u_exact_serialized, settings)\n",
    "u_exact_unscaled = unscale_2d_array(u_exact_parsed, vmin_exact, vmax_exact)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Exact coarse and exact fine demo plots\n",
    "u_coarse, u_fine, x_fine, t_fine = compute_exact_solution_random_ic_vary_Nx(\n",
    "    L, k, T, Nx, Nt, spline_obj=cs, return_fine=True)\n",
    "Nt_plus1, Nx = u_coarse.shape\n",
    "x_coarse_full = np.linspace(-L/2, L/2, Nx+2)\n",
    "t_coarse = np.linspace(0, t_fine[-1], Nt_plus1)\n",
    "u_coarse_full = np.zeros((Nt_plus1, Nx+2))\n",
    "u_coarse_full[:, 1:-1] = u_coarse\n",
    "u_coarse_full[:, 0] = u_coarse_full[:, -1] = 0\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6), tight_layout=True)\n",
    "# Exact coarse\n",
    "cf1 = ax1.contourf(t_coarse, x_coarse_full, u_coarse_full.T,\n",
    "                    levels=50, cmap='jet')\n",
    "ax1.set_title('Reference Solution (Coarse Grid)', fontsize=18)\n",
    "ax1.set_xlabel('Time $(t)$', fontsize=18)\n",
    "ax1.set_ylabel('Space $(x)$', fontsize=18)\n",
    "plt.colorbar(cf1, ax=ax1)\n",
    "\n",
    "# Exact fine\n",
    "cf2 = ax2.contourf(t_fine, x_fine, u_fine.T,\n",
    "                    levels=50, cmap='jet')\n",
    "ax2.set_title('Reference Solution (Fine Grid)', fontsize=18)\n",
    "ax2.set_xlabel('Time $(t)$', fontsize=18)\n",
    "ax2.set_ylabel('Space $(x)$', fontsize=18)\n",
    "plt.colorbar(cf2, ax=ax2)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Range of Nt values to test\n",
    "all_Nt_values = range(2, 41, 2)\n",
    "llm_final_max_diff = []\n",
    "llm_final_rmse = []\n",
    "llm_final_max_diff_std = []\n",
    "llm_final_rmse_std = []\n",
    "ftcs_final_max_diff = []\n",
    "ftcs_final_rmse = []\n",
    "imex_final_max_diff = []\n",
    "imex_final_rmse = []\n",
    "# Fixed parameter - number of spatial points (excluding boundary points)\n",
    "Nx = 14\n",
    "n_seeds = 50\n",
    "# Generate all random initial conditions\n",
    "stored_initial_conditions = []\n",
    "stored_spline_objects = []\n",
    "for seed in range(n_seeds):\n",
    "    # Set seed for this initial condition\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "    # Generate and store the random initial condition\n",
    "    init_cond_random = np.random.uniform(0.3, 0.7, size=Nx)\n",
    "    stored_initial_conditions.append(init_cond_random.copy())\n",
    "    # Create and store the spline object\n",
    "    fig, cs = visualize_spline_ic(L, Nx, init_cond_random)\n",
    "    plt.close(fig)\n",
    "    stored_spline_objects.append(cs)\n",
    "stored_initial_conditions_array = np.array(stored_initial_conditions)\n",
    "\n",
    "for Nt in tqdm(all_Nt_values):\n",
    "    dt = T / Nt\n",
    "    seed_max_diffs_llm = []\n",
    "    seed_rmses_llm = []\n",
    "    seed_max_diffs_ftcs = []\n",
    "    seed_rmses_ftcs = []\n",
    "    seed_max_diffs_imex = []\n",
    "    seed_rmses_imex = []\n",
    "    for seed in range(n_seeds):\n",
    "        # Use the stored initial condition and spline\n",
    "        init_cond_random = stored_initial_conditions[seed]\n",
    "        cs = stored_spline_objects[seed]\n",
    "        # Compute exact solution for this initial condition\n",
    "        u_exact = compute_exact_solution_random_ic_vary_Nx(L, k, T, Nx, Nt, spline_obj=cs)\n",
    "        # Scale and serialize the exact solution\n",
    "        settings = SimpleSerializerSettings(space_sep=\",\", time_sep=\";\")\n",
    "        u_exact_scaled, vmin_exact, vmax_exact = scale_2d_array(u_exact)\n",
    "        u_exact_serialized = serialize_2d_integers(u_exact_scaled, settings)\n",
    "        # Extract ground truth at final time step\n",
    "        final_exact = u_exact[Nt]\n",
    "        quantized_gt_2d, _, _ = scale_2d_array(final_exact[np.newaxis, :], vmin_exact, vmax_exact)\n",
    "        quantized_gt_2d = unscale_2d_array(quantized_gt_2d, vmin_exact, vmax_exact)\n",
    "        quantized_ground_truth = quantized_gt_2d[0, :]\n",
    "        # Extract input data\n",
    "        rows = [row.strip() for row in u_exact_serialized.split(settings.time_sep) if row.strip()]\n",
    "        train_rows = rows[:-1]\n",
    "        train_serial = settings.time_sep.join(train_rows) + settings.time_sep\n",
    "        # LLM prediction\n",
    "        next_token, _ = generate_text_multiple(train_serial, model, tokenizer, Nx)\n",
    "        parsed_data = deserialize_2d_integers(next_token, settings)\n",
    "        unscaled_data = unscale_2d_array(parsed_data, vmin_exact, vmax_exact)\n",
    "        # Calculate LLM errors\n",
    "        max_diff = np.max(np.abs(unscaled_data - quantized_ground_truth))\n",
    "        rmse = np.sqrt(np.mean((unscaled_data - quantized_ground_truth)**2))\n",
    "        seed_max_diffs_llm.append(max_diff)\n",
    "        seed_rmses_llm.append(rmse)\n",
    "        # Get the penultimate time step for finite difference methods\n",
    "        penultimate_exact = u_exact[Nt-1]\n",
    "        penultimate_exact_2d = penultimate_exact[np.newaxis, :]\n",
    "        scaled_init_2d, _, _ = scale_2d_array(penultimate_exact_2d, vmin_exact, vmax_exact)\n",
    "        initial_degraded_2d = unscale_2d_array(scaled_init_2d, vmin_exact, vmax_exact)\n",
    "        initial_degraded = initial_degraded_2d[0, :]\n",
    "        # FD predictions\n",
    "        _, ftcs_result, _ = solve_fisher_kpp_ftcs(L, k, dt, Nx, 1, init_cond=initial_degraded)\n",
    "        _, imex_result, _ = solve_fisher_kpp_imex(L, k, dt, Nx, 1, init_cond=initial_degraded)\n",
    "        predicted_ftcs = ftcs_result[1]\n",
    "        predicted_imex = imex_result[1]\n",
    "        # Calculate FD errors\n",
    "        max_diff_ftcs = np.max(np.abs(predicted_ftcs - quantized_ground_truth))\n",
    "        rmse_ftcs = np.sqrt(np.mean((predicted_ftcs - quantized_ground_truth)**2))\n",
    "        seed_max_diffs_ftcs.append(max_diff_ftcs)\n",
    "        seed_rmses_ftcs.append(rmse_ftcs)\n",
    "        max_diff_imex = np.max(np.abs(predicted_imex - quantized_ground_truth))\n",
    "        rmse_imex = np.sqrt(np.mean((predicted_imex - quantized_ground_truth)**2))\n",
    "        seed_max_diffs_imex.append(max_diff_imex)\n",
    "        seed_rmses_imex.append(rmse_imex)\n",
    "    llm_final_max_diff.append(np.mean(seed_max_diffs_llm))\n",
    "    llm_final_rmse.append(np.mean(seed_rmses_llm))\n",
    "    llm_final_max_diff_std.append(np.std(seed_max_diffs_llm, ddof=1))\n",
    "    llm_final_rmse_std.append(np.std(seed_rmses_llm, ddof=1))\n",
    "    ftcs_final_max_diff.append(np.mean(seed_max_diffs_ftcs))\n",
    "    ftcs_final_rmse.append(np.mean(seed_rmses_ftcs))\n",
    "    imex_final_max_diff.append(np.mean(seed_max_diffs_imex))\n",
    "    imex_final_rmse.append(np.mean(seed_rmses_imex))\n",
    "\n",
    "# Compute quant floor for temporal experiment using stored initial conditions\n",
    "temporal_baseline_max_errors = []\n",
    "temporal_baseline_rmse_errors = []\n",
    "for Nt in all_Nt_values:\n",
    "    seed_baseline_max_errors = []\n",
    "    seed_baseline_rmse_errors = []\n",
    "    for seed in range(n_seeds):\n",
    "        init_cond_random = stored_initial_conditions[seed]\n",
    "        cs = stored_spline_objects[seed]\n",
    "        # Compute exact solution for this specific (Nx, Nt) pair and initial condition\n",
    "        u_exact = compute_exact_solution_random_ic_vary_Nx(L, k, T, Nx, Nt, spline_obj=cs)\n",
    "        # Quantization pipeline\n",
    "        u_exact_scaled, vmin_exact, vmax_exact = scale_2d_array(u_exact)\n",
    "        u_exact_serialized = serialize_2d_integers(u_exact_scaled, settings)\n",
    "        u_exact_parsed = deserialize_2d_integers(u_exact_serialized, settings)\n",
    "        u_exact_unscaled = unscale_2d_array(u_exact_parsed, vmin_exact, vmax_exact)\n",
    "        # Compute quant floor errors for this seed\n",
    "        baseline_max_error = np.max(np.abs(u_exact - u_exact_unscaled))\n",
    "        baseline_rmse_error = np.sqrt(np.mean((u_exact - u_exact_unscaled)**2))\n",
    "        seed_baseline_max_errors.append(baseline_max_error)\n",
    "        seed_baseline_rmse_errors.append(baseline_rmse_error)\n",
    "    # Average across seeds\n",
    "    temporal_baseline_max_errors.append(np.mean(seed_baseline_max_errors))\n",
    "    temporal_baseline_rmse_errors.append(np.mean(seed_baseline_rmse_errors))\n",
    "temporal_baseline_max_errors = np.array(temporal_baseline_max_errors)\n",
    "temporal_baseline_rmse_errors = np.array(temporal_baseline_rmse_errors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def log_ci(mean, std, n, tcrit):\n",
    "    \"\"\"\n",
    "    95% CI for log10 axis using delta method\n",
    "    mean : arithmetic mean of the n samples\n",
    "    std : sample std of the n samples\n",
    "    n : number of samples\n",
    "    tcrit: two-sided t critical value\n",
    "    \"\"\"\n",
    "    se = std / np.sqrt(n)  # SE in linear space\n",
    "    se_log = se / (mean * np.log(10))  # delta-method SE in log space\n",
    "    mean_log = np.log10(mean)\n",
    "    delta_log = tcrit * se_log\n",
    "    return 10**(mean_log - delta_log), 10**(mean_log + delta_log)\n",
    "\n",
    "llm_final_max_diff = np.array(llm_final_max_diff)\n",
    "llm_final_rmse = np.array(llm_final_rmse)\n",
    "llm_final_max_diff_std = np.array(llm_final_max_diff_std)\n",
    "llm_final_rmse_std = np.array(llm_final_rmse_std)\n",
    "# Calculate log-scale confidence intervals\n",
    "t_critical = stats.t.ppf(0.975, df=n_seeds-1)\n",
    "lower_max_diff_log, upper_max_diff_log = log_ci(llm_final_max_diff, llm_final_max_diff_std, n_seeds, t_critical)\n",
    "lower_rmse_log, upper_rmse_log = log_ci(llm_final_rmse, llm_final_rmse_std, n_seeds, t_critical)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.savez_compressed(\n",
    "    \"8B_1_step_time_discretization.npz\",\n",
    "    # LLM metrics\n",
    "    llm_final_max_diff_8B = llm_final_max_diff,\n",
    "    llm_final_rmse_8B = llm_final_rmse,\n",
    "    llm_final_max_diff_std_8B = llm_final_max_diff_std,\n",
    "    llm_final_rmse_std_8B = llm_final_rmse_std,\n",
    "    # Log-scale confidence intervals\n",
    "    llm_lower_max_diff_log_8B = lower_max_diff_log,\n",
    "    llm_upper_max_diff_log_8B = upper_max_diff_log,\n",
    "    llm_lower_rmse_log_8B = lower_rmse_log,\n",
    "    llm_upper_rmse_log_8B = upper_rmse_log,\n",
    "    # Finite difference metrics\n",
    "    ftcs_final_max_diff = ftcs_final_max_diff,\n",
    "    ftcs_final_rmse = ftcs_final_rmse,\n",
    "    imex_final_max_diff = imex_final_max_diff,\n",
    "    imex_final_rmse = imex_final_rmse,\n",
    "    # Baseline metrics\n",
    "    temporal_baseline_max_errors=temporal_baseline_max_errors,\n",
    "    temporal_baseline_rmse_errors=temporal_baseline_rmse_errors,\n",
    "    initial_conditions = stored_initial_conditions_array,\n",
    "    all_Nt_values = list(all_Nt_values),\n",
    "    n_seeds = n_seeds,\n",
    "    t_critical = t_critical\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Range of Nx values to test\n",
    "all_Nx_values = range(2, 41, 2)\n",
    "llm_final_max_diff = []\n",
    "llm_final_rmse = []\n",
    "llm_final_max_diff_std = []\n",
    "llm_final_rmse_std = []\n",
    "ftcs_final_max_diff = []\n",
    "ftcs_final_rmse = []\n",
    "imex_final_max_diff = []\n",
    "imex_final_rmse = []\n",
    "\n",
    "# Fixed parameters\n",
    "Nt = 50       # Fixed number of time steps\n",
    "Nx_base = 14  # Base resolution for generating initial conditions\n",
    "for Nx in tqdm(all_Nx_values):\n",
    "    dt = T / Nt\n",
    "    seed_max_diffs_llm = []\n",
    "    seed_rmses_llm = []\n",
    "    seed_max_diffs_ftcs = []\n",
    "    seed_rmses_ftcs = []\n",
    "    seed_max_diffs_imex = []\n",
    "    seed_rmses_imex = []\n",
    "    # Variables to store previous valid results as fallback for preventive checks\n",
    "    prev_max_diff = None\n",
    "    prev_rmse = None\n",
    "    for seed in range(n_seeds):\n",
    "        # Use stored base initial condition and spline from temporal exploration\n",
    "        init_cond_random_base = stored_initial_conditions[seed]\n",
    "        cs = stored_spline_objects[seed]\n",
    "        if Nx == Nx_base:\n",
    "            init_cond_current = init_cond_random_base\n",
    "        else:\n",
    "            # Sample from the spline at new resolution\n",
    "            fig2, cs_same, init_cond_current = plot_both_grids(L, Nx_base, Nx, init_cond_random_base)\n",
    "            plt.close(fig2)\n",
    "        # Compute exact solution for this Nx using the sampled initial condition\n",
    "        u_exact = compute_exact_solution_random_ic_vary_Nx(L, k, T, Nx, Nt, spline_obj=cs)\n",
    "        # Scale and serialize the exact solution\n",
    "        settings = SimpleSerializerSettings(space_sep=\",\", time_sep=\";\")\n",
    "        u_exact_scaled, vmin_exact, vmax_exact = scale_2d_array(u_exact)\n",
    "        u_exact_serialized = serialize_2d_integers(u_exact_scaled, settings)\n",
    "        # Extract ground truth at final time step\n",
    "        final_exact = u_exact[Nt]\n",
    "        quantized_gt_2d, _, _ = scale_2d_array(final_exact[np.newaxis, :], vmin_exact, vmax_exact)\n",
    "        quantized_gt_2d = unscale_2d_array(quantized_gt_2d, vmin_exact, vmax_exact)\n",
    "        quantized_ground_truth = quantized_gt_2d[0, :]\n",
    "        # Extract training data (all steps except the last one)\n",
    "        rows = [row.strip() for row in u_exact_serialized.split(settings.time_sep) if row.strip()]\n",
    "        train_rows = rows[:-1]\n",
    "        train_serial = settings.time_sep.join(train_rows) + settings.time_sep\n",
    "        # LLM prediction\n",
    "        valid_prediction = False\n",
    "        max_retries = 10\n",
    "        for attempt in range(max_retries):\n",
    "            next_token, _ = generate_text_multiple(train_serial, model, tokenizer, Nx)\n",
    "            parsed_data = deserialize_2d_integers(next_token, settings)\n",
    "            if parsed_data.shape[1] == Nx:\n",
    "                valid_prediction = True\n",
    "                break\n",
    "            print(f\"Attempt {attempt+1} failed for Nx={Nx}, seed={seed}: Got shape {parsed_data.shape}, expected second dim to be {Nx}\")\n",
    "        if not valid_prediction:\n",
    "            if prev_max_diff is not None and prev_rmse is not None:\n",
    "                print(f\"Failed to get valid prediction for Nx={Nx}, seed={seed} after {max_retries} attempts, use results from previous seed\")\n",
    "                max_diff = prev_max_diff\n",
    "                rmse = prev_rmse\n",
    "            else:\n",
    "                print(f\"Failed to get valid prediction for Nx={Nx}, seed={seed} after {max_retries} attempts, and no previous results available. Skipping.\")\n",
    "                continue\n",
    "        else:\n",
    "            unscaled_data = unscale_2d_array(parsed_data, vmin_exact, vmax_exact)\n",
    "            unscaled_data_flat = unscaled_data.flatten()\n",
    "            # Calculate LLM errors\n",
    "            max_diff = np.max(np.abs(unscaled_data_flat - quantized_ground_truth))\n",
    "            rmse = np.sqrt(np.mean((unscaled_data_flat - quantized_ground_truth)**2))\n",
    "            prev_max_diff = max_diff\n",
    "            prev_rmse = rmse\n",
    "        seed_max_diffs_llm.append(max_diff)\n",
    "        seed_rmses_llm.append(rmse)\n",
    "        # Get the penultimate time step for finite difference methods\n",
    "        penultimate_exact = u_exact[Nt-1]\n",
    "        penultimate_exact_2d = penultimate_exact[np.newaxis, :]\n",
    "        scaled_init_2d, _, _ = scale_2d_array(penultimate_exact_2d, vmin_exact, vmax_exact)\n",
    "        initial_degraded_2d = unscale_2d_array(scaled_init_2d, vmin_exact, vmax_exact)\n",
    "        initial_degraded = initial_degraded_2d[0, :]\n",
    "        # FD predictions\n",
    "        _, ftcs_result, _ = solve_fisher_kpp_ftcs(L, k, dt, Nx, 1, init_cond=initial_degraded)\n",
    "        _, imex_result, _ = solve_fisher_kpp_imex(L, k, dt, Nx, 1, init_cond=initial_degraded)\n",
    "        predicted_ftcs = ftcs_result[1]\n",
    "        predicted_imex = imex_result[1]\n",
    "        # Calculate FD errors\n",
    "        max_diff_ftcs = np.max(np.abs(predicted_ftcs - quantized_ground_truth))\n",
    "        rmse_ftcs = np.sqrt(np.mean((predicted_ftcs - quantized_ground_truth)**2))\n",
    "        seed_max_diffs_ftcs.append(max_diff_ftcs)\n",
    "        seed_rmses_ftcs.append(rmse_ftcs)\n",
    "        max_diff_imex = np.max(np.abs(predicted_imex - quantized_ground_truth))\n",
    "        rmse_imex = np.sqrt(np.mean((predicted_imex - quantized_ground_truth)**2))\n",
    "        seed_max_diffs_imex.append(max_diff_imex)\n",
    "        seed_rmses_imex.append(rmse_imex)\n",
    "    \n",
    "    llm_final_max_diff.append(np.mean(seed_max_diffs_llm))\n",
    "    llm_final_rmse.append(np.mean(seed_rmses_llm))\n",
    "    llm_final_max_diff_std.append(np.std(seed_max_diffs_llm, ddof=1) if len(seed_max_diffs_llm) > 1 else 0)\n",
    "    llm_final_rmse_std.append(np.std(seed_rmses_llm, ddof=1) if len(seed_rmses_llm) > 1 else 0)\n",
    "    ftcs_final_max_diff.append(np.mean(seed_max_diffs_ftcs))\n",
    "    ftcs_final_rmse.append(np.mean(seed_rmses_ftcs))\n",
    "    imex_final_max_diff.append(np.mean(seed_max_diffs_imex))\n",
    "    imex_final_rmse.append(np.mean(seed_rmses_imex))\n",
    "\n",
    "# Compute baselines for spatial experiment using stored splines\n",
    "spatial_baseline_max_errors = []\n",
    "spatial_baseline_rmse_errors = []\n",
    "\n",
    "for Nx in all_Nx_values:\n",
    "    seed_baseline_max_errors = []\n",
    "    seed_baseline_rmse_errors = []\n",
    "    for seed in range(n_seeds):\n",
    "        # Use stored base initial condition and spline\n",
    "        init_cond_random_base = stored_initial_conditions[seed]\n",
    "        cs = stored_spline_objects[seed]\n",
    "        # Sample at current Nx resolution\n",
    "        if Nx == Nx_base:\n",
    "            init_cond_current = init_cond_random_base\n",
    "        else:\n",
    "            fig2, cs_same, init_cond_current = plot_both_grids(L, Nx_base, Nx, init_cond_random_base)\n",
    "            plt.close(fig2)\n",
    "        # Compute exact solution\n",
    "        u_exact = compute_exact_solution_random_ic_vary_Nx(L, k, T, Nx, Nt, spline_obj=cs)\n",
    "        # Quantization pipeline\n",
    "        u_exact_scaled, vmin_exact, vmax_exact = scale_2d_array(u_exact)\n",
    "        u_exact_serialized = serialize_2d_integers(u_exact_scaled, settings)\n",
    "        u_exact_parsed = deserialize_2d_integers(u_exact_serialized, settings)\n",
    "        u_exact_unscaled = unscale_2d_array(u_exact_parsed, vmin_exact, vmax_exact)\n",
    "        # Compute baseline errors\n",
    "        baseline_max_error = np.max(np.abs(u_exact - u_exact_unscaled))\n",
    "        baseline_rmse_error = np.sqrt(np.mean((u_exact - u_exact_unscaled)**2))\n",
    "        seed_baseline_max_errors.append(baseline_max_error)\n",
    "        seed_baseline_rmse_errors.append(baseline_rmse_error)\n",
    "    # Average across seeds\n",
    "    spatial_baseline_max_errors.append(np.mean(seed_baseline_max_errors))\n",
    "    spatial_baseline_rmse_errors.append(np.mean(seed_baseline_rmse_errors))\n",
    "\n",
    "epsilon=1e-4    # Avoid zero errors in log-log plots by enforcing a small minimum value\n",
    "spatial_baseline_max_errors = np.array(spatial_baseline_max_errors)\n",
    "spatial_baseline_rmse_errors = np.array(spatial_baseline_rmse_errors)\n",
    "llm_final_max_diff = np.maximum(np.array(llm_final_max_diff), epsilon)\n",
    "llm_final_rmse = np.maximum(np.array(llm_final_rmse), epsilon)\n",
    "llm_final_max_diff_std = np.array(llm_final_max_diff_std)\n",
    "llm_final_rmse_std = np.array(llm_final_rmse_std)\n",
    "# Calculate log-scale confidence intervals\n",
    "lower_max_diff_log, upper_max_diff_log = log_ci(llm_final_max_diff, llm_final_max_diff_std, n_seeds, t_critical)\n",
    "lower_rmse_log, upper_rmse_log = log_ci(llm_final_rmse, llm_final_rmse_std, n_seeds, t_critical)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.savez_compressed(\n",
    "    \"8B_1_step_space_discretization.npz\",\n",
    "    # LLM metrics\n",
    "    llm_final_max_diff_8B = llm_final_max_diff,\n",
    "    llm_final_rmse_8B = llm_final_rmse,\n",
    "    llm_final_max_diff_std_8B = llm_final_max_diff_std,\n",
    "    llm_final_rmse_std_8B = llm_final_rmse_std,\n",
    "    # Log-scale confidence intervals\n",
    "    llm_lower_max_diff_log_8B = lower_max_diff_log,\n",
    "    llm_upper_max_diff_log_8B = upper_max_diff_log,\n",
    "    llm_lower_rmse_log_8B = lower_rmse_log,\n",
    "    llm_upper_rmse_log_8B = upper_rmse_log,\n",
    "    # Finite difference metrics\n",
    "    ftcs_final_max_diff_8B = ftcs_final_max_diff,\n",
    "    ftcs_final_rmse_8B = ftcs_final_rmse,\n",
    "    imex_final_max_diff_8B = imex_final_max_diff,\n",
    "    imex_final_rmse_8B = imex_final_rmse,\n",
    "    # Baseline metrics\n",
    "    spatial_baseline_max_errors = spatial_baseline_max_errors,\n",
    "    spatial_baseline_rmse_errors = spatial_baseline_rmse_errors,\n",
    "    n_seeds = n_seeds,\n",
    "    t_critical = t_critical,\n",
    "    all_Nx_values = list(all_Nx_values)\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "A100",
   "machine_shape": "hm",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "smollm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "4bca38f991eb477fb6f6448ed40b7953": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_f282a01a1fa94fd3841fa84b0bf85801",
      "placeholder": "​",
      "style": "IPY_MODEL_bfdb859e858e42869e6da9b1482a5702",
      "value": "Loading checkpoint shards: 100%"
     }
    },
    "79d7edd2ec684e25b3674d375812e5fc": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "82fd26e315b6460ab439920956ecfc4b": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "8d598e552e3e4f3f9ffd47c953554ad0": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_79d7edd2ec684e25b3674d375812e5fc",
      "placeholder": "​",
      "style": "IPY_MODEL_b9ca4f266f0247a3aca54430f78c7bf4",
      "value": " 2/2 [00:04&lt;00:00,  2.25s/it]"
     }
    },
    "b9ca4f266f0247a3aca54430f78c7bf4": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "be3db56ffe3047a6ab8493d65d18f5c6": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "bfdb859e858e42869e6da9b1482a5702": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "e0dd1da9791a4911932193befbfd4dd0": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "e8bbace417ee4d74ae8e9fdcaf023b44": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_be3db56ffe3047a6ab8493d65d18f5c6",
      "max": 2,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_e0dd1da9791a4911932193befbfd4dd0",
      "value": 2
     }
    },
    "f282a01a1fa94fd3841fa84b0bf85801": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "fd21f3afeb514a51a73822346535fdec": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_4bca38f991eb477fb6f6448ed40b7953",
       "IPY_MODEL_e8bbace417ee4d74ae8e9fdcaf023b44",
       "IPY_MODEL_8d598e552e3e4f3f9ffd47c953554ad0"
      ],
      "layout": "IPY_MODEL_82fd26e315b6460ab439920956ecfc4b"
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
