from functools import partial
import os
import pickle as pkl
from collections.abc import MutableMapping
from datetime import datetime
from itertools import product
from functools import partial
from typing import Dict, Callable
import time
import logging

import matplotlib.pyplot as plt
import matplotlib.tri as tri
import numpy as np
import tqdm

import jax
import jax.numpy as jnp
from jax import random, grad, vmap, jit, hessian, lax
from jax.scipy.special import logsumexp
from jax.example_libraries import optimizers
from jax.nn import relu
from jax.config import config
from jax.flatten_util import ravel_pytree
import flax

from ..icbc_patch import generate_residue
from ..models.pinn_ensemble import PINNEnsemble
from ..utils import sample_from_uniform
from ..utils.jax_utils import flatten, vmap_mjp, vmap_jmp, jacobian_outer_product
from ..utils.vmap_chunked import vmap_chunked
# from .ed_loop import ExperimentalDesign
from .criterion_based import CriterionBasedAbstractMethod
from .eig_estimators.losses import generate_loss

# logger for this file
logging.getLogger().setLevel(logging.INFO)


class PINNACLEMethod(CriterionBasedAbstractMethod):
    
    def __init__(self, simulator_xs, pde, pde_domain, exp_design_fn, obs_design_fn,
                 inv_embedding, inv_param_in_domain, exp_in_domain, obs_in_domain,
                 inv_input_dim, exp_input_dim, obs_input_dim, obs_reading_count,
                 x_input_dim, y_output_dim, 
                 ensemble_size: int = 100, pinn_ensemble_args: Dict = dict(), 
                 pinn_share_init: bool = False, pinn_init_meta_rounds: int = 0, pinn_init_meta_steps: int = 1000, pinn_meta_eps: float = 0.1,
                 ensemble_steps: int = 100000, acq_fn: str = 'ucb', reg: float = 1e-6,
                 pde_colloc_sample_num: int = 1000, icbc_colloc_sample_num: int = 100, use_inverse_crit: bool = True,
                 exp_setup_rounds: int = 10, obs_setup_rounds: int = 10, obs_search_time_limit: float = 3600., noise_std: float = 1e-3, 
                 obs_optim_gd_params: Dict = dict(stepsize=1e-2, maxiter=100, acceleration=True), obs_optim_grad_clip: float = None, obs_optim_grad_jitter: float = None, 
                 min_obs_round: int = 3, seed: int = 0):
        super().__init__(
            simulator_xs=simulator_xs,
            pde=pde, 
            pde_domain=pde_domain, 
            exp_design_fn=exp_design_fn, 
            obs_design_fn=obs_design_fn,
            inv_embedding=inv_embedding, 
            inv_param_in_domain=inv_param_in_domain, 
            exp_in_domain=exp_in_domain, 
            obs_in_domain=obs_in_domain,
            inv_input_dim=inv_input_dim, 
            exp_input_dim=exp_input_dim, 
            obs_input_dim=obs_input_dim, 
            obs_reading_count=obs_reading_count,
            x_input_dim=x_input_dim,
            y_output_dim=y_output_dim,
            use_pinns=True,
            ensemble_size=ensemble_size,
            ensemble_steps=ensemble_steps,
            pinn_ensemble_args=pinn_ensemble_args,
            pinn_share_init=pinn_share_init,
            pinn_init_meta_rounds=pinn_init_meta_rounds,
            pinn_init_meta_steps=pinn_init_meta_steps,
            pinn_meta_eps=pinn_meta_eps,
            acq_fn=acq_fn,
            exp_setup_rounds=exp_setup_rounds,
            obs_setup_rounds=obs_setup_rounds,
            obs_search_time_limit=obs_search_time_limit,
            noise_std=noise_std,
            obs_optim_with_gd=True,
            do_jit=True,
            obs_optim_gd_params=obs_optim_gd_params,
            obs_optim_grad_clip=obs_optim_grad_clip,
            obs_optim_grad_jitter=obs_optim_grad_jitter,
            min_obs_rounds=min_obs_round,
            seed=seed,
        )
        
        self.pde_colloc_sample_num = pde_colloc_sample_num
        self.icbc_colloc_sample_num = icbc_colloc_sample_num
        self.use_inverse_crit = use_inverse_crit
        self.reg = reg
        
    def _generate_criterion_inner(self, exp_design, true_inv_prior_samples):
        
        updated_ensemble_args = {k: self.pinn_ensemble_args[k] for k in self.pinn_ensemble_args.keys()}
        updated_ensemble_args['n_pde_collocation_pts'] = self.pde_colloc_sample_num
        updated_ensemble_args['n_icbc_collocation_pts'] = self.icbc_colloc_sample_num
        inverse_ens = PINNEnsemble(
            pde=self.pde, 
            pde_domain=self.pde_domain, 
            exp_design_fn=self.exp_design_fn, 
            obs_design_fn=self.obs_design_fn,
            inv_embedding=self.inv_embedding,
            inv_problem=True,
            rng=self.get_rng(),
            **updated_ensemble_args
        )   
        
        inv_params_guesses = self.sample_inv_param(self.ensemble_size)
        if self.pinn_share_init:
            new_nn_params = self._generate_shared_params(n=self.ensemble_size, inv=inv_params_guesses)
        else:
            new_nn_params = None
        
        inverse_ens.reset()
        inverse_ens.prep_simulator(
            exp_params=exp_design, 
            inv_params_guesses=inv_params_guesses, 
            new_nn_params=new_nn_params,
        )
        
        modified_apply = inverse_ens.net.apply
        
        idxs_sample = jnp.array(np.random.choice(
            self.forward_ens.pde_collocation_pts.shape[0], 
            size=min(self.forward_ens.pde_collocation_pts.shape[0], self.pde_colloc_sample_num), 
            replace=False
        ))
        xs_colloc = self.forward_ens.pde_collocation_pts[idxs_sample]
        
        bc_fns = [fn for (fn, _) in self.forward_ens.exp_design_fn]
        bc_data = [xs[:min(xs.shape[0], self.icbc_colloc_sample_num)] for (_, xs) in self.forward_ens.exp_design_fn]
        
        @jax.jit
        def colloc_residues(params):
            nn_params = params['net']
            beta = params['inv']
            f_ = lambda xs: modified_apply(nn_params, xs)
            return {
                'pde': self.pde(xs_colloc, (f_(xs_colloc), f_), beta, exp_design)[0][:,0].reshape(-1) / (xs_colloc.shape[0] ** 0.5),
                'bcs': [
                    bc_fn(nn_params, modified_apply, self.forward_ens.exp_params, beta, bc_xs).reshape(-1) / (bc_xs.shape[0] ** 0.5)
                    for bc_fn, bc_xs in zip(bc_fns, bc_data)
                ]
            }
            
        @jax.jit
        def colloc_residue_stacked(params):
            res = colloc_residues(params)
            return jnp.concatenate([res['pde']] + res['bcs'])
        
        @jax.jit
        def anc_residue(params, obs, forward_params):
            reading_fn = lambda p: self.obs_design_fn((lambda xs: modified_apply(p, xs)), obs)
            return (reading_fn(params['net']) - reading_fn(jax.lax.stop_gradient(forward_params['net']))).reshape(-1)
        
        def get_ntk_from_jac(jac1, jac2):
            # prods = [jnp.einsum('ijk,ljk->jil', jac1[k], jac2[k]) for k in jac1.keys()]
            # return sum(prods)
            prods = None
            for k in jac1.keys():
                m = jac1[k] @ jac2[k].T
                prods = m if (prods is None) else (prods + m)
            return prods

        def _flatten_dict(d, parent_key='', sep='_'):
            # https://stackoverflow.com/questions/6027558/flatten-nested-python-dictionaries-compressing-keys
            # needed as neural network parameters are stored in a nested dictionary, e.g. {'params': {'dense': {'kernel': ...}}}
            items = []
            for k, v in d.items():
                new_key = parent_key + sep + k if parent_key else k
                if isinstance(v, MutableMapping) or isinstance(v, flax.core.frozen_dict.FrozenDict):
                    items.extend(_flatten_dict(v, new_key, sep=sep).items())
                else:
                    items.append((new_key, v))
            return dict(items)

        def get_jac(fn, p):
            dd = jax.jacrev(fn)(p)
            dd = _flatten_dict(dd)
            # currently only works for one-dimensional model outputs
            return {k: dd[k].reshape(dd[k].shape[0], -1) for k in dd.keys()}
        
        def jac_mult_fn(jac1, jac2):
            # prods = [jnp.einsum('ijk,ljk->jil', jac1[k], jac2[k]) for k in jac1.keys()]
            # return sum(prods)
            prods = None
            for k in jac1.keys():
                m = jac1[k] @ jac2[k].T
                prods = m if (prods is None) else (prods + m)
            return prods

        @jax.jit
        def single_network_criterion(obs_param, p, colloc_ntk, colloc_jac, colloc_res, forward_p):
            
            anc_fn = lambda p_: anc_residue(params=p, obs=obs_param, forward_params=forward_p)
            anc_jac = get_jac(anc_fn, p)
            anc_res = anc_fn(p)
            
            ntk_CC = colloc_ntk
            ntk_CA = jac_mult_fn(colloc_jac, anc_jac)
            ntk_AA = jac_mult_fn(anc_jac, anc_jac)
            ntk_matrix = jnp.block([[ntk_CC, ntk_CA], [ntk_CA.T, ntk_AA]])
            residue = jnp.block([[colloc_res.reshape(-1, 1)], [anc_res.reshape(-1, 1)]])
            
            ntk_matrix = ntk_matrix + self.reg * jnp.eye(ntk_matrix.shape[0])
            if self.use_inverse_crit:
                A = jnp.linalg.inv(ntk_matrix)
            else:
                A = ntk_matrix
            
            aux = {
                'anc_res': anc_res,
                'anc_jac': anc_jac,
                'ntk_matrix': ntk_matrix,
                'residue': residue,
            }
            return jnp.sqrt(residue.T @ A @ residue) / residue.shape[0]
        
        FOR_PARAMS_ALL = self.forward_ens.params
        INV_PARAMS_ALL = inverse_ens.params
        COLLOC_JAC_ALL = jax.vmap(lambda p_: get_jac(colloc_residue_stacked, p_))(INV_PARAMS_ALL)
        COLLOC_NTK_ALL = jax.vmap(lambda j_: jac_mult_fn(j_, j_))(COLLOC_JAC_ALL)
        COLLOC_RES_ALL = jax.vmap(colloc_residue_stacked)(INV_PARAMS_ALL)
        
        @jax.jit
        def criterion(obs_design, rng=jax.random.PRNGKey(0)):
            scores_indiv = jax.vmap(single_network_criterion, in_axes=(None, 0, 0, 0, 0, 0))(
                obs_design, INV_PARAMS_ALL, COLLOC_NTK_ALL, COLLOC_JAC_ALL, COLLOC_RES_ALL, FOR_PARAMS_ALL,
            )
            if self.use_inverse_crit:
                scores_indiv *= -1.
            return jnp.mean(scores_indiv), {'indiv_scores': scores_indiv}
        
        helper_fns = {
            'colloc_residues': colloc_residues,
            'colloc_residue_stacked': colloc_residue_stacked,
            'anc_residue': anc_residue,
            'single_network_criterion': single_network_criterion,
            'for_params': FOR_PARAMS_ALL,
            'inv_params': INV_PARAMS_ALL,
            'colloc_jacs': COLLOC_JAC_ALL,
            'colloc_ntks': COLLOC_NTK_ALL,
            'colloc_res': COLLOC_RES_ALL,
        }
        
        return criterion, helper_fns








