# from functools import partial
# import os
# import pickle as pkl
# from collections.abc import MutableMapping
# from datetime import datetime
# from itertools import product
# from functools import partial
# from typing import Dict
# import time
# import logging

# import jax
# import jax.numpy as jnp

# from .ed_loop import ExperimentalDesign
# from ..utils import sample_from_uniform
# from ..models.deeponet_modified import DeepONetModified, train_deeponet, train_deeponet_with_pde, \
#     generate_train_set_for_deeponet, generate_collocation_train_set_for_deeponet, generate_train_set_for_deeponet_from_known_xs
# from .utils.losses import generate_loss


# class DeepONetBased(ExperimentalDesign):
    
#     def __init__(self, simulator_xs, pde, pde_domain, exp_design_fn, obs_design_fn,
#                  inv_param_in_domain, exp_in_domain, obs_in_domain,
#                  inv_input_dim, exp_input_dim, obs_input_dim, x_input_dim,
#                  deeponet_hidden_layers: int = 4, deeponet_hidden_dim: int = 64, deeponet_arch: str= None,
#                  n_exp_sim: int = 10, n_inv_sim: int = 10, n_xs_sim: int = 10000,
#                  n_exp_colloc: int = 100, n_inv_colloc: int = 1000, n_xs_colloc: int = 10000,
#                  train_steps: int = 100000, batch_size: int = 32, optim_type: str = 'adam', optim_args: Dict = None,
#                  mcmc_estimate_samples: int = 100, posterior_loss: str = 'mse', posterior_loss_args: Dict = dict(),
#                  seed: int = 0):
#         super().__init__(
#             simulator_xs=simulator_xs,
#             pde=pde, 
#             pde_domain=pde_domain, 
#             exp_design_fn=exp_design_fn, 
#             obs_design_fn=obs_design_fn,
#             inv_param_in_domain=inv_param_in_domain, 
#             exp_in_domain=exp_in_domain, 
#             obs_in_domain=obs_in_domain,
#             inv_input_dim=inv_input_dim, 
#             exp_input_dim=exp_input_dim, 
#             obs_input_dim=obs_input_dim, 
#             x_input_dim=x_input_dim,
#             seed=seed,
#         )
        
#         self.n_exp_sim = n_exp_sim
#         self.n_inv_sim = n_inv_sim
#         self.n_xs_sim = n_xs_sim
#         self.n_exp_colloc = n_exp_colloc
#         self.n_inv_colloc = n_inv_colloc
#         self.n_xs_colloc = n_xs_colloc
#         self.train_steps = train_steps
#         self.batch_size = batch_size
#         self.optim_type = optim_type
#         self.optim_args = optim_args
#         self.posterior_loss = posterior_loss
#         self.posterior_loss_args = posterior_loss_args
#         self.mcmc_estimate_samples = mcmc_estimate_samples
        
#         self.deeponet = DeepONetModified(
#             inv_input_dim=inv_input_dim, 
#             exp_input_dim=exp_input_dim, 
#             x_input_dim=x_input_dim, 
#             hidden_layers=deeponet_hidden_layers, 
#             hidden_dim=deeponet_hidden_dim, 
#             arch=deeponet_arch,
#         )
#         self.deeponet_params = None
#         self.log_inv_prior_pdf = lambda inv: 1.
        
#     def _inner_sample_inv_param(self, n, rng):
#         raise ValueError
        
#     def _inner_experiment_round(self, given_exp_design=None, given_obs_design=None):
#         raise NotImplementedError
    
#     def _train_deeponet(self):
        
#         exp_setup_list = sample_from_uniform(
#             n=self.n_exp_sim, 
#             bounds=self.exp_in_domain, 
#             sample_dim=self.exp_input_dim, 
#             rng=self.get_rng()
#         )
#         dset = generate_train_set_for_deeponet(
#             noisy_simulator_xs=self.simulator_xs, 
#             exp_setup_list=exp_setup_list, 
#             inv_param_sampler=self.sample_inv_param, 
#             pde_domain=self.pde_domain,
#             n_inv=self.n_inv_sim, 
#             n_xs=self.n_xs_sim, 
#             batch_size=self.batch_size, 
#             rng=self.get_rng(),
#         )

#         dset_colloc = generate_collocation_train_set_for_deeponet(
#             exp_setup_list=exp_setup_list, 
#             inv_param_sampler=self.sample_inv_param, 
#             pde_domain=self.pde_domain,
#             n_inv=self.n_inv_colloc, 
#             n_xs=self.n_xs_colloc, 
#             batch_size=self.batch_size, 
#             rng=self.get_rng(),
#         )

