{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aeac1d80-0593-4d86-a09d-16f77b74cea5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-05-14 19:12:39.553651: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "E0000 00:00:1747249959.569577   41283 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "E0000 00:00:1747249959.574841   41283 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "W0000 00:00:1747249959.587412   41283 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
      "W0000 00:00:1747249959.587428   41283 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
      "W0000 00:00:1747249959.587429   41283 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
      "W0000 00:00:1747249959.587430   41283 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
      "/tmp/ipykernel_41283/3645238986.py:61: DeprecationWarning: Pickled array contains an aval with a named_shape attribute. This is deprecated and the code path supporting such avals will be removed. Please re-pickle the array.\n",
      "  opt_state_resnet= pickle.load(a_file)\n",
      "/tmp/ipykernel_41283/3645238986.py:64: DeprecationWarning: Pickled array contains an aval with a named_shape attribute. This is deprecated and the code path supporting such avals will be removed. Please re-pickle the array.\n",
      "  parameters_compressor= pickle.load(a_file)\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np \n",
    "from functools import partial\n",
    "from approxml.cosmo import lensing_simulator, sample_lensing_prior\n",
    "import pickle\n",
    "import torch\n",
    "import importlib\n",
    "import types\n",
    "import tensorflow_probability as tfp\n",
    "for _candidate in (\n",
    "        'tensorflow_probability.substrates.jax', \n",
    "        'tensorflow_probability.experimental.substrates.jax'):    \n",
    "    try:\n",
    "        _jax_backend = importlib.import_module(_candidate)\n",
    "        break\n",
    "    except ModuleNotFoundError:\n",
    "        _jax_backend = None\n",
    "if _jax_backend is None:\n",
    "    raise ImportError(\"Couldn’t locate the JAX substrate inside tensorflow_probability.\")\n",
    "\n",
    "if not hasattr(tfp, 'substrates'):\n",
    "    tfp.substrates = types.SimpleNamespace()\n",
    "tfp.substrates.jax = _jax_backend     \n",
    "\n",
    "if not hasattr(tfp, 'experimental'):\n",
    "    tfp.experimental = types.SimpleNamespace()\n",
    "if not hasattr(tfp.experimental, 'substrates'):\n",
    "    tfp.experimental.substrates = types.SimpleNamespace()\n",
    "tfp.experimental.substrates.jax = _jax_backend \n",
    "\n",
    "\n",
    "lognormal_shifts_params = np.load(\"./data/lognormal_shifts_LSSTY10_om_s8_w_bin.npy\")\n",
    "\n",
    "a_file = open('./data/params_compressor/opt_state_resnet_vmim.pkl', \"rb\")\n",
    "opt_state_resnet= pickle.load(a_file)\n",
    "\n",
    "a_file = open('./data/params_compressor/params_nd_compressor_vmim.pkl', \"rb\")\n",
    "parameters_compressor= pickle.load(a_file)\n",
    "\n",
    "\n",
    "@partial(jax.jit, static_argnames=(\"n_obs\", \"compress\"))\n",
    "def simulator_fn_jit(\n",
    "        key,\n",
    "        unbound_params,\n",
    "        n_obs: int,\n",
    "        *,\n",
    "        compress: bool = True):\n",
    "    p0, p1, p2, p3, p4, p5 = unbound_params  \n",
    "    params_bound = jnp.array([jnp.exp(p0), jnp.exp(p1), jnp.exp(p2), p3, p4 ,p5])\n",
    "\n",
    "    return lensing_simulator(\n",
    "        key,\n",
    "        params_bound,\n",
    "        n_obs,\n",
    "        compress=compress,\n",
    "        opt_state_resnet=opt_state_resnet,          \n",
    "        parameters_compressor=parameters_compressor,\n",
    "        lognormal_shifts_params=lognormal_shifts_params,\n",
    "    )\n",
    "\n",
    "sim_fn = simulator_fn_jit  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "36e6abc3-7b55-46d9-889b-e33f662ef0a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"nle_10K.pkl\", \"rb\") as f:\n",
    "    density_estimator = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "80b34f2c-1e2c-4cd9-9756-01cbfecc0f77",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/workspace/SM-ApproxML/approxml/cosmo.py:449: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.\n",
      "  a = jnp.asarray(a, dtype=jnp.promote_types(float, getattr(a, 'dtype', float)))\n",
      "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py:122: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.\n",
      "  return lax_numpy.astype(self, dtype, copy=copy, device=device)\n",
      "2025-05-14 19:12:59.591233: W external/xla/xla/stream_executor/cuda/subprocess_compilation.cc:237] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 9.0\n",
      "2025-05-14 19:12:59.591245: W external/xla/xla/stream_executor/cuda/subprocess_compilation.cc:240] Used ptxas at /usr/local/cuda/bin/ptxas\n",
      "2025-05-14 19:12:59.591275: W external/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc:135] UNIMPLEMENTED: /usr/local/cuda/bin/ptxas ptxas too old. Falling back to the driver to compile.\n",
      "Relying on driver to perform ptx compilation. \n",
      "Modify $PATH to customize ptxas location.\n",
      "This message will be only logged once.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter 010  NLL = 277.4568\n",
      "iter 020  NLL = -166.5627\n",
      "iter 030  NLL = -277.8683\n",
      "iter 040  NLL = -358.6677\n",
      "iter 050  NLL = -385.8805\n",
      "iter 060  NLL = -393.0500\n",
      "iter 070  NLL = -394.8662\n",
      "iter 080  NLL = -395.6410\n",
      "iter 090  NLL = -396.3203\n",
      "iter 100  NLL = -397.0311\n",
      "iter 110  NLL = -397.8787\n",
      "iter 120  NLL = -399.0250\n",
      "iter 130  NLL = -400.5421\n",
      "iter 140  NLL = -402.5488\n",
      "iter 150  NLL = -405.1787\n",
      "iter 160  NLL = -408.5904\n",
      "iter 170  NLL = -412.9607\n",
      "iter 180  NLL = -418.4491\n",
      "iter 190  NLL = -425.0680\n",
      "iter 200  NLL = -432.3144\n",
      "iter 210  NLL = -438.7263\n",
      "iter 220  NLL = -442.5251\n",
      "iter 230  NLL = -443.6243\n",
      "iter 240  NLL = -443.6209\n",
      "iter 250  NLL = -443.6010\n",
      "iter 260  NLL = -443.6469\n",
      "iter 270  NLL = -443.6668\n",
      "iter 280  NLL = -443.6675\n",
      "iter 290  NLL = -443.6671\n",
      "iter 300  NLL = -443.6677\n",
      "iter 310  NLL = -443.6680\n",
      "iter 320  NLL = -443.6682\n",
      "iter 330  NLL = -443.6682\n",
      "iter 340  NLL = -443.6683\n",
      "iter 350  NLL = -443.6682\n",
      "iter 360  NLL = -443.6682\n",
      "iter 370  NLL = -443.6683\n",
      "iter 380  NLL = -443.6682\n",
      "iter 390  NLL = -443.6680\n",
      "iter 400  NLL = -443.6682\n",
      "iter 410  NLL = -443.6680\n",
      "iter 420  NLL = -443.6682\n",
      "iter 430  NLL = -443.6681\n",
      "iter 440  NLL = -443.6682\n",
      "iter 450  NLL = -443.6682\n",
      "iter 460  NLL = -443.6682\n",
      "iter 470  NLL = -443.6681\n",
      "iter 480  NLL = -443.6682\n",
      "iter 490  NLL = -443.6682\n",
      "iter 500  NLL = -443.6682\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_41283/1607512988.py:35: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  sims = sim_fn(subkey, jnp.array(torch.tensor(theta)), n_obs)\n",
      "/tmp/ipykernel_41283/1607512988.py:37: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  mse = jnp.linalg.norm(jnp.array(torch.tensor(theta)) - params)\n",
      "/tmp/ipykernel_41283/1607512988.py:41: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  'theta': jnp.array(torch.tensor(theta)),\n",
      "/workspace/SM-ApproxML/approxml/cosmo.py:449: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.\n",
      "  a = jnp.asarray(a, dtype=jnp.promote_types(float, getattr(a, 'dtype', float)))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n_obs: 100, err: 0.5244500041007996, mse: 0.9621155858039856\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py:122: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.\n",
      "  return lax_numpy.astype(self, dtype, copy=copy, device=device)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter 010  NLL = 708.7800\n",
      "iter 020  NLL = -402.8170\n",
      "iter 030  NLL = -693.3678\n",
      "iter 040  NLL = -893.3175\n",
      "iter 050  NLL = -966.6942\n",
      "iter 060  NLL = -987.8702\n",
      "iter 070  NLL = -993.4218\n",
      "iter 080  NLL = -996.0474\n",
      "iter 090  NLL = -998.6967\n",
      "iter 100  NLL = -1001.4089\n",
      "iter 110  NLL = -1004.6997\n",
      "iter 120  NLL = -1009.0854\n",
      "iter 130  NLL = -1014.7377\n",
      "iter 140  NLL = -1022.0165\n",
      "iter 150  NLL = -1031.3204\n",
      "iter 160  NLL = -1043.0994\n",
      "iter 170  NLL = -1057.8176\n",
      "iter 180  NLL = -1075.6062\n",
      "iter 190  NLL = -1095.2122\n",
      "iter 200  NLL = -1112.5732\n",
      "iter 210  NLL = -1122.6079\n",
      "iter 220  NLL = -1125.2731\n",
      "iter 230  NLL = -1125.1833\n",
      "iter 240  NLL = -1125.1677\n",
      "iter 250  NLL = -1125.3020\n",
      "iter 260  NLL = -1125.3479\n",
      "iter 270  NLL = -1125.3470\n",
      "iter 280  NLL = -1125.3472\n",
      "iter 290  NLL = -1125.3496\n",
      "iter 300  NLL = -1125.3501\n",
      "iter 310  NLL = -1125.3501\n",
      "iter 320  NLL = -1125.3501\n",
      "iter 330  NLL = -1125.3501\n",
      "iter 340  NLL = -1125.3501\n",
      "iter 350  NLL = -1125.3501\n",
      "iter 360  NLL = -1125.3502\n",
      "iter 370  NLL = -1125.3499\n",
      "iter 380  NLL = -1125.3503\n",
      "iter 390  NLL = -1125.3502\n",
      "iter 400  NLL = -1125.3502\n",
      "iter 410  NLL = -1125.3501\n",
      "iter 420  NLL = -1125.3501\n",
      "iter 430  NLL = -1125.3501\n",
      "iter 440  NLL = -1125.3501\n",
      "iter 450  NLL = -1125.3503\n",
      "iter 460  NLL = -1125.3501\n",
      "iter 470  NLL = -1125.3501\n",
      "iter 480  NLL = -1125.3503\n",
      "iter 490  NLL = -1125.3501\n",
      "iter 500  NLL = -1125.3500\n",
      "n_obs: 250, err: 0.5303230285644531, mse: 0.9729365706443787\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-05-14 19:14:35.583221: E external/xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.\n",
      "2025-05-14 19:14:35.706607: E external/xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.\n",
      "2025-05-14 19:14:37.141582: E external/xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.\n",
      "2025-05-14 19:14:37.265193: E external/xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter 010  NLL = 1463.5188\n",
      "iter 020  NLL = -870.3006\n",
      "iter 030  NLL = -1438.6108\n",
      "iter 040  NLL = -1829.0684\n",
      "iter 050  NLL = -1965.1284\n",
      "iter 060  NLL = -1999.5947\n",
      "iter 070  NLL = -2009.0579\n",
      "iter 080  NLL = -2014.2932\n",
      "iter 090  NLL = -2019.1261\n",
      "iter 100  NLL = -2024.1515\n",
      "iter 110  NLL = -2030.5339\n",
      "iter 120  NLL = -2039.1506\n",
      "iter 130  NLL = -2050.5022\n",
      "iter 140  NLL = -2065.3789\n",
      "iter 150  NLL = -2084.6182\n",
      "iter 160  NLL = -2109.0898\n",
      "iter 170  NLL = -2139.4844\n",
      "iter 180  NLL = -2175.4897\n",
      "iter 190  NLL = -2213.7102\n",
      "iter 200  NLL = -2245.7109\n",
      "iter 210  NLL = -2262.9570\n",
      "iter 220  NLL = -2267.1421\n",
      "iter 230  NLL = -2266.9297\n",
      "iter 240  NLL = -2266.9287\n",
      "iter 250  NLL = -2267.1582\n",
      "iter 260  NLL = -2267.2354\n",
      "iter 270  NLL = -2267.2354\n",
      "iter 280  NLL = -2267.2351\n",
      "iter 290  NLL = -2267.2390\n",
      "iter 300  NLL = -2267.2397\n",
      "iter 310  NLL = -2267.2402\n",
      "iter 320  NLL = -2267.2400\n",
      "iter 330  NLL = -2267.2400\n",
      "iter 340  NLL = -2267.2402\n",
      "iter 350  NLL = -2267.2400\n",
      "iter 360  NLL = -2267.2400\n",
      "iter 370  NLL = -2267.2402\n",
      "iter 380  NLL = -2267.2400\n",
      "iter 390  NLL = -2267.2402\n",
      "iter 400  NLL = -2267.2400\n",
      "iter 410  NLL = -2267.2400\n",
      "iter 420  NLL = -2267.2397\n",
      "iter 430  NLL = -2267.2400\n",
      "iter 440  NLL = -2267.2400\n",
      "iter 450  NLL = -2267.2400\n",
      "iter 460  NLL = -2267.2397\n",
      "iter 470  NLL = -2267.2397\n",
      "iter 480  NLL = -2267.2400\n",
      "iter 490  NLL = -2267.2397\n",
      "iter 500  NLL = -2267.2397\n",
      "n_obs: 500, err: 0.5279261469841003, mse: 0.959726870059967\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-05-14 19:15:29.267399: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng0{} for conv %cudnn-conv-bias-activation.81 = (f32[1000,64,64,64]{3,2,1,0}, u8[0]{0}) custom-call(%bitcast.56294, %bitcast.56295, %bitcast.56296), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target=\"__cudnn$convBiasActivationForward\", metadata={op_name=\"jit(simulator_fn_jit)/jit(main)/res_net18/block_group_0/block_0/batchnorm_0/mul\" source_file=\"/usr/local/lib/python3.10/dist-packages/haiku/_src/batch_norm.py\" source_line=205}, backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"cudnn_conv_backend_config\":{\"conv_result_scale\":1,\"activation_mode\":\"kRelu\",\"side_input_scale\":0,\"leakyrelu_alpha\":0},\"force_earliest_schedule\":false} is taking a while...\n",
      "2025-05-14 19:15:30.123106: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 1.855815786s\n",
      "Trying algorithm eng0{} for conv %cudnn-conv-bias-activation.81 = (f32[1000,64,64,64]{3,2,1,0}, u8[0]{0}) custom-call(%bitcast.56294, %bitcast.56295, %bitcast.56296), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target=\"__cudnn$convBiasActivationForward\", metadata={op_name=\"jit(simulator_fn_jit)/jit(main)/res_net18/block_group_0/block_0/batchnorm_0/mul\" source_file=\"/usr/local/lib/python3.10/dist-packages/haiku/_src/batch_norm.py\" source_line=205}, backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"cudnn_conv_backend_config\":{\"conv_result_scale\":1,\"activation_mode\":\"kRelu\",\"side_input_scale\":0,\"leakyrelu_alpha\":0},\"force_earliest_schedule\":false} is taking a while...\n",
      "2025-05-14 19:15:31.807323: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng0{} for conv %cudnn-conv-bias-activation.82 = (f32[1000,64,64,64]{3,2,1,0}, u8[0]{0}) custom-call(%bitcast.56301, %bitcast.56302, %bitcast.56303, %bitcast.56294), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target=\"__cudnn$convBiasActivationForward\", metadata={op_name=\"jit(simulator_fn_jit)/jit(main)/res_net18/block_group_0/block_0/batchnorm_1/mul\" source_file=\"/usr/local/lib/python3.10/dist-packages/haiku/_src/batch_norm.py\" source_line=205}, backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"cudnn_conv_backend_config\":{\"conv_result_scale\":1,\"activation_mode\":\"kRelu\",\"side_input_scale\":1,\"leakyrelu_alpha\":0},\"force_earliest_schedule\":false} is taking a while...\n",
      "2025-05-14 19:15:32.668252: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 1.860999782s\n",
      "Trying algorithm eng0{} for conv %cudnn-conv-bias-activation.82 = (f32[1000,64,64,64]{3,2,1,0}, u8[0]{0}) custom-call(%bitcast.56301, %bitcast.56302, %bitcast.56303, %bitcast.56294), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target=\"__cudnn$convBiasActivationForward\", metadata={op_name=\"jit(simulator_fn_jit)/jit(main)/res_net18/block_group_0/block_0/batchnorm_1/mul\" source_file=\"/usr/local/lib/python3.10/dist-packages/haiku/_src/batch_norm.py\" source_line=205}, backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"cudnn_conv_backend_config\":{\"conv_result_scale\":1,\"activation_mode\":\"kRelu\",\"side_input_scale\":1,\"leakyrelu_alpha\":0},\"force_earliest_schedule\":false} is taking a while...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter 010  NLL = 2782.9531\n",
      "iter 020  NLL = -1727.6005\n",
      "iter 030  NLL = -2884.7927\n",
      "iter 040  NLL = -3712.0967\n",
      "iter 050  NLL = -3953.1406\n",
      "iter 060  NLL = -4005.5190\n",
      "iter 070  NLL = -4017.5876\n",
      "iter 080  NLL = -4027.6550\n",
      "iter 090  NLL = -4035.8206\n",
      "iter 100  NLL = -4043.7773\n",
      "iter 110  NLL = -4053.8899\n",
      "iter 120  NLL = -4067.9365\n",
      "iter 130  NLL = -4086.9587\n",
      "iter 140  NLL = -4112.6523\n",
      "iter 150  NLL = -4146.8940\n",
      "iter 160  NLL = -4191.7471\n",
      "iter 170  NLL = -4249.0264\n",
      "iter 180  NLL = -4318.9604\n",
      "iter 190  NLL = -4396.7539\n",
      "iter 200  NLL = -4467.8672\n",
      "iter 210  NLL = -4512.7324\n",
      "iter 220  NLL = -4527.4829\n",
      "iter 230  NLL = -4528.0879\n",
      "iter 240  NLL = -4527.7085\n",
      "iter 250  NLL = -4528.1758\n",
      "iter 260  NLL = -4528.4355\n",
      "iter 270  NLL = -4528.4561\n",
      "iter 280  NLL = -4528.4507\n",
      "iter 290  NLL = -4528.4570\n",
      "iter 300  NLL = -4528.4609\n",
      "iter 310  NLL = -4528.4609\n",
      "iter 320  NLL = -4528.4614\n",
      "iter 330  NLL = -4528.4614\n",
      "iter 340  NLL = -4528.4609\n",
      "iter 350  NLL = -4528.4609\n",
      "iter 360  NLL = -4528.4614\n",
      "iter 370  NLL = -4528.4619\n",
      "iter 380  NLL = -4528.4614\n",
      "iter 390  NLL = -4528.4609\n",
      "iter 400  NLL = -4528.4609\n",
      "iter 410  NLL = -4528.4609\n",
      "iter 420  NLL = -4528.4614\n",
      "iter 430  NLL = -4528.4609\n",
      "iter 440  NLL = -4528.4614\n",
      "iter 450  NLL = -4528.4609\n",
      "iter 460  NLL = -4528.4609\n",
      "iter 470  NLL = -4528.4609\n",
      "iter 480  NLL = -4528.4609\n",
      "iter 490  NLL = -4528.4609\n",
      "iter 500  NLL = -4528.4609\n",
      "n_obs: 1000, err: 0.520664393901825, mse: 0.9596536159515381\n"
     ]
    }
   ],
   "source": [
    "theta_init = sample_lensing_prior(jax.random.PRNGKey(0))\n",
    "theta_init_t = torch.as_tensor(np.asarray(theta_init).copy()).float()\n",
    "key = jax.random.PRNGKey(0)\n",
    "torch.manual_seed(0)\n",
    "\n",
    "params = jnp.array([0.2664, 0.0492, 0.831, 0.6727, 0.9645, -1.0])\n",
    "params = params.at[2].set(jnp.log(params[2]))\n",
    "params = params.at[1].set(jnp.log(params[1]))\n",
    "params = params.at[0].set(jnp.log(params[0]))\n",
    "\n",
    "res_list = []\n",
    "for n_obs in [100, 250, 500, 1000]:\n",
    "    obs = sim_fn(jax.random.PRNGKey(0), params, n_obs)\n",
    "    obs_t = torch.as_tensor(np.asarray(obs).copy()).float().to(\"cuda\")\n",
    "    \n",
    "    def neg_loglik(theta_vec):\n",
    "        theta_tiled = torch.tile(theta_vec, (obs_t.shape[0], 1))\n",
    "        log_qs = density_estimator.loss(obs_t, condition=theta_tiled)\n",
    "        return log_qs.sum()                      \n",
    "\n",
    "    theta = torch.nn.Parameter(theta_init_t.clone().to(\"cuda\"))\n",
    "\n",
    "    opt = torch.optim.Adam([theta], lr=5e-2)\n",
    "    \n",
    "    for i in range(500):\n",
    "        opt.zero_grad()\n",
    "        loss = neg_loglik(theta)     \n",
    "        loss.backward()              \n",
    "        opt.step()                   \n",
    "    \n",
    "        if (i+1) % 10 == 0:\n",
    "            print(f\"iter {i+1:03d}  NLL = {loss.item():.4f}\")\n",
    "\n",
    "    key, subkey = jax.random.split(key)\n",
    "    sims = sim_fn(subkey, jnp.array(torch.tensor(theta)), n_obs)\n",
    "    err = jnp.linalg.norm(sims - obs, axis=1).mean()\n",
    "    mse = jnp.linalg.norm(jnp.array(torch.tensor(theta)) - params)\n",
    "    print(f\"n_obs: {n_obs}, err: {err}, mse: {mse}\")\n",
    "    \n",
    "    res_list.append({\n",
    "                'theta': jnp.array(torch.tensor(theta)),\n",
    "                'err'  : err,\n",
    "                'mse'  : mse,\n",
    "                'n_obs': n_obs\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3c3674c1-ae13-4d3c-8254-49fc2a703000",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"cosmo_nle_res.pkl\", \"wb\") as f:\n",
    "    pickle.dump(res_list, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