# from functools import partial
# import os
# import pickle as pkl
# from collections.abc import MutableMapping
# from datetime import datetime
# from itertools import product
# from functools import partial
# from typing import Dict
# import time
# import logging

# import matplotlib.pyplot as plt
# import matplotlib.tri as tri
# import numpy as np
# import tqdm

# import jax
# import jax.numpy as jnp
# from jax import random, grad, vmap, jit, hessian, lax
# from jax.scipy.special import logsumexp
# from jax.example_libraries import optimizers
# from jax.nn import relu
# from jax.config import config
# from jax.flatten_util import ravel_pytree
# import optax
# import jaxopt
# from scipy.stats import spearmanr, pearsonr
# from scipy.interpolate import griddata
# import torch

# from gpytorch.mlls import ExactMarginalLogLikelihood
# from botorch.fit import fit_gpytorch_mll
# from botorch.models import SingleTaskGP
# from botorch.acquisition import qExpectedImprovement, qUpperConfidenceBound, qSimpleRegret
# from botorch.optim.initializers import gen_batch_initial_conditions, initialize_q_batch_nonneg
# from botorch.generation import gen_candidates_torch, get_best_candidates
# from botorch.sampling.stochastic_samplers import StochasticSampler

# from ..icbc_patch import generate_residue
# from ..models.pinn_ensemble import PINNEnsemble
# from ..utils import sample_from_uniform
# from .ed_loop import ExperimentalDesign
# from .utils.ntk import NTKHelper

