{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd335648",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from chip.utils.utils import get_uniform_angles\n",
    "from chip.datasets.tomogram_dataset import TomogramDataset\n",
    "from chip.models.forward_models import fourier_filtering\n",
    "from chip.utils.utils import create_circle_filter, create_gaussian_filter\n",
    "from chip.models.iterative_model import TomographicReconstruction\n",
    "from chip.training.iterative_reconstruction import finetune_sinogram_consistency\n",
    "from chip.utils.sinogram import Sinogram, compute_sinogram\n",
    "from chip.utils.plotting import plot_sinogram\n",
    "from chip.utils.metrics import PSNR, RMSE\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "# device = 'cpu'\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dddc65c-7488-4cd5-a30c-5df1a12413a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from bayes_dip.utils.experiment_utils import get_standard_ray_trafo, get_standard_dataset\n",
    "import torch\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from hydra import compose, initialize\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "    # context initialization\n",
    "initialize(version_base=None, config_path=\"../experiments/hydra_cfg/\", job_name=\"test_app\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13a9f626-23db-4219-b046-8aa4f032a839",
   "metadata": {},
   "outputs": [],
   "source": [
    "im_size =512\n",
    "cfg = compose(config_name=\"config\", overrides=[\n",
    "    \"experiment=chip\",\n",
    "    \"dip.optim.iterations=2000\",\n",
    "    \"dip.optim.gamma=10e-3\",\n",
    "    \"dip.net.use_sigmoid=True\",\n",
    "    \"dataset.rotation_angle=30\",\n",
    "    \"dataset.noise_stddev=0.0\",\n",
    "    f\"dataset.im_size={im_size}\",\n",
    "    \"dataset.path=/mydata/chip/shared/data/tomogram_synthetic.h5\",\n",
    "    \"mll_optim=walnut_sample_based_mll_optim\",\n",
    "    # \"mll_optim=walnut_sample_based_mll_optim_sto\",\n",
    "    # \"mll_optim=base_sam\",\n",
    "    \"mll_optim.activate_debugging_mode=False\",\n",
    "    \"mll_optim.num_samples=20\",\n",
    "    \"mll_optim.sampling.batch_size=20\",\n",
    "    \"mll_optim.sampling.use_conj_grad_inv=true\",\n",
    "    # \"mll_optim.use_sample_then_optimise=True\",\n",
    "    \"priors.use_gprior=True\",\n",
    "    \"priors.gprior.scale.obs_subsample_fct=100\",\n",
    "    # \"mll_optim.num_samples=8\"\n",
    "    \"trafo.num_angles=30\",\n",
    "    # \"trafo.geometry_specs.num_det_pixels=128\"\n",
    "    f\"trafo.geometry_specs.num_det_pixels={im_size}\"\n",
    "])\n",
    "print(OmegaConf.to_yaml(cfg))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22b4ca0d-ae4b-4b3b-8a61-0f1fa93e7649",
   "metadata": {},
   "outputs": [],
   "source": [
    "dtype = torch.get_default_dtype()\n",
    "device = torch.device(('cuda' if torch.cuda.is_available() else 'cpu'))\n",
    "print(device)\n",
    "\n",
    "ray_trafo = get_standard_ray_trafo(cfg)\n",
    "# ray_trafo.to(dtype=dtype, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e5cb0ce-f221-4394-a6be-3351eb124355",
   "metadata": {},
   "outputs": [],
   "source": [
    "# data: observation, ground_truth, filtbackproj\n",
    "dataset = get_standard_dataset(\n",
    "        cfg, ray_trafo, fold=cfg.dataset.fold, use_fixed_seeds_starting_from=cfg.seed,\n",
    "        device=device)\n",
    "obs, ground_truth, fbp = dataset[0]\n",
    "# ground_truth_orig = dataset.image_dataset.dataset_orig[0][1]\n",
    "ground_truth.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73c755d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from glob import glob\n",
    "from bayes_dip.dip.deep_image_prior import DeepImagePriorReconstructor\n",
    "from bayes_dip.inference.sample_based_predictive_posterior import SampleBasedPredictivePosterior\n",
    "from bayes_dip.marginal_likelihood_optim.preconditioner import get_preconditioner\n",
    "from bayes_dip.marginal_likelihood_optim.sample_based_mll_optim import sample_based_marginal_likelihood_optim\n",
    "from bayes_dip.marginal_likelihood_optim.utils import get_ordered_nn_params_vec\n",
    "from bayes_dip.probabilistic_models.linearized_dip.default_unet_priors import get_default_unet_gprior_dicts\n",
    "from bayes_dip.probabilistic_models.linearized_dip.image_cov import ImageCov\n",
    "from bayes_dip.probabilistic_models.linearized_dip.neural_basis_expansion.getter import get_neural_basis_expansion\n",
    "from bayes_dip.probabilistic_models.linearized_dip.parameter_cov import ParameterCov\n",
    "from bayes_dip.probabilistic_models.observation_cov import ObservationCov\n",
    "from bayes_dip.utils.experiment_utils import assert_sample_matches\n",
    "from bayes_dip.utils.utils import SSIM, PSNR\n",
    "import os\n",
    "from itertools import islice\n",
    "\n",
    "\n",
    "i = 0  # sample index\n",
    "data_sample = dataset[i]\n",
    "\n",
    "\n",
    "for i, data_sample in enumerate(islice(DataLoader(dataset), cfg.num_images)):\n",
    "    break\n",
    "\n",
    "if cfg.seed is not None:\n",
    "    torch.manual_seed(cfg.seed + i)\n",
    "\n",
    "# configs needed for 3D UQ estimation\n",
    "em_step = cfg.get('em_step', 0)\n",
    "load_previous_em_step_from_path = cfg.get(\n",
    "        'load_previous_em_step_from_path',\n",
    "        None\n",
    "    )\n",
    "load_previous_observation_cov_from_path = cfg.get(\n",
    "        'load_previous_observation_cov_from_path',\n",
    "        None\n",
    "    )\n",
    "\n",
    "observation, ground_truth, filtbackproj = data_sample\n",
    "\n",
    "load_dip_params_from_path = cfg.load_dip_params_from_path\n",
    "if cfg.mll_optim.init_load_path is not None and load_dip_params_from_path is None:\n",
    "    load_dip_params_from_path = cfg.mll_optim.init_load_path\n",
    "\n",
    "if load_dip_params_from_path is not None:\n",
    "    # assert that sample data matches with that from the dip to be loaded\n",
    "    assert_sample_matches(data_sample, load_dip_params_from_path, i, raise_if_file_not_found=False)\n",
    "\n",
    "torch.save(\n",
    "    {'observation': observation, 'filtbackproj': filtbackproj, 'ground_truth': ground_truth}, f'sample_{i}.pt')\n",
    "\n",
    "observation = observation.to(dtype=dtype, device=device)\n",
    "filtbackproj = filtbackproj.to(dtype=dtype, device=device)\n",
    "ground_truth = ground_truth.to(dtype=dtype, device=device)\n",
    "\n",
    "# try:\n",
    "#     assert cfg.dip.net.use_sigmoid is False\n",
    "# except AssertionError:\n",
    "#     raise(AssertionError('active sigmoid activation function'))\n",
    "\n",
    "net_kwargs = OmegaConf.to_object(cfg.dip.net)\n",
    "reconstructor = DeepImagePriorReconstructor(\n",
    "    ray_trafo, torch_manual_seed=cfg.dip.torch_manual_seed,\n",
    "    device=device, net_kwargs=net_kwargs,\n",
    "    load_params_path=cfg.load_pretrained_dip_params)\n",
    "if cfg.load_dip_params_from_path is None:\n",
    "    optim_kwargs = {\n",
    "        'lr': cfg.dip.optim.lr,\n",
    "        'iterations': cfg.dip.optim.iterations,\n",
    "        'loss_function': cfg.dip.optim.loss_function,\n",
    "        'gamma': cfg.dip.optim.gamma}\n",
    "    recon = reconstructor.reconstruct(\n",
    "        observation,\n",
    "        filtbackproj=filtbackproj,\n",
    "        ground_truth=ground_truth,\n",
    "        recon_from_randn=cfg.dip.recon_from_randn,\n",
    "        log_path=os.path.join(cfg.dip.log_path, f'dip_optim_{i}'),\n",
    "        optim_kwargs=optim_kwargs)\n",
    "else:\n",
    "    dip_params_filepath = os.path.join(load_dip_params_from_path, f'dip_model_{i}.pt')\n",
    "    print(f'loading DIP network parameters from {dip_params_filepath}')\n",
    "    reconstructor.load_params(dip_params_filepath)\n",
    "    assert not cfg.dip.recon_from_randn  # would need to re-create random input\n",
    "    recon = reconstructor.nn_model(filtbackproj).detach()  # pylint: disable=not-callable\n",
    "torch.save(reconstructor.nn_model.state_dict(),\n",
    "        f'dip_model_{i}.pt')\n",
    "torch.save(recon.cpu(),\n",
    "        f'recon_{i}.pt'\n",
    ")\n",
    "\n",
    "print(f'DIP reconstruction of sample {i}')\n",
    "print('PSNR:', PSNR(recon[0, 0].cpu().numpy(), ground_truth[0, 0].cpu().numpy()))\n",
    "print('SSIM:', SSIM(recon[0, 0].cpu().numpy(), ground_truth[0, 0].cpu().numpy()))\n",
    "\n",
    "assert cfg.priors.use_gprior # sample_based_marginal_likelihood_optim requires the g-prior assumption\n",
    "# https://en.wikipedia.org/wiki/G-prior\n",
    "prior_assignment_dict, hyperparams_init_dict = get_default_unet_gprior_dicts(\n",
    "    nn_model=reconstructor.nn_model, \n",
    "    gprior_hyperparams_init={'variance': cfg.priors.gprior.init_prior_variance_value})\n",
    "parameter_cov = ParameterCov(\n",
    "    reconstructor.nn_model,\n",
    "    prior_assignment_dict,\n",
    "    hyperparams_init_dict,\n",
    "    device=device\n",
    ")\n",
    "if cfg.load_gprior_scale_from_path is not None:\n",
    "    # 3D requires pre-computing and loading g-prior scale vec\n",
    "    load_scale_from_path = os.path.join(\n",
    "            cfg.load_gprior_scale_from_path,\n",
    "                f'gprior_scale_vector_{i}.pt')\n",
    "else:\n",
    "    load_scale_from_path = None\n",
    "\n",
    "neural_basis_expansion = get_neural_basis_expansion(\n",
    "    nn_model=reconstructor.nn_model,\n",
    "    nn_input=filtbackproj,\n",
    "    ordered_nn_params=parameter_cov.ordered_nn_params,\n",
    "    nn_out_shape=filtbackproj.shape,\n",
    "    use_gprior=True, # requires the g-prior assumption\n",
    "    # use_gprior=False, # requires the g-prior assumption\n",
    "    trafo=ray_trafo,\n",
    "    load_scale_from_path=load_scale_from_path,\n",
    "    scale_kwargs=OmegaConf.to_object(cfg.priors.gprior.scale)\n",
    "    )\n",
    "image_cov = ImageCov(parameter_cov=parameter_cov,\n",
    "        neural_basis_expansion=neural_basis_expansion\n",
    "        )\n",
    "# sample-based MLL based methods do not optimise noise variance, i.e. fixed to 1.\n",
    "observation_cov = ObservationCov(trafo=ray_trafo,\n",
    "        image_cov=image_cov, \n",
    "        device=device\n",
    "        )\n",
    "\n",
    "# if `m_step==0` setting g-prior to init value\n",
    "if em_step > 0:\n",
    "    assert load_previous_observation_cov_from_path is not None\n",
    "    # if `m_step>0` overwrite g_prior variance with the `em_step-1` optimised one\n",
    "    observation_cov.load_state_dict(torch.load(\n",
    "        os.path.join(load_previous_observation_cov_from_path, f'observation_cov_iter_{em_step - 1}.pt')))           \n",
    "\n",
    "optim_kwargs = {\n",
    "    'iterations': cfg.mll_optim.iterations,\n",
    "    'activate_debugging_mode': cfg.mll_optim.activate_debugging_mode,\n",
    "    'num_samples': cfg.mll_optim.num_samples,\n",
    "    'use_sample_then_optimise': cfg.mll_optim.use_sample_then_optimise\n",
    "    }\n",
    "optim_kwargs['sample_kwargs'] = OmegaConf.to_object(cfg.mll_optim.sampling)\n",
    "precon_kwargs = OmegaConf.to_object(cfg.mll_optim.preconditioner)\n",
    "\n",
    "if cfg.load_sample_based_precon_state_from_path is not None:\n",
    "    precon_kwargs['load_approx_basis'] = os.path.join(\n",
    "        cfg.load_sample_based_precon_state_from_path, f'preconditioner_{i}.pt')\n",
    "    precon_kwargs['load_state_dict'] = os.path.join(\n",
    "        cfg.load_sample_based_precon_state_from_path, f'observation_cov_{i}.pt')\n",
    "\n",
    "cg_preconditioner = None\n",
    "if cfg.mll_optim.use_preconditioner:\n",
    "    cg_preconditioner = get_preconditioner(observation_cov=observation_cov, kwargs=precon_kwargs)\n",
    "    optim_kwargs['sample_kwargs']['cg_kwargs']['precon_closure'] = cg_preconditioner.get_closure()\n",
    "optim_kwargs['cg_preconditioner'] = cg_preconditioner\n",
    "if cfg.mll_optim.activate_debugging_mode: optim_kwargs['debugging_mode_kwargs'] = OmegaConf.to_object(\n",
    "        cfg.mll_optim.debugging_mode_kwargs)\n",
    "\n",
    "predictive_posterior = SampleBasedPredictivePosterior(observation_cov)\n",
    "posterior_obs_samples_sq_sum = {} # to compute eff. dims in 3D \n",
    "prev_linear_weights = None\n",
    "if load_previous_em_step_from_path is not None:\n",
    "    post_sample_sq_sum_paths = glob(\n",
    "            os.path.join(load_previous_em_step_from_path, f'posterior_obs_samples_sq_sum_{i}_em={em_step}_seed=*.pt'))\n",
    "    for k, path in enumerate(post_sample_sq_sum_paths):\n",
    "        print(f'Loading sample from : ', path)\n",
    "        posterior_obs_samples_sq_sum_i = torch.load(path)\n",
    "        if k == 0:\n",
    "            posterior_obs_samples_sq_sum['value'] = posterior_obs_samples_sq_sum_i['value']\n",
    "            posterior_obs_samples_sq_sum['num_samples'] = posterior_obs_samples_sq_sum_i['num_samples']\n",
    "        else:\n",
    "            posterior_obs_samples_sq_sum['value'] += posterior_obs_samples_sq_sum_i['value']\n",
    "            posterior_obs_samples_sq_sum['num_samples'] += posterior_obs_samples_sq_sum_i['num_samples']\n",
    "    \n",
    "    prev_linear_weights = torch.load(f'linearized_weights_em={em_step - 1}_{i}.pt')\n",
    "\n",
    "image_samples, obs_samples = sample_based_marginal_likelihood_optim(\n",
    "# linearized_weights, linearized_recon = sample_based_marginal_likelihood_optim(\n",
    "    predictive_posterior=predictive_posterior,\n",
    "    map_weights=get_ordered_nn_params_vec(parameter_cov).clone(),\n",
    "    observation=observation,\n",
    "    nn_recon=recon,\n",
    "    ground_truth=ground_truth,\n",
    "    optim_kwargs=optim_kwargs,\n",
    "    log_path=os.path.join(  cfg.mll_optim.log_path, f'mrglik_optim_{i}' ),\n",
    "    posterior_obs_samples_sq_sum=posterior_obs_samples_sq_sum,\n",
    "    em_start_step=em_step,\n",
    "    prev_linear_weights=prev_linear_weights,\n",
    "    return_samples=True\n",
    "    )\n",
    "\n",
    "# torch.save(linearized_weights, f'linearized_weights_em={em_step}_{i}.pt')\n",
    "# torch.save(linearized_recon, f'linearized_recon_em={em_step}_{i}.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca5bec6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_samples.shape, obs_samples.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23c253ed",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