#         dset_bcs = []
#         for (_, bc_xs) in self.exp_design_fn:
#             exp_setup_list = sample_from_uniform(
#                 n=self.n_exp_colloc, 
#                 bounds=self.exp_in_domain, 
#                 sample_dim=self.exp_input_dim, 
#                 rng=self.get_rng()
#             )
#             dset_bcs.append(generate_train_set_for_deeponet_from_known_xs(
#                 exp_setup_list=exp_setup_list, 
#                 inv_param_sampler=self.sample_inv_param, 
#                 xs=bc_xs,
#                 n_inv=self.n_inv_colloc, 
#                 n_xs=self.n_xs_colloc, 
#                 batch_size=self.batch_size, 
#                 rng=self.get_rng(),
#             ))
            
#         self.deeponet_params, aux_training = train_deeponet_with_pde(
#             deeponet=self.deeponet,
#             dset=dset,
#             colloc_dset=dset_colloc,
#             pde=self.pde,
#             bcs_dset=dset_bcs,
#             bc_exp_fns=self.exp_design_fn,
#             steps=self.train_steps,
#             optim_type=self.optim_type, 
#             optim_args=self.optim_args, 
#             rng=self.get_rng(),
#         )
            
#         aux = dict(
#             dset=dset,
#             dset_colloc=dset_colloc,
#             dset_bcs=dset_bcs,
#             deeponet_params=self.deeponet_params,
#             aux_training=aux_training,
#         )
#         return aux
    
#     def _laplace_approx_vals(self, exp, obs, inv, ys):
        
#         def mock_simulator(exp, inv, rng):
#             return lambda xs: self.deeponet.apply(self.deeponet_params, inv, exp, xs)
        
#         log_loss = generate_loss(
#             noisy_simulator=mock_simulator, obs_design_fn=self.obs_design_fn,
#             loss=self.posterior_loss, **self.posterior_loss_args)
        
#         # compute unnormalised log posterior
#         # rng set to None because unused in this case
#         log_posterior = lambda inv_: log_loss(y_obs=ys, exp_params=exp, obs_params=obs, inv=inv_, rng=None) + self.log_inv_prior_pdf(inv=inv_)
        
#         # use the MAP as the inv that was specified
#         inv_MAP = inv
        
#         # compute negtive inv Hessian at inv location
#         H_inv = - jax.hessian(log_posterior)(inv_MAP)
        
#         return inv_MAP, log_posterior(inv_MAP), H_inv
    
#     def _get_eig_estimator(self):
        
#         inv_param_samples = self.sample_inv_param(n=self.mcmc_estimate_samples)
#         d = self.obs_design_fn(self.obs_in_domain[:,0]).shape[0]
        
#         def one_mcmc_round(exp, obs, inv):
#             xs = self.obs_design_fn(obs)
#             ys = self.deeponet.apply_single_branch(self.deeponet_params, inv, exp, xs)
#             _, _, H_inv = self._laplace_approx_vals(exp=exp, obs=obs, inv=inv, ys=ys)
#             posterior_entropy = 0.5 * d * (1. + jnp.log(2 * jnp.pi)) - 0.5 * jnp.linalg.slogdet(H_inv)[1]
#             prior_cross_entropy = self.log_inv_prior_pdf(inv)
#             return posterior_entropy - prior_cross_entropy
        
#         def estimator(exp, obs):
#             return jnp.mean(jax.vmap(one_mcmc_round, in_axes=(None, None, 0))(exp, obs, inv_param_samples))
        
#         return estimator, dict(one_mcmc_round=one_mcmc_round, inv_param_samples=inv_param_samples)
    
#     def _inner_experiment_round(self, given_exp_design=None, given_obs_design=None):
        
#         t = time.time()
#         training_aux = self._train_deeponet()
#         t = time.time() - t
#         logging.info(f'[TIMING] Training DeepONet (s) : {t:.6f}')
        
#         t = time.time()
#         eig_fn, eig_fn_aux = self._get_eig_estimator()
#         t = time.time() - t
#         logging.info(f'[TIMING] Preparing EIG estimator (s) : {t:.6f}')
        
        
    
#     def _inner_process_obs(self, best_exp, best_obs, observation):
#         raise NotImplementedError