# # logger for this file
# logging.getLogger().setLevel(logging.INFO)


# class PINNACLEEnsembleMethod(ExperimentalDesign):
    
#     def __init__(self, simulator_xs, pde, pde_domain, exp_design_fn, obs_design_fn,
#                  inv_embedding, inv_param_in_domain, exp_in_domain, obs_in_domain,
#                  inv_input_dim, exp_input_dim, obs_input_dim, obs_reading_count,
#                  x_input_dim, y_output_dim, 
#                  ensemble_size: int = 100, pinn_ensemble_args: Dict = dict(), 
#                  forward_ensemble_steps: int = 100000, inverse_ensemble_steps: int = 100000,
#                  acq_fn: str = 'ucb', exp_setup_rounds: int = 10, 
#                  obs_design_opt_rounds: int = 10, colloc_point_sample_count: int = 100, 
#                  eigval_reg: float = 0., inverse_reg: float = 1e-2, noise_std: float = 1e-3, 
#                  optim_args: Dict = dict(stepsize=1e-2, maxiter=2000, acceleration=False),
#                  seed: int = 0):
#         super().__init__(
#             simulator_xs=simulator_xs,
#             pde=pde, 
#             pde_domain=pde_domain, 
#             exp_design_fn=exp_design_fn, 
#             obs_design_fn=obs_design_fn,
#             inv_embedding=inv_embedding,
#             inv_param_in_domain=inv_param_in_domain, 
#             exp_in_domain=exp_in_domain, 
#             obs_in_domain=obs_in_domain,
#             inv_input_dim=inv_input_dim, 
#             exp_input_dim=exp_input_dim, 
#             obs_input_dim=obs_input_dim, 
#             obs_reading_count=obs_reading_count,
#             x_input_dim=x_input_dim,
#             y_output_dim=y_output_dim,
#         )
        
