{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-05-14 18:51:12.510450: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "E0000 00:00:1747248672.526274   40373 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "E0000 00:00:1747248672.531348   40373 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "W0000 00:00:1747248672.543916   40373 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
      "W0000 00:00:1747248672.543935   40373 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
      "W0000 00:00:1747248672.543936   40373 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
      "W0000 00:00:1747248672.543937   40373 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np \n",
    "from functools import partial\n",
    "from approxml.cosmo import lensing_simulator\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import torch\n",
    "from torch.optim import Adam\n",
    "from sbi.neural_nets.net_builders import build_maf\n",
    "import math\n",
    "import importlib\n",
    "import types\n",
    "import tensorflow_probability as tfp\n",
    "for _candidate in (\n",
    "        'tensorflow_probability.substrates.jax', \n",
    "        'tensorflow_probability.experimental.substrates.jax'):    \n",
    "    try:\n",
    "        _jax_backend = importlib.import_module(_candidate)\n",
    "        break\n",
    "    except ModuleNotFoundError:\n",
    "        _jax_backend = None\n",
    "if _jax_backend is None:\n",
    "    raise ImportError(\"Couldn’t locate the JAX substrate inside tensorflow_probability.\")\n",
    "\n",
    "if not hasattr(tfp, 'substrates'):\n",
    "    tfp.substrates = types.SimpleNamespace()\n",
    "tfp.substrates.jax = _jax_backend     \n",
    "\n",
    "if not hasattr(tfp, 'experimental'):\n",
    "    tfp.experimental = types.SimpleNamespace()\n",
    "if not hasattr(tfp.experimental, 'substrates'):\n",
    "    tfp.experimental.substrates = types.SimpleNamespace()\n",
    "tfp.experimental.substrates.jax = _jax_backend \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_OBS = 100\n",
    "N_ITER = 100\n",
    "N_RUNS = 100\n",
    "N_PROP = 10    \n",
    "N_SIM_DST = 100\n",
    "N_PARAM_DIM = 6\n",
    "N_DATA_DIM = 6\n",
    "MOD_SIGMA = jnp.eye(N_PARAM_DIM)\n",
    "LR = 1e-3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_40373/1863810017.py:4: DeprecationWarning: Pickled array contains an aval with a named_shape attribute. This is deprecated and the code path supporting such avals will be removed. Please re-pickle the array.\n",
      "  opt_state_resnet= pickle.load(a_file)\n",
      "/tmp/ipykernel_40373/1863810017.py:7: DeprecationWarning: Pickled array contains an aval with a named_shape attribute. This is deprecated and the code path supporting such avals will be removed. Please re-pickle the array.\n",
      "  parameters_compressor= pickle.load(a_file)\n"
     ]
    }
   ],
   "source": [
    "lognormal_shifts_params = np.load(\"./data/lognormal_shifts_LSSTY10_om_s8_w_bin.npy\")\n",
    "\n",
    "a_file = open('./data/params_compressor/opt_state_resnet_vmim.pkl', \"rb\")\n",
    "opt_state_resnet= pickle.load(a_file)\n",
    "\n",
    "a_file = open('./data/params_compressor/params_nd_compressor_vmim.pkl', \"rb\")\n",
    "parameters_compressor= pickle.load(a_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@partial(jax.jit, static_argnames=(\"n_obs\", \"compress\"))\n",
    "def simulator_fn_jit(\n",
    "        key,\n",
    "        unbound_params,\n",
    "        n_obs: int,\n",
    "        *,\n",
    "        compress: bool = True):\n",
    "    p0, p1, p2, p3, p4, p5 = unbound_params  \n",
    "    params_bound = jnp.array([jnp.exp(p0), jnp.exp(p1), jnp.exp(p2), p3, p4 ,p5])\n",
    "\n",
    "    return lensing_simulator(\n",
    "        key,\n",
    "        params_bound,\n",
    "        n_obs,\n",
    "        compress=compress,\n",
    "        opt_state_resnet=opt_state_resnet,          \n",
    "        parameters_compressor=parameters_compressor,\n",
    "        lognormal_shifts_params=lognormal_shifts_params,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_fn = simulator_fn_jit  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.PRNGKey(0)\n",
    "params = jnp.array([0.2664, 0.0492, 0.831, 0.6727, 0.9645, -1.0])\n",
    "params = params.at[2].set(jnp.log(params[2]))\n",
    "params = params.at[1].set(jnp.log(params[1]))\n",
    "params = params.at[0].set(jnp.log(params[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "names = [\"omega_c\", \"omega_b\", \"sigma_8\", \"h_0\", \"n_s\", \"w_0\"]\n",
    "lo   = jnp.array([0.0, \n",
    "                  0.0492 - 3 * math.sqrt(0.006), \n",
    "                  0.831  - 3 * math.sqrt(0.14), \n",
    "                  0.6727 - 3 * math.sqrt(0.063), \n",
    "                  0.9645 - 3 * math.sqrt(0.08), \n",
    "                  -2.0])\n",
    "hi   = jnp.array([0.2664 + 3 * math.sqrt(0.2), \n",
    "                  0.0492 + 3 * math.sqrt(0.006), \n",
    "                  0.831  + 3 * math.sqrt(0.14), \n",
    "                  0.6727 + 3 * math.sqrt(0.063), \n",
    "                  0.9645 + 3 * math.sqrt(0.08), \n",
    "                  -0.3])\n",
    "\n",
    "def sample_prior(key, *shape):\n",
    "    u = jax.random.uniform(key, shape + lo.shape)    \n",
    "    s = lo + (hi - lo) * u\n",
    "    batch = {n: s[..., i] for i, n in enumerate(names)}\n",
    "\n",
    "    for k in (\"omega_c\", \"omega_b\", \"sigma_8\"):\n",
    "        batch[k] = jnp.log(batch[k])\n",
    "    return batch\n",
    "\n",
    "key = jax.random.PRNGKey(0)\n",
    "batch = sample_prior(key, 10_000)     \n",
    "p_batch = jnp.stack([batch[n] for n in names], axis=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/10000 [00:00<?, ?it/s]/workspace/SM-ApproxML/approxml/cosmo.py:449: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.\n",
      "  a = jnp.asarray(a, dtype=jnp.promote_types(float, getattr(a, 'dtype', float)))\n",
      "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py:122: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.\n",
      "  return lax_numpy.astype(self, dtype, copy=copy, device=device)\n",
      "2025-05-14 18:51:31.702275: W external/xla/xla/stream_executor/cuda/subprocess_compilation.cc:237] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 9.0\n",
      "2025-05-14 18:51:31.702287: W external/xla/xla/stream_executor/cuda/subprocess_compilation.cc:240] Used ptxas at /usr/local/cuda/bin/ptxas\n",
      "2025-05-14 18:51:31.702316: W external/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc:135] UNIMPLEMENTED: /usr/local/cuda/bin/ptxas ptxas too old. Falling back to the driver to compile.\n",
      "Relying on driver to perform ptx compilation. \n",
      "Modify $PATH to customize ptxas location.\n",
      "This message will be only logged once.\n",
      "100%|██████████| 10000/10000 [19:33<00:00,  8.52it/s]\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "n = 10_000\n",
    "key = jax.random.PRNGKey(0)\n",
    "key, subkey = jax.random.split(key)\n",
    "\n",
    "keys      = jax.random.split(subkey, n)          \n",
    "\n",
    "sims_batch = []\n",
    "max_attempts = 25          \n",
    "\n",
    "for i in tqdm(range(n)):\n",
    "    attempts = 0\n",
    "    sim_ok   = False\n",
    "\n",
    "    while not sim_ok:\n",
    "        attempts += 1\n",
    "        if attempts > max_attempts:\n",
    "            raise RuntimeError(\n",
    "                f\"Simulator returned NaNs {max_attempts}× in a row \"\n",
    "                f\"for sample {i}\"\n",
    "            )\n",
    "\n",
    "        key, subkey = jax.random.split(key)\n",
    "\n",
    "        p_batch = p_batch.at[i, :].set(jnp.stack([sample_prior(subkey)[n] for n in names], axis=-1))\n",
    "        sim   = simulator_fn_jit(subkey, p_batch[i, :], 1)\n",
    "        \n",
    "        sim_ok = not jnp.isnan(sim).any()     \n",
    "\n",
    "    sims_batch.append(sim)\n",
    "\n",
    "sims_batch = jax.numpy.stack(sims_batch)  \n",
    "\n",
    "sims_batch_t  = torch.as_tensor(np.asarray(sims_batch).copy()).float()\n",
    "p_batch_t = torch.as_tensor(np.asarray(p_batch).copy()).float()\n",
    "\n",
    "sims_batch_t = sims_batch_t.reshape(10_000, 6)\n",
    "\n",
    "torch.save(sims_batch_t, \"sims_batch_t_u_10K.pt\")\n",
    "torch.save(p_batch_t, \"p_batch_u_t_10K.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "sims_batch_t = torch.load(\"sims_batch_t_u_10K.pt\").to(\"cuda\")\n",
    "p_batch_t = torch.load(\"p_batch_u_t_10K.pt\").to(\"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.3886, -1.3647, -0.2016,  0.2275,  1.3012, -1.9330],\n",
       "        [-0.9998, -1.5849,  0.1599,  1.0994,  0.5246, -1.1902],\n",
       "        [-0.5049, -2.7669, -0.8873,  0.5463,  1.2078, -0.4930],\n",
       "        ...,\n",
       "        [ 0.1663, -2.1619, -0.0454,  0.1906,  0.7719, -0.5513],\n",
       "        [ 0.1467, -2.4886, -0.4669,  1.2307,  1.3107, -1.2518],\n",
       "        [-1.0890, -1.3206,  0.1029,  0.8272,  0.6525, -1.8215]],\n",
       "       device='cuda:0')"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p_batch_t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iteration 0, loss: 17.933374404907227\n",
      "iteration 100, loss: 4.787280559539795\n",
      "iteration 200, loss: 2.6578938961029053\n",
      "iteration 300, loss: 1.0838611125946045\n",
      "iteration 400, loss: 0.17521323263645172\n",
      "iteration 500, loss: -0.3896007835865021\n",
      "iteration 600, loss: -0.7402608394622803\n",
      "iteration 700, loss: -1.0814344882965088\n",
      "iteration 800, loss: -1.3012248277664185\n",
      "iteration 900, loss: -1.4849404096603394\n",
      "iteration 1000, loss: -1.654344916343689\n",
      "iteration 1100, loss: -1.7934495210647583\n",
      "iteration 1200, loss: -1.7426137924194336\n",
      "iteration 1300, loss: -1.9865784645080566\n",
      "iteration 1400, loss: -2.1151812076568604\n",
      "iteration 1500, loss: -2.1937637329101562\n",
      "iteration 1600, loss: -2.2741711139678955\n",
      "iteration 1700, loss: -2.3426496982574463\n",
      "iteration 1800, loss: -2.393488883972168\n",
      "iteration 1900, loss: -2.3811988830566406\n",
      "iteration 2000, loss: -2.5075013637542725\n",
      "iteration 2100, loss: -2.5163652896881104\n",
      "iteration 2200, loss: -2.4071619510650635\n",
      "iteration 2300, loss: -2.63608980178833\n",
      "iteration 2400, loss: -2.6776297092437744\n",
      "iteration 2500, loss: -2.703366994857788\n",
      "iteration 2600, loss: -2.749027729034424\n",
      "iteration 2700, loss: -2.7725393772125244\n",
      "iteration 2800, loss: -2.8171072006225586\n",
      "iteration 2900, loss: -2.8375959396362305\n",
      "iteration 3000, loss: -2.8591763973236084\n",
      "iteration 3100, loss: -2.9010000228881836\n",
      "iteration 3200, loss: -2.9073150157928467\n",
      "iteration 3300, loss: -2.9504644870758057\n",
      "iteration 3400, loss: -2.9514875411987305\n",
      "iteration 3500, loss: -2.8950631618499756\n",
      "iteration 3600, loss: -3.021071195602417\n",
      "iteration 3700, loss: -3.0116565227508545\n",
      "iteration 3800, loss: -3.0476696491241455\n",
      "iteration 3900, loss: -3.0839900970458984\n",
      "iteration 4000, loss: -3.0849995613098145\n",
      "iteration 4100, loss: -3.1028268337249756\n",
      "iteration 4200, loss: -3.121610403060913\n",
      "iteration 4300, loss: -3.150514602661133\n",
      "iteration 4400, loss: -3.122514486312866\n",
      "iteration 4500, loss: -3.1735193729400635\n",
      "iteration 4600, loss: -3.205477237701416\n",
      "iteration 4700, loss: -3.1753077507019043\n",
      "iteration 4800, loss: -3.2390730381011963\n",
      "iteration 4900, loss: -3.2486510276794434\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(0)\n",
    "\n",
    "density_estimator = build_maf(sims_batch_t, p_batch_t).to(\"cuda\")\n",
    "\n",
    "opt = Adam(list(density_estimator.parameters()), lr=1e-3)\n",
    "for i in range(5000):\n",
    "    opt.zero_grad()\n",
    "    losses = density_estimator.loss(sims_batch_t, \n",
    "                                    condition=p_batch_t)\n",
    "    loss = torch.mean(losses)\n",
    "    if i % 100 == 0:\n",
    "        print(f\"iteration {i}, loss: {loss}\")\n",
    "    loss.backward()\n",
    "    opt.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"nle_10K.pkl\", \"wb\") as f:\n",
    "    pickle.dump(density_estimator, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
