# from functools import partial
# import os
# import pickle as pkl
# from collections.abc import MutableMapping
# from typing import Dict, Any, Callable, List
# import logging

# import tqdm

# import numpy as np
# import jax
# import jax.numpy as jnp
# from jax.scipy.special import logsumexp
# from jax.experimental.jax2tf import call_tf
# import flax
# from flax import linen as nn
# import optax
# import jaxopt

# # import gpjax as gpx

# from .. import deepxde as dde

# from ..model_loader import construct_net
# from ..utils import to_cpu, tree_stack, tree_unstack
# from ..icbc_patch import generate_residue
# from .pde_sampler import PDESampler

# # logger for this file
# logging.getLogger().setLevel(logging.INFO)


# class PINNEnsemble(PDESampler):
    
#     def __init__(self, pde: Callable, fixed_bcs: List, pde_domain, exp_design_fn: Callable, 
#                  noise_kernel: str = 'rbf', noise_kernel_args: Dict = None,
#                  nn_params: Dict = None, collocation_pts: int = 10000, 
#                  optim_method: str = 'adam', optim_args: Dict = None, optim_steps: int = 10000, loss_thr: float = 0.):
        
#         super().__init__()
#         # self.pde_class = pde_class  # takes in beta, returns (pde, [list of bcs])
#         self.pde = pde
#         self.fixed_bcs = fixed_bcs
#         self.exp_design_fn = exp_design_fn  # takes in experimental design params Dict, returns [list of bcs]
        
#         self.pde_domain = pde_domain
#         self.pde_collocation_pts = jnp.array(self.pde_domain.random_points(int(0.8 * collocation_pts), random='Hammersley'))
#         self.ic_pts = jnp.array(self.pde_domain.random_initial_points(collocation_pts // 10))
#         self.bc_pts = jnp.array(self.pde_domain.random_boundary_points(collocation_pts // 10))
                
#         # if noise_kernel == 'rbf':
#         #     self.noise_kernel_fn = gpx.kernels.RBF(**noise_kernel_args)
#         # elif noise_kernel == 'matern52':
#         #     self.noise_kernel_fn = gpx.kernels.matern52(**noise_kernel_args)
#         # else:
#         #     raise ValueError(f'Invalid {noise_kernel}.')
#         # self.noise_prior = gpx.Prior(mean_function=gpx.mean_functions.Zero(), kernel=self.noise_kernel_fn)
        
#         self.optim_steps = optim_steps
#         if optim_args is None:
#             self.optim_method = 'adam'
#             self.optim_args = dict(learning_rate=0.001)
#         else:
#             self.optim_method = optim_method
#             self.optim_args = optim_args
#         self.loss_thr = loss_thr
        
#         self.nn_construct_params = nn_params
#         self.net = construct_net(**self.nn_construct_params)[0]
#         self.rng = jax.random.PRNGKey(np.random.randint(1000000))
#         self.net_params = None
#         self.net_count = None
        
#         self._icbcs = None
#         self._icbcs_data = None
#         self._losses_steps = []
#         self._losses = []
        
#     def set_inv_params(self, inv_params):
#         self.inv_params = inv_params
#         self.net_count = self.inv_params.shape[0]
        
#     def _generate_solver(self, value_and_grad):
#         if self.optim_method == 'adam':
#             opt = optax.adam(**self.optim_args)
#             solver = jaxopt.OptaxSolver(opt=opt, fun=value_and_grad, value_and_grad=True)
#         elif self.optim_method == 'lbfgs':
#             solver = jaxopt.LBFGS(fun=value_and_grad, value_and_grad=True, jit=True, **self.optim_args)
#         else:
#             raise ValueError(f'Invalid optim_method: {self.optim_method}')
#         return solver
    
#     def _do_training(self, exp_params):
        
#         def _pde_residue_fn(params, xs, inv_param):
#             f_ = lambda xs: self.net.apply(params, xs, training=True)
#             return self.pde(xs, (f_(xs), f_), inv_param)[0]
        