#         self.ensemble_size = ensemble_size
#         self.forward_ensemble_steps = forward_ensemble_steps
#         self.inverse_ensemble_steps = inverse_ensemble_steps
#         self.pinn_ensemble_args = pinn_ensemble_args
#         self.acq_fn = acq_fn
#         self.exp_setup_rounds = exp_setup_rounds
#         self.inverse_reg = inverse_reg
#         self.colloc_point_sample_count = colloc_point_sample_count
#         self.obs_design_opt_rounds = obs_design_opt_rounds
#         self.eigval_regulariser = eigval_reg
#         self.noise_std = noise_std
#         self.optim_args = optim_args
        
#         self.forward_ens = PINNEnsemble(
#             pde=self.pde, 
#             pde_domain=self.pde_domain, 
#             exp_design_fn=self.exp_design_fn, 
#             inv_problem=False,
#             rng=self.get_rng(),
#             **self.pinn_ensemble_args
#         )
        
#         self.inverse_ens = PINNEnsemble(
#             pde=self.pde, 
#             pde_domain=self.pde_domain, 
#             exp_design_fn=self.exp_design_fn, 
#             inv_problem=True,
#             rng=self.get_rng(),
#             **self.pinn_ensemble_args
#         )
        
#     def _inner_sample_inv_param(self, n, rng):
#         raise ValueError
        
#     def _inner_experiment_round(self, given_exp_design=None, given_obs_design=None):
        
#         ran_exp_params = []
#         ran_exp_scores = []
#         ran_exp_coresponding_obs = []
#         ran_exp_auxs = []
        
#         t = time.time()
#         inv_prior_samples = self.sample_inv_param(n=self.ensemble_size, rng=self.get_rng())
#         t = time.time() - t
#         logging.info(f'[TIMING] Sampling inverse params (s) : {t:.6f}')
            
#         for i in range(self.exp_setup_rounds):
            
#             logging.info(f'[OUTER_LOOP] Running outer loop {i+1} of {self.exp_setup_rounds}.')
            
#             if i < 2:
                
#                 exp_design_candidate = sample_from_uniform(
#                     n=1, 
#                     bounds=self.exp_in_domain, 
#                     sample_dim=self.exp_input_dim, 
#                     rng=self.get_rng()
#                 )[0]
            
#             else:
                
#                 train_X = torch.tensor(np.array(ran_exp_params))
#                 train_Y = torch.tensor(np.array(ran_exp_scores).reshape(-1, 1))
#                 train_Y = (train_Y - torch.mean(train_Y)) / torch.std(train_Y)

#                 model = SingleTaskGP(train_X, train_Y)
#                 mll = ExactMarginalLogLikelihood(model.likelihood, model)
#                 fit_gpytorch_mll(mll)
                
#                 sampler = StochasticSampler(sample_shape=torch.Size([128]))
#                 if self.acq_fn == 'ucb':
#                     q_fn = qUpperConfidenceBound(model, beta=1., sampler=sampler)
#                 elif self.acq_fn == 'ei':
#                     q_fn = qExpectedImprovement(model, best_f=train_Y.max(), sampler=sampler)
#                 else:
#                     raise ValueError(f'Invalid self.acq_fn {self.acq_fn}')
                
#                 exp_domain_np = np.array(self.exp_in_domain)
#                 Xinit = gen_batch_initial_conditions(q_fn, torch.tensor(exp_domain_np.T), q=1, num_restarts=25, raw_samples=500)
#                 batch_candidates, batch_acq_values = gen_candidates_torch(
#                     initial_conditions=Xinit,
#                     acquisition_function=q_fn,
#                     lower_bounds=torch.tensor(exp_domain_np[:,0]),
#                     upper_bounds=torch.tensor(exp_domain_np[:,1]),
#                 )
#                 exp_design_candidate = jnp.array(get_best_candidates(batch_candidates, batch_acq_values)[0].cpu().detach().numpy())
            
#             if self.exp_input_dim <= 5:
#                 logging.info(f'[OUTER_LOOP] Candidate for round {i+1} of {self.exp_setup_rounds} is {exp_design_candidate}.')
            
#             t = time.time()
#             score, best_obs_param, aux = self._process_exp_design(
#                 exp_design=exp_design_candidate,
#                 true_inv_prior_samples=inv_prior_samples,
#             )
#             t = time.time() - t
#             logging.info(f'[TIMING] [OUTER_LOOP] Computing score for outer loop {i+1} of {self.exp_setup_rounds} (s) : {t:.6f}')
            
#             logging.info(f'[OUTER_LOOP] Score for outer loop {i+1} of {self.exp_setup_rounds} is {score:.10f}.')
#             logging.info(f'[OUTER_LOOP] Finished outer loop {i+1} of {self.exp_setup_rounds}.')
#             ran_exp_params.append(exp_design_candidate)
#             ran_exp_scores.append(score)
#             ran_exp_coresponding_obs.append(best_obs_param)
#             ran_exp_auxs.append(aux)
            