#         design_bcs = self.exp_design_fn(exp_params)
#         bcs = self.fixed_bcs + design_bcs
#         bc_fns = [generate_residue(bc, net_apply=self.net.apply) for bc in bcs]
#         bc_data = []
#         for bc in bcs:
#             if isinstance(bc, dde.icbc.initial_conditions.IC):
#                 xs = self.ic_pts
#             else:
#                 xs = self.bc_pts
#             bc_data.append(xs)
#         pde_colloc_pts = self.pde_collocation_pts
            
#         def new_loss(params, const):
#             loss = jnp.mean(_pde_residue_fn(params, pde_colloc_pts, const) ** 2)
#             for bc_fn, bc_xs in zip(bc_fns, bc_data):
#                 loss += jnp.mean(bc_fn(params, bc_xs) ** 2)
#             return loss
        
#         solver = self._generate_solver(value_and_grad=jax.jit(jax.value_and_grad(new_loss)))
#         opt_state_list = jax.vmap(solver.init_state)(self.net_params)
            
#         const_list = jnp.array(self.inv_params)
#         params_list = self.net_params
        
#         step_batch = jax.jit(jax.vmap(solver.update))
#         loss_batch = jax.jit(jax.vmap(new_loss))
        
#         logging.info(f'Doing parallel training.')
#         tqdm.tqdm._instances.clear()
#         steps = self.optim_steps
#         factor = min(steps, 1000)
#         pbar = tqdm.trange(steps)
#         for s in range(steps // factor):
#             for _ in range(factor):
#                 params_list, opt_state_list = step_batch(params_list, opt_state_list, const_list)
#                 pbar.update()
#             l = loss_batch(params_list, const_list)
#             self.net_params = params_list
#             if (l < self.loss_thr).all():
#                 logging.info(f'Training terminate early due to achieving required accuracy.')
#                 break
#         logging.info(f'Training completed.')

#     def prep_simulator(self):
        
#         assert self.inv_params is not None
#         assert self.exp_params is not None
                
#         if self.net_params is None:
#             net_params = []
#             for _ in range(len(self.inv_params)):
#                 self.rng, key_ = jax.random.split(self.rng)
#                 params = self.net.init(key_, self.pde_collocation_pts[:1])
#                 net_params.append(params)
#             self.net_params = tree_stack(net_params)
                
#         self._do_training(exp_params=self.exp_params)
                
#     def log_likelihood(self, xs, ys):
        
#         @jax.jit
#         def indiv_llh(xs, ys, params):
#             mu = self.net.apply(params, xs, training=False)
#             d = ys - mu
#             sigma = self.noise_kernel_fn.gram(xs).matrix
#             sigma_inv = jnp.linalg.inv(sigma)
#             return - 0.5 * jnp.linalg.slogdet(sigma)[1] - (0.5 * (d.T @ sigma_inv @ d))[0,0]
        
#         batched_llh = jax.jit(jax.vmap(indiv_llh, in_axes=(None, None, 0)))
#         return batched_llh(xs, ys, self.net_params)
    
#     def sample_noiseless(self, xs):
#         apply_fn = jax.jit(jax.vmap(lambda p_, x_: self.net.apply(p_, x_, training=False), in_axes=(0, None)))
#         return apply_fn(self.net_params, xs) 
    
#     def sample(self, xs, rng):
#         apply_fn = jax.jit(jax.vmap(lambda p_, x_: self.net.apply(p_, x_, training=False), in_axes=(0, None)))
#         # prior_dist = self.noise_prior.predict(xs)
#         noiseless_y = apply_fn(self.net_params, xs) 
#         return noiseless_y # + prior_dist.sample(seed=rng, sample_shape=(self.net_count,)).reshape(noiseless_y.shape)
    
#     def reset(self):
#         self.inv_params = None
#         self.exp_params = None
#         self.net_params = None
    
#     def generate_intermediate_info(self):
#         return {
#             'inv_params': self.inv_params,
#             'exp_params': self.exp_params,
#             'net_params': self.net_params,
#         }