#         best_i = np.argmax(ran_exp_scores)
#         best_exp = ran_exp_params[best_i]
#         best_obs = ran_exp_coresponding_obs[best_i]
#         aux = {
#             'best_exp': best_exp,
#             'best_obs': best_obs,
#             'ran_exp_params': ran_exp_params,
#             'ran_exp_scores': ran_exp_scores,
#             'ran_exp_coresponding_obs': ran_exp_coresponding_obs,
#             'ran_exp_auxs': ran_exp_auxs,
#             'inv_prior_samples': inv_prior_samples,
#         }
#         return best_exp, best_obs, aux
        
#     def _process_exp_design(self, exp_design, true_inv_prior_samples):
        
#         t = time.time()
#         self.forward_ens.reset()
#         self.forward_ens.prep_simulator(exp_params=exp_design, inv_params=true_inv_prior_samples)
#         for _ in tqdm.trange(self.forward_ensemble_steps):
#             self.forward_ens.step_opt()
#         t = time.time() - t
#         logging.info(f'[TIMING] [OUTER_LOOP] Training forward ensemble (s) : {t:.6f}')
            
#         t2 = time.time()
#         inv_prior_guess = self.sample_inv_param(n=self.ensemble_size, rng=self.get_rng())
#         self.inverse_ens.reset()
#         self.inverse_ens.prep_simulator(exp_params=exp_design, inv_params_guesses=inv_prior_guess)
#         net_params_init = self.inverse_ens.params
        
#         xs_pde = self.forward_ens.pde_collocation_pts[:self.colloc_point_sample_count]
#         xs_bcs = [xs[:self.colloc_point_sample_count] for xs in self.forward_ens.icbc_points]
        
#         ntk_helper = NTKHelper(net=self.inverse_ens.net, pde=self.inverse_ens.pde, bcs=self.inverse_ens.icbc, inverse_problem=True)
#         colloc_jacs = jax.vmap(lambda p: ntk_helper.get_combined_jac(p, xs_pde=xs_pde, xs_bcs=xs_bcs))(net_params_init)
#         colloc_ntks = jax.vmap(ntk_helper.get_ntk)(colloc_jacs, colloc_jacs)
#         colloc_residues = jax.vmap(lambda p: ntk_helper.get_combined_res(p, xs_pde=xs_pde, xs_bcs=xs_bcs))(net_params_init)
        
#         anc_jac_fn = lambda xs: jax.vmap(ntk_helper.get_jac_fn(code=-2), in_axes=(0, None))(net_params_init, xs)
#         anc_res_fn = lambda xs: jax.vmap(
#             lambda x_, fp, ip: ntk_helper.get_residue_fn(code=-2, anc_idx=0, anc_model_param=fp)(ip, x_), 
#             in_axes=(None, 0, 0)
#         )(xs, self.forward_ens.params, net_params_init)
        
#         jac_mult_fn = ntk_helper.get_ntk_fn()
        
#         @jax.jit
#         def single_network_criterion(colloc_ntk, colloc_jac, anc_jac, colloc_res, anc_res):
#             ntk_CC = colloc_ntk
#             ntk_CA = jac_mult_fn(colloc_jac, anc_jac)
#             ntk_AA = jac_mult_fn(anc_jac, anc_jac)
#             ntk_matrix = jnp.block([[ntk_CC, ntk_CA], [ntk_CA.T, ntk_AA]])
#             ntk_matrix = ntk_matrix + self.inverse_reg * jnp.eye(ntk_matrix.shape[0])
#             residue = jnp.block([[colloc_res], [anc_res]])
#             min_eig = jnp.linalg.eigvalsh(ntk_AA)[0]
#             # return - residue.T @ ntk_matrix @ residue - self.eigval_regulariser**2 * min_eig  # negative because we want to maximise criterion
#             return (residue.T @ jnp.linalg.inv(ntk_matrix) @ residue) - (self.eigval_regulariser**2 * min_eig)
        
#         @jax.jit
#         def criterion(obs_param):
#             xs_obs = self.obs_design_fn(obs_param)
#             anc_jacs = anc_jac_fn(xs_obs)
#             anc_ress = anc_res_fn(xs_obs)
#             criteria = jax.vmap(single_network_criterion)(
#                 colloc_ntk=colloc_ntks, 
#                 colloc_jac=colloc_jacs,
#                 anc_jac=anc_jacs,
#                 colloc_res=colloc_residues,
#                 anc_res=anc_ress
#             )
#             return jnp.mean(criteria)
            
#         obs_param_candidates = []
#         obs_param_scores = []
#         current_best_i = 0
            
#         for r in range(self.obs_design_opt_rounds):
            
#             t = time.time()
#             obs_candidate = sample_from_uniform(
#                 n=1, 
#                 bounds=self.obs_in_domain, 
#                 sample_dim=self.obs_input_dim, 
#                 rng=self.get_rng()
#             )[0]
#             pg = jaxopt.ProjectedGradient(
#                 fun=criterion, 
#                 projection=jaxopt.projection.projection_box,
#                 **self.optim_args
#             )
#             pg = pg.run(obs_candidate, hyperparams_proj=self.obs_in_domain.T)
#             best_obs_param = pg.params
            
#             obs_param_candidates.append(best_obs_param)
#             obs_param_scores.append(criterion(best_obs_param))
#             if obs_param_scores[current_best_i] > obs_param_scores[r]:
#                 current_best_i = r
#             t = time.time() - t
#             logging.info(f'[TIMING] [INNER_LOOP] Finding obs param candidate {r+1} (s) : {t:.6f}')
                
#         best_obs_param = obs_param_candidates[current_best_i]
#         best_obs_score = obs_param_scores[current_best_i]
#         t2 = time.time() - t2
#         logging.info(f'[TIMING] [OUTER_LOOP] Finding best obs param (s) : {t2:.6f}')
        
#         t = time.time()
#         self.inverse_ens.reset()
#         self.inverse_ens.prep_simulator(exp_params=exp_design, inv_params_guesses=inv_prior_guess)

#         xs_obs = self.obs_design_fn(obs_param=best_obs_param)
#         ys_obs = self.forward_ens.generate_pred_function()(xs_obs)
#         ys_obs += self.noise_std * jax.random.normal(key=self.get_rng(), shape=ys_obs.shape)
        
#         xs_obs_split = jnp.repeat(xs_obs[None,:], self.ensemble_size, axis=0)
#         for _ in tqdm.trange(self.inverse_ensemble_steps):
#             self.inverse_ens.step_opt(xs_obs_split, ys_obs)
            
#         t = time.time() - t
#         logging.info(f'[TIMING] [OUTER_LOOP] Training inverse ensemble on best candidate (s) : {t:.6f}')
        
#         # negative since BO does maximisation
#         score = - jnp.mean(jnp.linalg.norm(self.inverse_ens.params["inv"] - true_inv_prior_samples, axis=1))
        
#         aux = {
#             'exp_param': exp_design,
#             'best_score': score,
#             'best_obs_param': best_obs_param,
#             'criterion': criterion,
#             'criterion_components': {
#                 'colloc_jac': colloc_jacs,
#                 'colloc_res': colloc_residues,
#                 'net_params_init': net_params_init,
#             },
#             'obs_param_candidates': obs_param_candidates,
#             'obs_param_scores': obs_param_scores,
#             'inv_prior_samples': true_inv_prior_samples,
#             'forward_ensemble_nn_params': self.forward_ens.params,
#             'inverse_ensemble_nn_params': self.inverse_ens.params,
#         }
        
#         return score, best_obs_param, aux
    
#     def _inner_process_obs(self, best_exp, best_obs, observation):
#         raise NotImplementedError
